pgatoula commited on
Commit
b620cf3
·
0 Parent(s):

Initial commit

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/github_tide.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="tf2" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="GOOGLE" />
10
+ <option name="myDocStringFormat" value="Google" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <Languages>
6
+ <language minSize="66" name="Python" />
7
+ </Languages>
8
+ </inspection_tool>
9
+ </profile>
10
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="tf2" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/github_tide.iml" filepath="$PROJECT_DIR$/.idea/github_tide.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
__pycache__/compare_ckpts.cpython-310.pyc ADDED
Binary file (5.55 kB). View file
 
compare_ckpts.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import glob
3
+ import numpy as np
4
+ import pandas as pd
5
+ import importlib
6
+ import tensorflow as tf
7
+
8
+ from PIL import Image
9
+ from re import split, compile
10
+ from tensorflow.keras.applications.inception_v3 import preprocess_input
11
+
12
+ import fid_kid
13
+
14
+ #TODO : uncomment & import appropriate paths if msgastrovae_smc.py (TIDE model) & my_convnext.py (TIDE-2 model)
15
+ # are not in the same folder with this script (for importlib modules below)
16
+ # sys.path.append(f'{tide_path}')
17
+ # sys.path.append(f'{tide2_path}')
18
+
19
+
20
+ def list_saved_models(results_dir):
21
+ models_found = glob.glob("{}/weights/vae_checkpoints/*.index".format(results_dir)) # checkpoints
22
+ models_found.extend(glob.glob("{}/weights/*.h5".format(results_dir))) # models
23
+ models_found.sort(key=lambda l: [int(s) if s.isdigit() else s.lower() for s in split(compile(r'(\d+)'), l)])
24
+ # print(models_found)
25
+ return models_found
26
+
27
+
28
+ def get_real_kid_filenames(label):
29
+ kid_dir = '/mnt/storage/shared/ckaitanidis/datasets/kid/kid-dataset-2/'
30
+ files = []
31
+ if label in ['inflammatory', 'polypoid', 'vascular']:
32
+ files = glob.glob('{}/{}/*.png'.format(kid_dir, label))
33
+ elif label == 'normaleso':
34
+ files = glob.glob('{}/normal-esophagus'.format(kid_dir))
35
+ elif label == 'normalstom':
36
+ files = glob.glob('{}/normal-stomach'.format(kid_dir))
37
+ elif label == 'normalcolon':
38
+ files = glob.glob('{}/normal-colon'.format(kid_dir))
39
+ elif label == 'normalsb':
40
+ files = glob.glob('{}/normal-small-bowel'.format(kid_dir))
41
+ print('Real images found: {}'.format(len(files)))
42
+ files = sorted(files)
43
+ return files
44
+
45
+
46
+ def init_vae_model(model_name, latent_dim, input_shape):
47
+ if model_name == 'tide':
48
+ vae = importlib.import_module("msgastrovae_smc")
49
+ vae_model = vae.VAE(vae.create_encoder(latent_dim=latent_dim,input_shape=input_shape),
50
+ vae.create_decoder(latent_dim=latent_dim))
51
+ vae_model.build(input_shape=[(None,) + input_shape])
52
+ return vae_model
53
+ elif model_name == 'tide2':
54
+ vae = importlib.import_module("my_convnext")
55
+ vae_model = vae.VAE(vae.create_encoder_tiny(latent_dim=latent_dim, input_shape=input_shape),
56
+ vae.create_decoder_tiny(latent_dim=latent_dim))
57
+ vae_model.build(input_shape=[(None,) + input_shape])
58
+ return vae_model
59
+
60
+
61
+ def load_weights(vae, weights_path):
62
+ print("Loading weights from {}".format(weights_path))
63
+ if "ckpt-" in weights_path:
64
+ weights_path = weights_path.split(".index")[0]
65
+ ckpt = tf.train.Checkpoint(vae=vae)
66
+ status = ckpt.restore(weights_path).expect_partial()
67
+ return vae
68
+ if ".h5" in weights_path:
69
+ vae.load_weights(weights_path, by_name=True)
70
+ return vae
71
+
72
+
73
+ def debug_weights_loading(vae):
74
+ decoder_weights = vae.decoder.get_weights()
75
+ print("Decoder layer 0 weights shape:", decoder_weights[0].shape)
76
+ print("Decoder layer 0 weights sample:", decoder_weights[0].flatten()[:5])
77
+ # 'Decoder layer 0 weights sample: [-0.01202846 -0.02691004 0.00642165 -0.02967337 -0.03743371]' # mine always same
78
+
79
+
80
+ def get_noise_seeded(noise_shape):
81
+ np.random.seed(0)
82
+ random_z = np.random.normal(0, 1, noise_shape)
83
+ return random_z
84
+
85
+
86
+ def decode_noise(trained_vae, noise, return_list=False):
87
+ print("Generating fake images ...")
88
+ pred = trained_vae.decoder.predict(noise, batch_size=1)
89
+ # print(type(pred), pred.shape, pred.dtype, pred.min(), pred.max())
90
+ pred *= 255.0 # for tf.preprocess_input requires [0, 255]
91
+ # print(type(pred), pred.shape, pred.dtype, pred.min(), pred.max())
92
+ if return_list:
93
+ return [img for img in pred]
94
+ return pred
95
+
96
+
97
+ def visualize_debug(image, name='output.png'):
98
+ image = ((image + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
99
+ Image.fromarray(image).save(name)
100
+
101
+
102
+ def kid_dataset_center_crop(img, crop_size=(320, 320)):
103
+ if not isinstance(img, np.ndarray):
104
+ img = np.array(img)
105
+
106
+ h, w, _ = img.shape
107
+ ch, cw = crop_size
108
+
109
+ top = (h - ch) // 2
110
+ left = (w - cw) // 2
111
+
112
+ return img[top:top + ch, left:left + cw]
113
+
114
+
115
+ if __name__ == "__main__":
116
+
117
+ # Change these two to run across all models
118
+ model_name = 'tide2' # 'tide' 'tide2'
119
+ label = "inflammatory" # 'inflammatory' 'vascular' 'polypoid' 'normaleso' 'normalcolon' 'normalsb' 'normalstom'
120
+
121
+ results_dir = "/mnt/storage/shared/ckaitanidis/"
122
+ if model_name == 'tide':
123
+ results_dir += 'kid_latent6_sr96_50000ep_{}'.format(label)
124
+ elif model_name == 'tide2':
125
+ results_dir += 'kid_latent8_tide2_sr96_50000ep_{}'.format(label)
126
+ print(results_dir)
127
+
128
+ # Params auto
129
+ latent_dim = 6 if model_name == 'tide' else 8
130
+ input_shape = (96, 96, 3) if model_name == 'tide' else (256, 256, 3)
131
+ crop_dim = (320, 320)
132
+
133
+ trained_weights = list_saved_models(results_dir)
134
+ real_filenames = get_real_kid_filenames(label)
135
+ real_images = fid_kid.get_images_inception(real_filenames, crop_dim=crop_dim) # this returns np.array, float32, [-1, 1], (batch, 299, 299, 3)
136
+ # print(type(real_images), real_images.shape, real_images.dtype, real_images.min(), real_images.max())
137
+ visualize_debug(real_images[0], name='output1.png')
138
+
139
+ vae = init_vae_model(model_name, latent_dim, input_shape)
140
+ noise_vector = get_noise_seeded((len(real_filenames), latent_dim))
141
+
142
+
143
+ results = []
144
+
145
+ # Ignore these - my weights for debug
146
+ # trained_weights = ['/mnt/storage/pgatoula-private/codes/tide-panagiota/results_kid/kid_inflammatory_latent6/weights/vae_checkpoints/ckpt-4500.index']
147
+ # trained_weights = ['/mnt/storage/pgatoula-private/general-results/convnext/kid_inflammatory_latent8_lbfcn_sr96/weights/vae_checkpoints/ckpt-1400.index',
148
+ # '/mnt/storage/pgatoula-private/general-results/convnext/kid_inflammatory_latent8_lbfcn_sr96/weights/vae_checkpoints/ckpt-1600.index']
149
+
150
+ for weights in trained_weights:
151
+ # Load weights
152
+ vae = load_weights(vae, weights)
153
+ vae.trainable = False
154
+ # try:
155
+ # debug_weights_loading(vae)
156
+ # except Exception as e:
157
+ # print(f"Skipping {weights} due to load failure: {e}")
158
+ # continue
159
+
160
+ # Generate Fakes
161
+ fake_images = decode_noise(vae, noise_vector, return_list=True)
162
+ fake_images = preprocess_input(fake_images)
163
+ # print(type(fake_images), fake_images.shape, fake_images.dtype, fake_images.min(), fake_images.max())
164
+ fake_images = tf.image.resize(fake_images, size=(299, 299), method='bicubic').numpy()
165
+ # print(type(fake_images), fake_images.shape, fake_images.dtype, fake_images.min(), fake_images.max())
166
+ visualize_debug(fake_images[0], name='output2.png')
167
+
168
+ # Calculate metrics
169
+ fid_score = fid_kid.calculate_fid(real_images, fake_images)
170
+ kid_mean, kid_std = fid_kid.calculate_kid(real_images, fake_images)
171
+
172
+ fid_score = round(fid_score, 4)
173
+ kid_mean = round(kid_mean, 4)
174
+ kid_std = round(kid_std, 4)
175
+
176
+ print("{}: FID={} KID={} ± {}".format(weights, fid_score, kid_mean, kid_std))
177
+
178
+ results.append({'weights': weights,
179
+ 'fid': fid_score,
180
+ 'kid_mean': kid_mean,
181
+ 'kid_std': kid_std,
182
+ })
183
+
184
+ # Save in xlxs
185
+ df = pd.DataFrame(results)
186
+ excel_path = f"ckpt_metrics_{model_name}.xlsx"
187
+
188
+ # TODO: pip install XlsxWriter if not installed
189
+ with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer:
190
+ df.to_excel(writer, sheet_name=label, index=False)
191
+
192
+ print(f"Results saved to: {excel_path}")
generate_images.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from argparse import ArgumentParser
4
+ from utils.inference_utils import init_vae_model, load_weights, get_noise_seeded, decode_noise, save_images
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = ArgumentParser()
9
+ parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model')
10
+ parser.add_argument("--weights_path", required=True, type=str, help='Path to restore trained weights')
11
+ parser.add_argument("--latent_dim", default=8, type=int, help='Dimensionality of latent space')
12
+ parser.add_argument("--save_dir", default="./fake_images", type=str, help='Path to save synthetic images')
13
+ parser.add_argument("--num_of_images", default=10, type=int, help='Number of images to generate')
14
+ args = parser.parse_args()
15
+
16
+ os.makedirs(args.save_dir, exist_ok=True)
17
+
18
+ if not os.path.exists(args.weights_path):
19
+ print("Not a valid path")
20
+
21
+ vae = init_vae_model(args.model_name, args.latent_dim)
22
+ noise_vector = get_noise_seeded((args.num_of_images, args.latent_dim))
23
+
24
+ # Load weights
25
+ vae = load_weights(vae, args.weights_path)
26
+ vae.trainable = False
27
+
28
+ # Generate & Save images
29
+ fake_images = decode_noise(vae, noise_vector, return_list=True)
30
+ save_images(args.save_dir, fake_images)
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (158 Bytes). View file
 
model/__pycache__/convnext_modules.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
model/__pycache__/tidev2.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
model/__pycache__/tidev2_utils.cpython-310.pyc ADDED
Binary file (2.15 kB). View file
 
model/__pycache__/vae.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
model/convnext_modules.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.keras.layers as layers
3
+
4
+ from tensorflow.keras import backend
5
+
6
+
7
+ class LayerScale(layers.Layer):
8
+ def __init__(self, init_values, projection_dim, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.init_values = init_values
11
+ self.projection_dim = projection_dim
12
+
13
+ def build(self, input_shape):
14
+ self.gamma = tf.Variable(self.init_values * tf.ones((self.projection_dim,)))
15
+
16
+ def call(self, x):
17
+ return x * self.gamma
18
+
19
+ def get_config(self):
20
+ config = super().get_config()
21
+ config.update(
22
+ {
23
+ "init_values": self.init_values,
24
+ "projection_dim": self.projection_dim,
25
+ }
26
+ )
27
+ return config
28
+
29
+
30
+ class StochasticDepth(layers.Layer):
31
+ def __init__(self, drop_path_rate, **kwargs):
32
+ super().__init__(**kwargs)
33
+ self.drop_path_rate = drop_path_rate
34
+
35
+ def call(self, x, training=None):
36
+ if training:
37
+ keep_prob = 1 - self.drop_path_rate
38
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
39
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
40
+ random_tensor = tf.floor(random_tensor)
41
+ return (x / keep_prob) * random_tensor
42
+ return x
43
+
44
+ def get_config(self):
45
+ config = super().get_config()
46
+ config.update({"drop_path_rate": self.drop_path_rate})
47
+ return config
48
+
49
+
50
+ class ConvNeXtBlock(layers.Layer):
51
+ def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
52
+ super().__init__(name=name_prefix or f"prestem{backend.get_uid('prestem')}")
53
+ self.depthwise_conv = layers.Conv2D(
54
+ filters=projection_dim, kernel_size=7, padding="same", groups=projection_dim,
55
+ name=self.name + "_depthwise_conv"
56
+ )
57
+ self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
58
+ self.act = layers.Activation("gelu", name=self.name + "_gelu")
59
+ self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
60
+ self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale") \
61
+ if layer_scale_init_value is not None else None
62
+ self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth") \
63
+ if drop_path_rate else layers.Activation("linear", name=self.name + "_identity")
64
+
65
+ def call(self, inputs, training=False):
66
+ x = self.depthwise_conv(inputs)
67
+ x = self.pointwise_conv1(x)
68
+ x = self.act(x)
69
+ x = self.pointwise_conv2(x)
70
+ if self.layer_scale:
71
+ x = self.layer_scale(x)
72
+ x = self.stochastic_depth(x, training=training)
73
+ return inputs + x
74
+
75
+
76
+ class ConvNeXtBlockTransposed(layers.Layer):
77
+ def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
78
+ super().__init__(name=name_prefix or f"poststem{backend.get_uid('poststem')}")
79
+ self.projection_dim = projection_dim
80
+ self.drop_path_rate = drop_path_rate
81
+ self.layer_scale_init_value = layer_scale_init_value
82
+
83
+ self.depthwise_conv_trans = layers.Conv2DTranspose(
84
+ filters=projection_dim, kernel_size=7, padding="same",
85
+ groups=projection_dim, name=self.name + "_depthwise_conv_trans"
86
+ )
87
+ self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
88
+ self.act = layers.Activation("gelu", name=self.name + "_gelu")
89
+ self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
90
+
91
+ if layer_scale_init_value is not None:
92
+ self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale")
93
+ else:
94
+ self.layer_scale = None
95
+
96
+ if drop_path_rate:
97
+ self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth")
98
+ else:
99
+ self.stochastic_depth = layers.Activation("linear", name=self.name + "_identity")
100
+
101
+ def call(self, inputs, training=False):
102
+ x = self.depthwise_conv_trans(inputs)
103
+ x = self.pointwise_conv1(x)
104
+ x = self.act(x)
105
+ x = self.pointwise_conv2(x)
106
+ if self.layer_scale:
107
+ x = self.layer_scale(x)
108
+ x = self.stochastic_depth(x, training=training)
109
+ return inputs + x
110
+
111
+
model/tidev2.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow.keras.layers as layers
3
+
4
+ from tensorflow.keras import Model
5
+ from tensorflow.keras import Sequential
6
+
7
+ from model.tidev2_utils import TopLayer, Sampling
8
+ from model.convnext_modules import ConvNeXtBlock, ConvNeXtBlockTransposed
9
+
10
+
11
+ class ConvNeXtEncoderTiny(Model):
12
+ def __init__(self,
13
+ depths=[3, 3, 9, 3],
14
+ projection_dims=[96, 192, 384, 768],
15
+ drop_path_rate=0.0,
16
+ layer_scale_init_value=1e-6,
17
+ model_name="convnext",
18
+ latent_dim=None):
19
+ super().__init__(name=model_name)
20
+ self.latent_dim = latent_dim
21
+ self.depths = depths
22
+ self.projection_dims = projection_dims
23
+
24
+ # Stem
25
+ self.stem = Sequential([
26
+ layers.Conv2D(projection_dims[0], kernel_size=4, strides=4, name=model_name + "_stem_conv"),
27
+ ], name=model_name + "_stem")
28
+
29
+ # Downsampling layers
30
+ self.downsample_layers = [self.stem]
31
+ for i in range(3):
32
+ self.downsample_layers.append(
33
+ Sequential([
34
+ layers.Conv2D(projection_dims[i + 1], kernel_size=2, strides=2,
35
+ name=model_name + f"_downsampling_conv_{i}")
36
+ ], name=model_name + f"_downsampling_block_{i}")
37
+ )
38
+
39
+ # Drop rates for stochastic depth
40
+ self.depth_drop_rates = np.linspace(0.0, drop_path_rate, sum(depths)).astype(float)
41
+
42
+ # ConvNeXt stages
43
+ self.stages = []
44
+ cur = 0
45
+ for i in range(4):
46
+ stage_blocks = []
47
+ for j in range(depths[i]):
48
+ stage_blocks.append(
49
+ ConvNeXtBlock(projection_dim=projection_dims[i],
50
+ drop_path_rate=self.depth_drop_rates[cur + j],
51
+ layer_scale_init_value=layer_scale_init_value,
52
+ name_prefix=model_name + f"_stage_{i}_block_{j}")
53
+ )
54
+ self.stages.append(stage_blocks)
55
+ cur += depths[i]
56
+
57
+ # Latent projection if requested
58
+ if latent_dim is not None:
59
+ self.flatten = layers.Flatten()
60
+ self.dense_proj = layers.Dense(256, activation="relu", name="dense_proj")
61
+ self.z_mean = layers.Dense(latent_dim, name="z_mean")
62
+ self.z_log_var = layers.Dense(latent_dim, name="z_log_var")
63
+ self.sampling = Sampling()
64
+
65
+ def call(self, inputs, training=False):
66
+ x = inputs
67
+ for i in range(4):
68
+ x = self.downsample_layers[i](x)
69
+ for block in self.stages[i]:
70
+ x = block(x, training=training)
71
+
72
+ if self.latent_dim is None:
73
+ return x
74
+
75
+ x = self.flatten(x)
76
+ x = self.dense_proj(x)
77
+ z_mean = self.z_mean(x)
78
+ z_log_var = self.z_log_var(x)
79
+ z = self.sampling([z_mean, z_log_var])
80
+ return [z, z_mean, z_log_var]
81
+
82
+
83
+ class ConvNeXtDecoderTiny(Model):
84
+ def __init__(self,
85
+ depths=[3, 9, 3, 3],
86
+ projection_dims=[768, 384, 192, 96],
87
+ drop_path_rate=0.0,
88
+ layer_scale_init_value=1e-6,
89
+ model_name="convnext",
90
+ latent_dim=None):
91
+ super().__init__(name=model_name)
92
+
93
+ if latent_dim is None:
94
+ raise ValueError("latent_dim must be specified for decoder")
95
+
96
+ # Intro layer (dense + reshape)
97
+ self.intro = Sequential([
98
+ layers.Dense(10 * 10 * projection_dims[0], activation="relu"),
99
+ layers.Reshape((10, 10, projection_dims[0]))
100
+ ], name=model_name + "_intro")
101
+
102
+ # Upsampling layers
103
+ self.upsample_layers = [self.intro]
104
+ for i in range(3):
105
+ self.upsample_layers.append(
106
+ Sequential([
107
+ layers.Conv2DTranspose(projection_dims[i + 1], kernel_size=2, strides=2,
108
+ name=model_name + f"_upsampling_conv_{i}")
109
+ ], name=model_name + f"_upsampling_block_{i}")
110
+ )
111
+
112
+ # Drop rates for stochastic depth
113
+ self.depth_drop_rates = np.linspace(0.0, drop_path_rate, sum(depths)).astype(float)
114
+
115
+ # ConvNeXt transpose stages
116
+ self.stages = []
117
+ cur = 0
118
+ for i in range(4):
119
+ stage_blocks = []
120
+ for j in range(depths[i]):
121
+ stage_blocks.append(
122
+ ConvNeXtBlockTransposed(projection_dim=projection_dims[i],
123
+ drop_path_rate=self.depth_drop_rates[cur + j],
124
+ layer_scale_init_value=layer_scale_init_value,
125
+ name_prefix=model_name + f"_stage_{i}_block_{j}")
126
+ )
127
+ self.stages.append(stage_blocks)
128
+ cur += depths[i]
129
+
130
+ # Top layer
131
+ self.top = Sequential([
132
+ layers.Conv2DTranspose(projection_dims[3], kernel_size=4, strides=4, name=model_name + "_top_conv")
133
+ ], name=model_name + "_top")
134
+
135
+ self.top_layer = TopLayer(filters=96)
136
+ self.pred_layer = layers.Conv2DTranspose(3, kernel_size=1, activation="sigmoid",
137
+ padding="same", name="pred_layer")
138
+
139
+ def call(self, inputs, training=False):
140
+ x = inputs
141
+ for i in range(4):
142
+ x = self.upsample_layers[i](x)
143
+ for block in self.stages[i]:
144
+ x = block(x, training=training)
145
+ x = self.top(x)
146
+ x = self.top_layer(x)
147
+ return self.pred_layer(x)
148
+
149
+
150
+ if __name__ == "__main__":
151
+ # Encoder
152
+ encoder = ConvNeXtEncoderTiny(latent_dim=8)
153
+ encoder.build((None, 320, 320, 3))
154
+ encoder.summary()
155
+
156
+ # Decoder
157
+ decoder = ConvNeXtDecoderTiny(latent_dim=8)
158
+ decoder.build((None, 8))
159
+ decoder.summary()
160
+
161
+
model/tidev2_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.keras.layers as layers
3
+
4
+
5
+ class TopLayer(layers.Layer):
6
+ def __init__(self, filters):
7
+ super().__init__()
8
+ self.filters = filters
9
+
10
+ self.conv_1x1 = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
11
+ name="_top_layer")
12
+ self.conv_2x2 = layers.Conv2D(self.filters//3, (2, 2), activation='relu', strides=1, padding="same",
13
+ name="_top_layer")
14
+ self.conv_4x4 = layers.Conv2D(self.filters//3, (4, 4), activation='relu', strides=1, padding="same",
15
+ name="_top_layer")
16
+ self.conv_8x8 = layers.Conv2D(self.filters//3, (8, 8), activation='relu', strides=1, padding="same",
17
+ name="_top_layer")
18
+
19
+ self.concat = layers.Concatenate(axis=-1)
20
+ self.point_wise_conv = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
21
+ padding='same', name="_top_layer")
22
+ self.feat_fusion = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
23
+ padding='same', name="_top_layer")
24
+
25
+ self.addition = layers.Add()
26
+ self.gelu = layers.Activation('gelu')
27
+ self.final_conv = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
28
+ name="_top_layer")
29
+
30
+ def call(self, inputs, training=False):
31
+ x = self.conv_1x1(inputs, training=training)
32
+
33
+ feats_2x2 = self.conv_2x2(x, training=training)
34
+ feats_4x4 = self.conv_4x4(x, training=training)
35
+ feats_8x8 = self.conv_8x8(x, training=training)
36
+
37
+ concatenated = self.concat([feats_2x2, feats_4x4, feats_8x8])
38
+ concatenated = self.point_wise_conv(concatenated)
39
+
40
+ concatenated = self.feat_fusion(concatenated)
41
+ x = self.addition([inputs, concatenated])
42
+ x = self.gelu(x)
43
+ x = self.final_conv(x)
44
+ return x
45
+
46
+
47
+ class Sampling(layers.Layer):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def call(self, inputs):
52
+ z_mean, z_log_var = inputs
53
+ batch = tf.shape(z_mean)[0]
54
+ dim = tf.shape(z_mean)[1]
55
+ epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
56
+ return z_mean + tf.exp(0.5 * z_log_var) * epsilon
model/vae.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.models import Model
3
+
4
+
5
+ class VAE(Model):
6
+ def __init__(self, encoder, decoder, **kwargs):
7
+ super(VAE, self).__init__(**kwargs)
8
+ self.encoder = encoder
9
+ self.decoder = decoder
10
+
11
+ # Loss Trackers
12
+ self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
13
+ self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
14
+ self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
15
+
16
+ @property
17
+ def metrics(self):
18
+ return [
19
+ self.total_loss_tracker,
20
+ self.reconstruction_loss_tracker,
21
+ self.kl_loss_tracker,
22
+ ]
23
+
24
+ @tf.function()
25
+ def call(self, x):
26
+ z, z_mean, z_log_var, = self.encoder(x)
27
+ reconstruction = self.decoder(z)
28
+ return reconstruction
29
+
30
+ def full_summary(self):
31
+ for layer in self.layers:
32
+ print(layer.summary())
33
+
34
+ @tf.function()
35
+ def train_step(self, x):
36
+ with tf.GradientTape() as tape:
37
+ z, z_mean, z_log_var, = self.encoder(x)
38
+ reconstruction = self.decoder(z)
39
+
40
+ reconstruction_loss = tf.reduce_mean(
41
+ tf.reduce_sum(
42
+ tf.keras.losses.binary_crossentropy(x, reconstruction), axis=(1, 2)
43
+ )
44
+ )
45
+ kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
46
+ kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
47
+ if tf.math.is_nan(kl_loss) or tf.math.is_inf(kl_loss):
48
+ kl_loss = tf.float32.max
49
+ total_loss = reconstruction_loss + kl_loss
50
+
51
+ grads = tape.gradient(total_loss, self.trainable_weights)
52
+ self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
53
+ self.total_loss_tracker.update_state(total_loss)
54
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
55
+ self.kl_loss_tracker.update_state(kl_loss)
56
+
57
+ return {
58
+ "loss": self.total_loss_tracker.result(),
59
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
60
+ "kl_loss": self.kl_loss_tracker.result(),
61
+ }
train.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+
4
+ from json import dump
5
+ from argparse import ArgumentParser
6
+
7
+
8
+ from model import tidev2
9
+ from model.vae import VAE
10
+ from utils.callbacks import VisualizeCallback, CheckpointCallback
11
+ from utils.dataloader import list_filenames, Dataset
12
+ from utils.plots import visualize_from_latent_space
13
+
14
+
15
+ if __name__ == '__main__':
16
+ parser = ArgumentParser()
17
+ parser.add_argument("--model_name", required=True, type=str, choices=['tide', 'tidev2'], help='VAE model')
18
+ parser.add_argument("--output_path", default='./results/', type=str, help='Path to store the results')
19
+ # VAE model
20
+ parser.add_argument("--input_shape", default=(320, 320, 3), type=tuple, help='Image shape for training')
21
+ parser.add_argument("--dim_latent", default=8, type=int, help='Dimensionality of latent space')
22
+ # Training
23
+ parser.add_argument("--epochs", default=5000, type=int, help='Number of training epochs')
24
+ parser.add_argument("--batch_size", default=4, type=int, help='Number of training batch size')
25
+ parser.add_argument("--learning_rate", default=0.0002, type=float, help='Learning rate')
26
+ parser.add_argument("--ckpt_interval", default=200, type=int, help='Epoch interval for saving checkpoints')
27
+ parser.add_argument("--visualization_interval", default=25, type=int, help='Epoch interval for visualizing results')
28
+ # Data
29
+ parser.add_argument("--datadir", default='./kid/inflammatory', type=str, help='Folder with images for training')
30
+ parser.add_argument("--files_ext", default='png', type=str, help='Extension of training files')
31
+ parser.add_argument("--files_prefix", default=None, type=str,
32
+ help='Prefix of training files. Ignore if datadir contains only the appropriate files')
33
+ parser.add_argument("--crop_dim", default=None, type=tuple,
34
+ help='Dimensions for cropping images. Ignore if images are already cropped')
35
+ args = parser.parse_args()
36
+
37
+ # Create folders & Save training config
38
+ os.makedirs(args.output_path, exist_ok=True)
39
+ log_dir = os.path.join(args.output_path, 'logs')
40
+ ckpt_dir = os.path.join(args.output_path, 'checkpoints')
41
+ visualize_dir = os.path.join(args.output_path, 'visualize')
42
+
43
+ os.makedirs(log_dir, exist_ok=True)
44
+ os.makedirs(ckpt_dir, exist_ok=True)
45
+ os.makedirs(visualize_dir, exist_ok=True)
46
+
47
+ with open(os.path.join(args.output_path, "training_config.json"), 'w') as fp:
48
+ dump(vars(args), fp)
49
+
50
+ # Setup training data
51
+ filenames = list_filenames(data_path=args.datadir,
52
+ img_extension=args.files_ext,
53
+ filename_prefix=args.files_prefix)
54
+ images = Dataset(filenames,
55
+ batch_size=args.batch_size,
56
+ crop_dim=args.crop_dim,
57
+ resize_dim=args.input_shape[:2],)
58
+
59
+ # Create Model
60
+ if args.model_name == 'tidev2':
61
+ vae = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=args.dim_latent),
62
+ tidev2.ConvNeXtDecoderTiny(latent_dim=args.dim_latent)
63
+ )
64
+ vae.compile(optimizer=tf.keras.optimizers.Adam(args.learning_rate))
65
+
66
+ # Training
67
+ callbacks = [VisualizeCallback(args.visualization_interval, lambda model, epoch: visualize_from_latent_space(
68
+ latent_dim=args.dim_latent,
69
+ input_shape=args.input_shape,
70
+ vae=model,
71
+ output_path=visualize_dir,
72
+ epoch=epoch,
73
+ num_items=10,)),
74
+ CheckpointCallback(vae=vae,
75
+ path=ckpt_dir,
76
+ epoch_interval=args.ckpt_interval,
77
+ restore_training=False,
78
+ restore_path=None),
79
+ tf.keras.callbacks.TensorBoard(log_dir=log_dir)]
80
+
81
+ vae.fit(x=images,
82
+ epochs=args.epochs,
83
+ batch_size=args.batch_size,
84
+ callbacks=callbacks,
85
+ shuffle=True,
86
+ initial_epoch=0)
87
+
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (158 Bytes). View file
 
utils/__pycache__/callbacks.cpython-310.pyc ADDED
Binary file (2.5 kB). View file
 
utils/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (3.29 kB). View file
 
utils/__pycache__/inference_utils.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
utils/__pycache__/plots.cpython-310.pyc ADDED
Binary file (1 kB). View file
 
utils/callbacks.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import tensorflow as tf
4
+
5
+
6
+ class VisualizeCallback(tf.keras.callbacks.Callback):
7
+ def __init__(self, epoch_interval=1, func=lambda model, epoch: None):
8
+ super(VisualizeCallback, self).__init__()
9
+ self.func = func
10
+ self.epoch_interval = epoch_interval
11
+
12
+ def on_epoch_end(self, epoch, logs=None):
13
+ if epoch % self.epoch_interval == 0 and epoch > 0:
14
+ self.func(self.model, epoch)
15
+
16
+
17
+ class CheckpointCallback(tf.keras.callbacks.Callback):
18
+ def __init__(self, vae, path, epoch_interval=1, restore_training=False, restore_path=None):
19
+ super(CheckpointCallback, self).__init__()
20
+ self.epoch_interval = epoch_interval
21
+ self.path = path
22
+ self.vae = vae
23
+
24
+ self.ckpt = tf.train.Checkpoint(vae=vae,
25
+ vae_optimizer=vae.optimizer)
26
+ self.ckpt_manager = tf.train.CheckpointManager(checkpoint=self.ckpt,
27
+ directory=self.path,
28
+ max_to_keep=None)
29
+ self.restore_training = restore_training
30
+ self.restore_path = restore_path
31
+ self._saved = False
32
+
33
+ def on_epoch_end(self, epoch, logs=None):
34
+ if epoch % self.epoch_interval == 0 and epoch > 0:
35
+ self.ckpt_manager.save(checkpoint_number=epoch)
36
+
37
+ def on_train_begin(self, logs=None):
38
+ if self.restore_training:
39
+ if self.restore_path is None:
40
+ self.ckpt.restore(self.ckpt_manager.latest_checkpoint).except_partial()
41
+ print("Resume training from checkpoint ", self.ckpt_manager.latest_checkpoint, "\n")
42
+ else:
43
+ self.ckpt.restore(self.restore_path)
44
+ print("resume training from checkpoint ", self.restore_path, "\n")
45
+
46
+ def on_train_end(self, logs=None):
47
+ weights_path = os.path.join(self.path, "trained-vae")
48
+ self.ckpt.save(file_prefix=weights_path)
49
+
utils/dataloader.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from re import split, compile
6
+ from tensorflow.keras.utils import Sequence
7
+
8
+
9
+ def list_filenames(data_path, img_extension='png', filename_prefix=None):
10
+ if filename_prefix is None:
11
+ files_list = [file for file in os.listdir(data_path) if file.endswith(img_extension)]
12
+ else:
13
+ files_list = [file for file in os.listdir(data_path) if file.endswith(img_extension) and file.startswith(filename_prefix)]
14
+
15
+ files_list.sort(key=lambda l: [int(s) if s.isdigit() else s.lower() for s in split(compile(r'(\d+)'), l)])
16
+ files_list = [os.path.join(data_path, file) for file in files_list]
17
+ print('Found {} files in {}'.format(len(files_list), data_path))
18
+ return files_list
19
+
20
+
21
+ class Dataset(Sequence):
22
+ def __init__(self, file_list, batch_size=32, crop_dim=None, resize_dim=None, shuffle=True):
23
+ self.files_list = file_list
24
+ self.batch_size = batch_size
25
+
26
+ self.crop_dim = crop_dim
27
+ self.resize_dim = resize_dim
28
+ self.shuffle = shuffle
29
+ self.on_epoch_end()
30
+
31
+ def __len__(self):
32
+ return int(np.ceil(len(self.files_list) / self.batch_size))
33
+
34
+ def __getitem__(self, idx):
35
+ batch_files = self.files_list[idx * self.batch_size : (idx + 1) * self.batch_size]
36
+ images = [self.load_images(f) for f in batch_files]
37
+ return np.stack(images)
38
+
39
+ def on_epoch_end(self):
40
+ if self.shuffle:
41
+ np.random.shuffle(self.files_list)
42
+
43
+ @staticmethod
44
+ def center_crop(image, crop_dim):
45
+ h, w = image.size
46
+ crop_h, crop_w = crop_dim
47
+
48
+ top = max(0, (w - crop_w) // 2)
49
+ left = max(0, (h - crop_h) // 2)
50
+ right = min(h - 0, (h + crop_h) // 2)
51
+ bottom = min(w - 0, (w + crop_w) // 2)
52
+
53
+ return image.crop((left, top, right, bottom))
54
+
55
+ def load_images(self, filepath):
56
+ image = Image.open(filepath).convert('RGB')
57
+ if self.crop_dim:
58
+ image = self.center_crop(image, crop_dim=self.crop_dim)
59
+ if self.resize_dim:
60
+ image = image.resize(self.resize_dim)
61
+
62
+ image = np.array(image).astype(np.float32)
63
+ image = image / 255.0
64
+ return image
utils/inference_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+
5
+ from PIL import Image
6
+
7
+ from model.vae import VAE
8
+ from model import tidev2
9
+
10
+
11
+ def init_vae_model(model_name, latent_dim):
12
+ if model_name == 'tidev2':
13
+ vae_model = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=latent_dim),
14
+ tidev2.ConvNeXtDecoderTiny(latent_dim=latent_dim)
15
+ )
16
+ return vae_model
17
+
18
+
19
+ def load_weights(vae, weights_path):
20
+ print("Loading weights from {}".format(weights_path))
21
+ if "ckpt-" in weights_path:
22
+ ckpt = tf.train.Checkpoint(vae=vae)
23
+ ckpt.restore(weights_path).expect_partial()
24
+ return vae
25
+ if ".TF" in weights_path:
26
+ vae.load_weights(weights_path, by_name=True)
27
+ return vae
28
+
29
+
30
+ def get_noise_seeded(noise_shape):
31
+ np.random.seed(0)
32
+ random_z = np.random.normal(0, 1, noise_shape)
33
+ return random_z
34
+
35
+
36
+ def decode_noise(trained_vae, noise, return_list=False):
37
+ print("Generating synthetic images ...")
38
+ pred = trained_vae.decoder.predict(noise, batch_size=1)
39
+ # print(type(pred), pred.shape, pred.dtype, pred.min(), pred.max())
40
+ pred *= 255.0
41
+ # print(type(pred), pred.shape, pred.dtype, pred.min(), pred.max())
42
+ if return_list:
43
+ return [img for img in pred]
44
+ return pred
45
+
46
+
47
+ def save_images(save_folder, images):
48
+ print(f"Saving synthetic images into {save_folder}")
49
+ if isinstance(images, list):
50
+ for i, image in enumerate(images):
51
+ image = image.astype(np.uint8)
52
+ Image.fromarray(image).save(os.path.join(save_folder, f"image-{i}.jpg"))
utils/plots.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import numpy as np
3
+
4
+
5
+ def visualize_from_latent_space(latent_dim, input_shape, vae, output_path, epoch="final", num_items=10,):
6
+
7
+ image_size, _, img_channels = input_shape
8
+ figure = np.zeros((image_size * num_items, image_size * num_items, 3))
9
+
10
+ scale = 1.0
11
+ grid_x = np.linspace(-scale, scale, num_items)
12
+ grid_y = np.linspace(-scale, scale, num_items)[::-1]
13
+
14
+ np.random.seed(42)
15
+ for i, yi in enumerate(grid_y):
16
+ for j, xi in enumerate(grid_x):
17
+ random_z = np.random.normal(0, 1, (1, latent_dim))
18
+ x_decoded = vae.decoder.predict(random_z)
19
+ image = x_decoded[0].reshape(input_shape)
20
+ figure[i * image_size: (i + 1) * image_size, j * image_size: (j + 1) * image_size, ] = image
21
+ print(f'Saving collage in {output_path}/decoding-noise-ep{epoch}.jpg')
22
+ imageio.imsave(f'{output_path}/decoding-noise-ep{epoch}.jpg', (figure * 255).astype('uint8'))
23
+
24
+
25
+