Initial upload of Calcifications model
Browse files- Calc.png +0 -0
- README.md +36 -0
- critic_size_512_599.pth +3 -0
- generator_size_512_599.pth +3 -0
- progan_model.py +207 -0
- requirements.txt +10 -0
Calc.png
ADDED
|
README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- gan
|
| 5 |
+
- progan
|
| 6 |
+
- generative-ai
|
| 7 |
+
- medical-imaging
|
| 8 |
+
- pytorch
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# ProGAN-Mammography-Calcifications
|
| 12 |
+
|
| 13 |
+
## 🖼️ Model Description
|
| 14 |
+
This model is an implementation of **Progressive Growing of GANs (ProGAN)**, meticulously trained to generate medical images of mammograms with the presence of calcifications, mainly microcalcifications.. Its objective is to synthesize realistic images for data augmentation, research, or studying complex patterns in mammograms.
|
| 15 |
+
|
| 16 |
+
> This model is part of a broader research effort on the application of GANs in medical mammography imaging.
|
| 17 |
+
|
| 18 |
+
## ⚙️ Architecture Details
|
| 19 |
+
* **GAN Type:** Progressive Growing of GANs (ProGAN)
|
| 20 |
+
* **Generator:** The generator's architecture is defined in `progan_model.py`. This file includes the `Generator` class necessary to instantiate the model.
|
| 21 |
+
* **Generator Weights:** The main generator weights are found in the file `generator_size_512_599.pth`. This is the checkpoint with the highest resolution and training epoch achieved.
|
| 22 |
+
* **Critic/Discriminator Weights:** (Optional) The critic/discriminator weights are found in the file `critic_size_512_599.pth`.
|
| 23 |
+
|
| 24 |
+
## 📊 Training Dataset
|
| 25 |
+
The model was trained using the following dataset:
|
| 26 |
+
> This model was trained exclusively on a subset of the 'VinDr-Mammogram' dataset, consisting of mammograms showcasing **confirmed calcifications**. The VinDr-Mammogram dataset was meticulously curated and labeled by experienced radiologists, and its labeling scheme is unique. This model was developed as part of a Bachelor's Final Project (TFG) at the University of Extremadura (UEX).
|
| 27 |
+
|
| 28 |
+
It is recommended to review the original dataset documentation for more details on its composition and characteristics.
|
| 29 |
+
|
| 30 |
+
## 🚀 How to Use This Model
|
| 31 |
+
|
| 32 |
+
### Requirements
|
| 33 |
+
Make sure you have the following Python libraries installed:
|
| 34 |
+
```bash
|
| 35 |
+
pip install torch
|
| 36 |
+
pip install huggingface_hub
|
critic_size_512_599.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0183bc36d13e70bf3bc178f589cde43babc51c492f0ad3297168f32a298489b
|
| 3 |
+
size 305232851
|
generator_size_512_599.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd7c5df518472061333e940b51ee033547d782e76f8f073b63005ade28c43eed
|
| 3 |
+
size 276898727
|
progan_model.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implementation of ProGAN generator and discriminator with the key
|
| 3 |
+
attributions from the paper. We have tried to make the implementation
|
| 4 |
+
compact but a goal is also to keep it readable and understandable.
|
| 5 |
+
Specifically the key points implemented are:
|
| 6 |
+
|
| 7 |
+
1) Progressive growing (of model and layers)
|
| 8 |
+
2) Minibatch std on Discriminator
|
| 9 |
+
3) Normalization with PixelNorm
|
| 10 |
+
4) Equalized Learning Rate (here I cheated and only did it on Conv layers)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from math import log2
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Factors is used in Discrmininator and Generator for how much
|
| 20 |
+
the channels should be multiplied and expanded for each layer,
|
| 21 |
+
so specifically the first 5 layers the channels stay the same,
|
| 22 |
+
whereas when we increase the img_size (towards the later layers)
|
| 23 |
+
we decrease the number of chanels by 1/2, 1/4, etc.
|
| 24 |
+
"""
|
| 25 |
+
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class WSConv2d(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Weight scaled Conv2d (Equalized Learning Rate)
|
| 31 |
+
Note that input is multiplied rather than changing weights
|
| 32 |
+
this will have the same result.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
|
| 38 |
+
):
|
| 39 |
+
super(WSConv2d, self).__init__()
|
| 40 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
| 41 |
+
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
|
| 42 |
+
self.bias = self.conv.bias
|
| 43 |
+
self.conv.bias = None
|
| 44 |
+
|
| 45 |
+
nn.init.normal_(self.conv.weight)
|
| 46 |
+
nn.init.zeros_(self.bias)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return self.conv(x) * self.scale + self.bias.view(1, self.bias.shape[0], 1, 1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class PixelNorm(nn.Module):
|
| 53 |
+
def __init__(self):
|
| 54 |
+
super(PixelNorm, self).__init__()
|
| 55 |
+
self.epsilon = 1e-8
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ConvBlock(nn.Module):
|
| 62 |
+
def __init__(self, in_channels, out_channels, use_pixelnorm=True):
|
| 63 |
+
super(ConvBlock, self).__init__()
|
| 64 |
+
self.use_pn = use_pixelnorm
|
| 65 |
+
self.conv1 = WSConv2d(in_channels, out_channels)
|
| 66 |
+
self.conv2 = WSConv2d(out_channels, out_channels)
|
| 67 |
+
self.leaky = nn.LeakyReLU(0.2)
|
| 68 |
+
self.pn = PixelNorm()
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
x = self.leaky(self.conv1(x))
|
| 72 |
+
x = self.pn(x) if self.use_pn else x
|
| 73 |
+
x = self.leaky(self.conv2(x))
|
| 74 |
+
x = self.pn(x) if self.use_pn else x
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Generator(nn.Module):
|
| 79 |
+
def __init__(self, z_dim, in_channels, img_channels=1):
|
| 80 |
+
super(Generator, self).__init__()
|
| 81 |
+
|
| 82 |
+
# initial takes 1x1 -> 4x4
|
| 83 |
+
self.initial = nn.Sequential(
|
| 84 |
+
PixelNorm(),
|
| 85 |
+
nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
|
| 86 |
+
nn.LeakyReLU(0.2),
|
| 87 |
+
WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
|
| 88 |
+
nn.LeakyReLU(0.2),
|
| 89 |
+
PixelNorm(),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.initial_rgb = WSConv2d(
|
| 93 |
+
in_channels, img_channels, kernel_size=1, stride=1, padding=0
|
| 94 |
+
)
|
| 95 |
+
self.prog_blocks, self.rgb_layers = (
|
| 96 |
+
nn.ModuleList([]),
|
| 97 |
+
nn.ModuleList([self.initial_rgb]),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
for i in range(
|
| 101 |
+
len(factors) - 1
|
| 102 |
+
): # -1 to prevent index error because of factors[i+1]
|
| 103 |
+
conv_in_c = int(in_channels * factors[i])
|
| 104 |
+
conv_out_c = int(in_channels * factors[i + 1])
|
| 105 |
+
self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
|
| 106 |
+
self.rgb_layers.append(
|
| 107 |
+
WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def fade_in(self, alpha, upscaled, generated):
|
| 111 |
+
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
| 112 |
+
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
|
| 113 |
+
|
| 114 |
+
def forward(self, x, alpha, steps):
|
| 115 |
+
out = self.initial(x)
|
| 116 |
+
|
| 117 |
+
if steps == 0:
|
| 118 |
+
return self.initial_rgb(out)
|
| 119 |
+
|
| 120 |
+
for step in range(steps):
|
| 121 |
+
upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
|
| 122 |
+
out = self.prog_blocks[step](upscaled)
|
| 123 |
+
|
| 124 |
+
final_upscaled = self.rgb_layers[steps - 1](upscaled)
|
| 125 |
+
final_out = self.rgb_layers[steps](out)
|
| 126 |
+
return self.fade_in(alpha, final_upscaled, final_out)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Discriminator(nn.Module):
|
| 130 |
+
def __init__(self, z_dim, in_channels, img_channels=1):
|
| 131 |
+
super(Discriminator, self).__init__()
|
| 132 |
+
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
| 133 |
+
self.leaky = nn.LeakyReLU(0.2)
|
| 134 |
+
|
| 135 |
+
for i in range(len(factors) - 1, 0, -1):
|
| 136 |
+
conv_in = int(in_channels * factors[i])
|
| 137 |
+
conv_out = int(in_channels * factors[i - 1])
|
| 138 |
+
self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
|
| 139 |
+
self.rgb_layers.append(
|
| 140 |
+
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
|
| 144 |
+
# did this to "mirror" the generator initial_rgb
|
| 145 |
+
self.initial_rgb = WSConv2d(
|
| 146 |
+
img_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 147 |
+
)
|
| 148 |
+
self.rgb_layers.append(self.initial_rgb)
|
| 149 |
+
self.avg_pool = nn.AvgPool2d(
|
| 150 |
+
kernel_size=2, stride=2
|
| 151 |
+
) # down sampling using avg pool
|
| 152 |
+
|
| 153 |
+
# this is the block for 4x4 input size
|
| 154 |
+
self.final_block = nn.Sequential(
|
| 155 |
+
# +1 to in_channels because we concatenate from MiniBatch std
|
| 156 |
+
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
|
| 157 |
+
nn.LeakyReLU(0.2),
|
| 158 |
+
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
|
| 159 |
+
nn.LeakyReLU(0.2),
|
| 160 |
+
WSConv2d(
|
| 161 |
+
in_channels, 1, kernel_size=1, padding=0, stride=1
|
| 162 |
+
), # we use this instead of linear layer
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def fade_in(self, alpha, downscaled, out):
|
| 166 |
+
"""Used to fade in downscaled using avg pooling and output from CNN"""
|
| 167 |
+
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
| 168 |
+
return alpha * out + (1 - alpha) * downscaled
|
| 169 |
+
|
| 170 |
+
def minibatch_std(self, x):
|
| 171 |
+
batch_statistics = torch.std(x, dim=0, unbiased=False).mean()
|
| 172 |
+
batch_statistics = batch_statistics.repeat(x.shape[0], 1, x.shape[2], x.shape[3])
|
| 173 |
+
# we take the std for each example (across all channels, and pixels) then we repeat it
|
| 174 |
+
# for a single channel and concatenate it with the image. In this way the discriminator
|
| 175 |
+
# will get information about the variation in the batch/image
|
| 176 |
+
return torch.cat([x, batch_statistics], dim=1)
|
| 177 |
+
|
| 178 |
+
def forward(self, x, alpha, steps):
|
| 179 |
+
# where we should start in the list of prog_blocks, maybe a bit confusing but
|
| 180 |
+
# the last is for the 4x4. So example let's say steps=1, then we should start
|
| 181 |
+
# at the second to last because input_size will be 8x8. If steps==0 we just
|
| 182 |
+
# use the final block
|
| 183 |
+
cur_step = len(self.prog_blocks) - steps
|
| 184 |
+
|
| 185 |
+
# convert from rgb as initial step, this will depend on
|
| 186 |
+
# the image size (each will have it's on rgb layer)
|
| 187 |
+
out = self.leaky(self.rgb_layers[cur_step](x))
|
| 188 |
+
|
| 189 |
+
if steps == 0: # i.e, image is 4x4
|
| 190 |
+
out = self.minibatch_std(out)
|
| 191 |
+
return self.final_block(out).view(out.shape[0], -1)
|
| 192 |
+
|
| 193 |
+
# because prog_blocks might change the channels, for down scale we use rgb_layer
|
| 194 |
+
# from previous/smaller size which in our case correlates to +1 in the indexing
|
| 195 |
+
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
|
| 196 |
+
out = self.avg_pool(self.prog_blocks[cur_step](out))
|
| 197 |
+
|
| 198 |
+
# the fade_in is done first between the downscaled and the input
|
| 199 |
+
# this is opposite from the generator
|
| 200 |
+
out = self.fade_in(alpha, downscaled, out)
|
| 201 |
+
|
| 202 |
+
for step in range(cur_step + 1, len(self.prog_blocks)):
|
| 203 |
+
out = self.prog_blocks[step](out)
|
| 204 |
+
out = self.avg_pool(out)
|
| 205 |
+
|
| 206 |
+
out = self.minibatch_std(out)
|
| 207 |
+
return self.final_block(out).view(out.shape[0], -1)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Pillow==10.4.0
|
| 2 |
+
customtkinter==5.2.2
|
| 3 |
+
matplotlib==3.9.2
|
| 4 |
+
numpy==2.1.3
|
| 5 |
+
opencv-python==4.11.0
|
| 6 |
+
scipy==1.14.1
|
| 7 |
+
torch==2.5.1+cu124
|
| 8 |
+
torchmetrics==1.7.1
|
| 9 |
+
torchvision==0.20.1+cu124
|
| 10 |
+
tqdm==4.67.0
|