farrell236 commited on
Commit
e99a83c
·
1 Parent(s): 57d0fed
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from augmentations import IMAGENET_MEAN, IMAGENET_STD
11
+ from models import build_model
12
+
13
+
14
+ APP_STATE = {}
15
+
16
+
17
+ def load_model(args, device):
18
+ model = build_model(
19
+ model_name=args.model,
20
+ num_classes=1,
21
+ in_channels=3,
22
+ image_size=args.image_size,
23
+ backbone=args.backbone,
24
+ pretrained=False,
25
+ base_channels=args.base_channels,
26
+ dropout=args.dropout,
27
+ )
28
+
29
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
30
+
31
+ if "model_state_dict" in checkpoint:
32
+ state_dict = checkpoint["model_state_dict"]
33
+ else:
34
+ state_dict = checkpoint
35
+
36
+ model.load_state_dict(state_dict, strict=True)
37
+ model.to(device)
38
+ model.eval()
39
+
40
+ return model
41
+
42
+
43
+ def preprocess_image(image, image_size):
44
+ if isinstance(image, Image.Image):
45
+ image = np.array(image.convert("RGB"))
46
+ else:
47
+ image = np.array(image)
48
+
49
+ if image.ndim == 2:
50
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
51
+
52
+ if image.shape[-1] == 4:
53
+ image = image[..., :3]
54
+
55
+ original_rgb = image.copy()
56
+
57
+ resized = cv2.resize(
58
+ image,
59
+ (image_size, image_size),
60
+ interpolation=cv2.INTER_LINEAR,
61
+ )
62
+
63
+ resized = resized.astype(np.float32) / 255.0
64
+
65
+ mean = np.array(IMAGENET_MEAN, dtype=np.float32).reshape(1, 1, 3)
66
+ std = np.array(IMAGENET_STD, dtype=np.float32).reshape(1, 1, 3)
67
+
68
+ resized = (resized - mean) / std
69
+ tensor = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float()
70
+
71
+ return tensor, original_rgb
72
+
73
+
74
+ def overlay_mask(image_rgb, mask, alpha=0.45):
75
+ image_rgb = image_rgb.astype(np.uint8)
76
+
77
+ red = np.zeros_like(image_rgb)
78
+ red[..., 0] = 255
79
+
80
+ mask_3ch = mask[..., None]
81
+
82
+ overlay = image_rgb * (1 - alpha * mask_3ch) + red * (alpha * mask_3ch)
83
+ overlay = np.clip(overlay, 0, 255).astype(np.uint8)
84
+
85
+ return overlay
86
+
87
+
88
+ def run_inference(image, threshold):
89
+ tensor, original_rgb = preprocess_image(
90
+ image=image,
91
+ image_size=APP_STATE["image_size"],
92
+ )
93
+
94
+ tensor = tensor.to(APP_STATE["device"])
95
+
96
+ with torch.no_grad():
97
+ logits = APP_STATE["model"](tensor)
98
+ probs = torch.sigmoid(logits)
99
+
100
+ prob_map = probs[0, 0].detach().cpu().numpy()
101
+
102
+ original_h, original_w = original_rgb.shape[:2]
103
+
104
+ prob_map = cv2.resize(
105
+ prob_map,
106
+ (original_w, original_h),
107
+ interpolation=cv2.INTER_LINEAR,
108
+ )
109
+
110
+ pred_mask = (prob_map >= threshold).astype(np.float32)
111
+
112
+ return original_rgb, prob_map, pred_mask
113
+
114
+
115
+ def predict(image, threshold, alpha):
116
+ if image is None:
117
+ return None, None, None
118
+
119
+ original_rgb, prob_map, pred_mask = run_inference(image, threshold)
120
+
121
+ overlay = overlay_mask(original_rgb, pred_mask, alpha=alpha)
122
+ prob_vis = (prob_map * 255).clip(0, 255).astype(np.uint8)
123
+ mask_vis = (pred_mask * 255).astype(np.uint8)
124
+
125
+ return overlay, prob_vis, mask_vis
126
+
127
+
128
+ def build_app():
129
+ css = """
130
+ #input_image {
131
+ height: 430px !important;
132
+ }
133
+
134
+ #input_image img {
135
+ object-fit: contain !important;
136
+ max-height: 430px !important;
137
+ }
138
+
139
+ #overlay_output {
140
+ height: 200px !important;
141
+ }
142
+
143
+ #overlay_output img {
144
+ object-fit: contain !important;
145
+ max-height: 200px !important;
146
+ }
147
+
148
+ #prob_output {
149
+ height: 200px !important;
150
+ }
151
+
152
+ #prob_output img {
153
+ object-fit: contain !important;
154
+ max-height: 200px !important;
155
+ }
156
+
157
+ #mask_output {
158
+ height: 430px !important;
159
+ }
160
+
161
+ #mask_output img {
162
+ object-fit: contain !important;
163
+ max-height: 430px !important;
164
+ }
165
+ """
166
+
167
+ with gr.Blocks(title="Retina Vessel Segmentation", css=css) as demo:
168
+ gr.Markdown("# Retina Vessel Segmentation")
169
+ gr.Markdown(
170
+ f"Model: `{APP_STATE['model_name']}` | "
171
+ f"Backbone: `{APP_STATE['backbone']}` | "
172
+ f"Image size: `{APP_STATE['image_size']}`"
173
+ )
174
+
175
+ with gr.Row(equal_height=False):
176
+ with gr.Column(scale=1):
177
+ input_image = gr.Image(
178
+ type="pil",
179
+ label="Input CFP Image",
180
+ elem_id="input_image",
181
+ height=430,
182
+ )
183
+
184
+ threshold = gr.Slider(
185
+ minimum=0.05,
186
+ maximum=0.95,
187
+ value=0.5,
188
+ step=0.05,
189
+ label="Prediction Threshold",
190
+ )
191
+
192
+ alpha = gr.Slider(
193
+ minimum=0.1,
194
+ maximum=0.9,
195
+ value=0.45,
196
+ step=0.05,
197
+ label="Overlay Alpha",
198
+ )
199
+
200
+ run_button = gr.Button("Segment")
201
+
202
+ with gr.Column(scale=1.2):
203
+ with gr.Row():
204
+ overlay_output = gr.Image(
205
+ type="numpy",
206
+ label="Overlay",
207
+ elem_id="overlay_output",
208
+ height=200,
209
+ )
210
+
211
+ prob_output = gr.Image(
212
+ type="numpy",
213
+ label="Probability Map",
214
+ elem_id="prob_output",
215
+ height=200,
216
+ )
217
+
218
+ mask_output = gr.Image(
219
+ type="numpy",
220
+ label="Binary Mask",
221
+ elem_id="mask_output",
222
+ height=430,
223
+ )
224
+
225
+ run_button.click(
226
+ fn=predict,
227
+ inputs=[input_image, threshold, alpha],
228
+ outputs=[overlay_output, prob_output, mask_output],
229
+ )
230
+
231
+ threshold.change(
232
+ fn=predict,
233
+ inputs=[input_image, threshold, alpha],
234
+ outputs=[overlay_output, prob_output, mask_output],
235
+ )
236
+
237
+ alpha.change(
238
+ fn=predict,
239
+ inputs=[input_image, threshold, alpha],
240
+ outputs=[overlay_output, prob_output, mask_output],
241
+ )
242
+
243
+ return demo
244
+
245
+
246
+ def parse_args():
247
+ parser = argparse.ArgumentParser(description="Gradio app for retina vessel segmentation.")
248
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/fives_resunet/best.pt")
249
+ parser.add_argument("--image-size", type=int, default=1024)
250
+ parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"])
251
+ parser.add_argument("--backbone", type=str, default="resnet50")
252
+ parser.add_argument("--base-channels", type=int, default=32)
253
+ parser.add_argument("--dropout", type=float, default=0.0)
254
+ parser.add_argument("--device", type=str, default="cuda")
255
+ parser.add_argument("--server-name", type=str, default="127.0.0.1")
256
+ parser.add_argument("--server-port", type=int, default=7860)
257
+ parser.add_argument("--share", action="store_true")
258
+
259
+ return parser.parse_args()
260
+
261
+
262
+ if __name__ == "__main__":
263
+ args = parse_args()
264
+
265
+ device = args.device
266
+ if device == "cuda" and not torch.cuda.is_available():
267
+ device = "cpu"
268
+
269
+ checkpoint_path = Path(args.checkpoint)
270
+ if not checkpoint_path.exists():
271
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
272
+
273
+ APP_STATE["device"] = torch.device(device)
274
+ APP_STATE["image_size"] = args.image_size
275
+ APP_STATE["model_name"] = args.model
276
+ APP_STATE["backbone"] = args.backbone
277
+
278
+ APP_STATE["model"] = load_model(
279
+ args=args,
280
+ device=APP_STATE["device"],
281
+ )
282
+
283
+ print(f"Loaded checkpoint: {checkpoint_path}")
284
+ print(f"Device: {APP_STATE['device']}")
285
+ print(f"Model: {APP_STATE['model_name']}")
286
+ print(f"Backbone: {APP_STATE['backbone']}")
287
+ print(f"Image size: {APP_STATE['image_size']}")
288
+
289
+ demo = build_app()
290
+ demo.launch(
291
+ # server_name=args.server_name,
292
+ # server_port=args.server_port,
293
+ # share=args.share,
294
+ )
augmentations.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ augmentations.py
3
+
4
+ Simple camera-style augmentations for color fundus photography (CFP)
5
+ classification.
6
+
7
+ Expected input:
8
+ RGB NumPy image, shape (H, W, 3)
9
+
10
+ Dependencies:
11
+ pip install albumentations opencv-python
12
+ """
13
+
14
+ import albumentations as A
15
+ from albumentations.pytorch import ToTensorV2
16
+
17
+
18
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
+ IMAGENET_STD = (0.229, 0.224, 0.225)
20
+
21
+
22
+ def get_train_transforms(
23
+ image_size=1024,
24
+ mean=IMAGENET_MEAN,
25
+ std=IMAGENET_STD,
26
+ ):
27
+ """
28
+ Training transforms.
29
+ """
30
+ return A.Compose([
31
+ A.Resize(image_size, image_size),
32
+
33
+ A.HorizontalFlip(p=0.5),
34
+
35
+ A.ShiftScaleRotate(
36
+ shift_limit=0.02,
37
+ scale_limit=0.05,
38
+ rotate_limit=7,
39
+ border_mode=0,
40
+ value=0,
41
+ p=0.3,
42
+ ),
43
+
44
+ A.RandomBrightnessContrast(
45
+ brightness_limit=0.15,
46
+ contrast_limit=0.15,
47
+ p=0.5,
48
+ ),
49
+
50
+ A.RandomGamma(
51
+ gamma_limit=(85, 115),
52
+ p=0.3,
53
+ ),
54
+
55
+ A.HueSaturationValue(
56
+ hue_shift_limit=3,
57
+ sat_shift_limit=10,
58
+ val_shift_limit=10,
59
+ p=0.25,
60
+ ),
61
+
62
+ A.OneOf([
63
+ A.GaussianBlur(blur_limit=(3, 5)),
64
+ A.Downscale(scale_min=0.80, scale_max=0.95),
65
+ A.ImageCompression(quality_lower=75, quality_upper=100),
66
+ ], p=0.2),
67
+
68
+ A.Normalize(mean=mean, std=std),
69
+ ToTensorV2(),
70
+ ])
71
+
72
+
73
+ def get_val_transforms(
74
+ image_size=1024,
75
+ mean=IMAGENET_MEAN,
76
+ std=IMAGENET_STD,
77
+ ):
78
+ """
79
+ Validation/test transforms.
80
+ """
81
+ return A.Compose([
82
+ A.Resize(image_size, image_size),
83
+ A.Normalize(mean=mean, std=std),
84
+ ToTensorV2(),
85
+ ])
86
+
87
+
88
+ # -------------------------------------------------------------------------
89
+ # Suggested CFP augmentation parameter sets
90
+ # -------------------------------------------------------------------------
91
+ #
92
+ # 1) DEFAULT / CONSERVATIVE
93
+ # Use this as a general starting point for CFP classification tasks.
94
+ #
95
+ # Rationale:
96
+ # - Simulates common camera/acquisition variability.
97
+ # - Keeps color and image-quality perturbations mild.
98
+ # - Good first choice when the disease signal may depend on subtle color,
99
+ # contrast, texture, or anatomical context.
100
+ #
101
+ # brightness_limit = 0.15
102
+ # contrast_limit = 0.15
103
+ # gamma_limit = (85, 115) # approximately gamma 0.85–1.15
104
+ # hue_shift_limit = 3 # intentionally small for fundus color realism
105
+ # sat_shift_limit = 10
106
+ # val_shift_limit = 10
107
+ # rotate_limit = 7
108
+ # shift_limit = 0.02
109
+ # scale_limit = 0.05
110
+ # blur_limit = (3, 5)
111
+ # downscale_range = (0.80, 0.95)
112
+ # jpeg_quality = (75, 100)
113
+ #
114
+ #
115
+ # 2) MORE AGGRESSIVE / DOMAIN-ROBUSTNESS
116
+ # Use this when robustness across different CFP cameras, sites, image qualities,
117
+ # or acquisition pipelines is more important, and confirm using external or
118
+ # camera/site-held-out validation.
119
+ #
120
+ # Rationale:
121
+ # - Simulates broader variation across CFP devices and acquisition conditions.
122
+ # - May improve domain robustness.
123
+ # - Higher risk of altering disease-relevant appearance, so it should be
124
+ # validated carefully for the target task.
125
+ #
126
+ # brightness_limit = 0.25
127
+ # contrast_limit = 0.25
128
+ # gamma_limit = (75, 130) # approximately gamma 0.75–1.30
129
+ # hue_shift_limit = 5 # still limited for fundus color realism
130
+ # sat_shift_limit = 18
131
+ # val_shift_limit = 18
132
+ # rotate_limit = 12
133
+ # shift_limit = 0.04
134
+ # scale_limit = 0.10
135
+ # blur_limit = (3, 7)
136
+ # downscale_range = (0.65, 0.95)
137
+ # jpeg_quality = (55, 100)
138
+ # -------------------------------------------------------------------------
checkpoints/fives_resunet/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f89b779afb9de2859fa57a0282dd5e3e252fab39ab8fdcfa1cc0ce794108bbd
3
+ size 97523253
datasets/DRIVE.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from PIL import Image
6
+ import torchvision.transforms.functional as TF
7
+
8
+
9
+ class DRIVEDataset(Dataset):
10
+ """
11
+ PyTorch Dataset for the DRIVE retinal vessel segmentation dataset.
12
+
13
+ Expected structure:
14
+ DRIVE/
15
+ ├── training/
16
+ │ ├── images/
17
+ │ ├── 1st_manual/
18
+ │ └── mask/
19
+ └── test/
20
+ ├── images/
21
+ └── mask/
22
+
23
+ For training split:
24
+ image: 21_training.tif
25
+ vessel mask: 21_manual1.gif
26
+ FOV mask: 21_training_mask.gif
27
+
28
+ For test split:
29
+ image: 01_test.tif
30
+ FOV mask: 01_test_mask.gif
31
+ no vessel mask is included in the provided tree
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ root,
37
+ split="training",
38
+ image_size=None,
39
+ return_fov=True,
40
+ transform=None,
41
+ ):
42
+ self.root = Path(root)
43
+ self.split = split
44
+ self.image_size = image_size
45
+ self.return_fov = return_fov
46
+ self.transform = transform
47
+
48
+ if split not in ["training", "test"]:
49
+ raise ValueError("split must be either 'training' or 'test'")
50
+
51
+ self.split_dir = self.root / split
52
+ self.image_dir = self.split_dir / "images"
53
+ self.fov_dir = self.split_dir / "mask"
54
+
55
+ if not self.image_dir.exists():
56
+ raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
57
+
58
+ self.image_paths = sorted(self.image_dir.glob("*.tif"))
59
+
60
+ if len(self.image_paths) == 0:
61
+ raise RuntimeError(f"No .tif images found in {self.image_dir}")
62
+
63
+ if split == "training":
64
+ self.label_dir = self.split_dir / "1st_manual"
65
+ if not self.label_dir.exists():
66
+ raise FileNotFoundError(f"Label directory not found: {self.label_dir}")
67
+ else:
68
+ self.label_dir = None
69
+
70
+ def __len__(self):
71
+ return len(self.image_paths)
72
+
73
+ def _get_case_id(self, image_path):
74
+ """
75
+ Examples:
76
+ 21_training.tif -> 21
77
+ 01_test.tif -> 01
78
+ """
79
+ return image_path.stem.split("_")[0]
80
+
81
+ def _load_image(self, path):
82
+ image = Image.open(path).convert("RGB")
83
+ return image
84
+
85
+ def _load_mask(self, path):
86
+ mask = Image.open(path).convert("L")
87
+ return mask
88
+
89
+ def _resize_if_needed(self, image, label=None, fov=None):
90
+ if self.image_size is None:
91
+ return image, label, fov
92
+
93
+ size = self.image_size
94
+ if isinstance(size, int):
95
+ size = (size, size)
96
+
97
+ image = TF.resize(image, size, interpolation=TF.InterpolationMode.BILINEAR)
98
+
99
+ if label is not None:
100
+ label = TF.resize(label, size, interpolation=TF.InterpolationMode.NEAREST)
101
+
102
+ if fov is not None:
103
+ fov = TF.resize(fov, size, interpolation=TF.InterpolationMode.NEAREST)
104
+
105
+ return image, label, fov
106
+
107
+ def __getitem__(self, idx):
108
+ image_path = self.image_paths[idx]
109
+ case_id = self._get_case_id(image_path)
110
+
111
+ image = self._load_image(image_path)
112
+
113
+ if self.split == "training":
114
+ label_path = self.label_dir / f"{case_id}_manual1.gif"
115
+ label = self._load_mask(label_path)
116
+ else:
117
+ label = None
118
+
119
+ fov_path = self.fov_dir / f"{case_id}_{self.split}_mask.gif"
120
+ fov = self._load_mask(fov_path)
121
+
122
+ image, label, fov = self._resize_if_needed(image, label, fov)
123
+
124
+ if self.transform is not None:
125
+ image, label, fov = self.transform(image, label, fov)
126
+
127
+ image = TF.to_tensor(image)
128
+
129
+ sample = {
130
+ "image": image,
131
+ "case_id": case_id,
132
+ }
133
+
134
+ if label is not None:
135
+ label = TF.to_tensor(label)
136
+ label = (label > 0.5).float()
137
+ sample["label"] = label
138
+
139
+ if self.return_fov:
140
+ fov = TF.to_tensor(fov)
141
+ fov = (fov > 0.5).float()
142
+ sample["fov"] = fov
143
+
144
+ return sample
145
+
146
+
147
+ if __name__ == "__main__":
148
+ import matplotlib.pyplot as plt
149
+
150
+ root = "/data/MIDS/datasets/retina/DRIVE"
151
+
152
+ dataset = DRIVEDataset(
153
+ root=root,
154
+ split="training",
155
+ image_size=512,
156
+ return_fov=True,
157
+ )
158
+
159
+ loader = DataLoader(
160
+ dataset,
161
+ batch_size=4,
162
+ shuffle=True,
163
+ num_workers=0,
164
+ )
165
+
166
+ batch = next(iter(loader))
167
+
168
+ print("Number of samples:", len(dataset))
169
+ print("Batch keys:", batch.keys())
170
+ print("Image shape:", batch["image"].shape)
171
+
172
+ if "label" in batch:
173
+ print("Label shape:", batch["label"].shape)
174
+ print("Label min/max:", batch["label"].min().item(), batch["label"].max().item())
175
+
176
+ if "fov" in batch:
177
+ print("FOV shape:", batch["fov"].shape)
178
+ print("FOV min/max:", batch["fov"].min().item(), batch["fov"].max().item())
179
+
180
+ print("Case IDs:", batch["case_id"])
181
+
182
+ # -------------------------
183
+ # Matplotlib visualization
184
+ # -------------------------
185
+ image = batch["image"][0] # [3, H, W]
186
+ label = batch.get("label", None)
187
+ fov = batch.get("fov", None)
188
+
189
+ image_np = image.permute(1, 2, 0).cpu().numpy()
190
+
191
+ fig, axes = plt.subplots(1, 4, figsize=(16, 4))
192
+
193
+ axes[0].imshow(image_np)
194
+ axes[0].set_title("Image")
195
+ axes[0].axis("off")
196
+
197
+ if label is not None:
198
+ label_np = label[0, 0].cpu().numpy()
199
+
200
+ axes[1].imshow(label_np, cmap="gray")
201
+ axes[1].set_title("Vessel Label")
202
+ axes[1].axis("off")
203
+
204
+ axes[2].imshow(image_np)
205
+ axes[2].imshow(label_np, cmap="Reds", alpha=0.45)
206
+ axes[2].set_title("Image + Vessel Overlay")
207
+ axes[2].axis("off")
208
+ else:
209
+ axes[1].axis("off")
210
+ axes[2].axis("off")
211
+
212
+ if fov is not None:
213
+ fov_np = fov[0, 0].cpu().numpy()
214
+
215
+ axes[3].imshow(image_np)
216
+ axes[3].imshow(fov_np, cmap="gray", alpha=0.25)
217
+ axes[3].set_title("Image + FOV Overlay")
218
+ axes[3].axis("off")
219
+ else:
220
+ axes[3].axis("off")
221
+
222
+ plt.tight_layout()
223
+ plt.show()
datasets/FGADR.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ from sklearn.model_selection import KFold
8
+
9
+
10
+ class FGADRDataset(Dataset):
11
+ """
12
+ FGADR Seg-set dataset for diabetic retinopathy lesion segmentation.
13
+
14
+ Expected structure:
15
+ Seg-set/
16
+ ├── DR_Seg_Grading_Label.csv
17
+ ├── Original_Images/
18
+ ├── Microaneurysms_Masks/
19
+ ├── Hemohedge_Masks/
20
+ ├── HardExudate_Masks/
21
+ ├── SoftExudate_Masks/
22
+ ├── IRMA_Masks/
23
+ └── Neovascularization_Masks/
24
+
25
+ CSV format, no header:
26
+ filename,dr_grade
27
+
28
+ Output:
29
+ image: [3, H, W]
30
+ label: [6, H, W]
31
+ grade: scalar long tensor
32
+ case_id: filename stem
33
+
34
+ split:
35
+ "train" = all folds except selected fold
36
+ "val" = selected fold
37
+ "all" = full dataset
38
+
39
+ Notes:
40
+ If a lesion-specific mask file is absent, it is treated as an empty
41
+ all-zero mask, meaning no incidence of that lesion class.
42
+ """
43
+
44
+ lesion_dirs = {
45
+ "microaneurysm": "Microaneurysms_Masks",
46
+ "hemorrhage": "Hemohedge_Masks",
47
+ "hard_exudate": "HardExudate_Masks",
48
+ "soft_exudate": "SoftExudate_Masks",
49
+ "irma": "IRMA_Masks",
50
+ "neovascularization": "Neovascularization_Masks",
51
+ }
52
+
53
+ def __init__(
54
+ self,
55
+ root,
56
+ split="train",
57
+ fold=0,
58
+ n_folds=5,
59
+ seed=42,
60
+ transform=None,
61
+ csv_name="DR_Seg_Grading_Label.csv",
62
+ image_dir_name="Original_Images",
63
+ mask_suffix="",
64
+ ):
65
+ self.root = Path(root)
66
+ self.split = split
67
+ self.fold = fold
68
+ self.n_folds = n_folds
69
+ self.seed = seed
70
+ self.transform = transform
71
+ self.csv_path = self.root / csv_name
72
+ self.image_dir = self.root / image_dir_name
73
+ self.mask_suffix = mask_suffix
74
+
75
+ if split not in ["train", "val", "all"]:
76
+ raise ValueError("split must be one of: 'train', 'val', 'all'")
77
+
78
+ if not (0 <= fold < n_folds):
79
+ raise ValueError(f"fold must be in [0, {n_folds - 1}], got {fold}")
80
+
81
+ if not self.image_dir.exists():
82
+ raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
83
+
84
+ if not self.csv_path.exists():
85
+ raise FileNotFoundError(f"CSV file not found: {self.csv_path}")
86
+
87
+ self.class_names = list(self.lesion_dirs.keys())
88
+
89
+ for dirname in self.lesion_dirs.values():
90
+ mask_dir = self.root / dirname
91
+ if not mask_dir.exists():
92
+ raise FileNotFoundError(f"Mask directory not found: {mask_dir}")
93
+
94
+ all_samples = self._read_csv()
95
+
96
+ if len(all_samples) == 0:
97
+ raise RuntimeError(f"No samples found in {self.csv_path}")
98
+
99
+ if split == "all":
100
+ self.samples = all_samples
101
+ else:
102
+ kfold = KFold(
103
+ n_splits=n_folds,
104
+ shuffle=True,
105
+ random_state=seed,
106
+ )
107
+
108
+ splits = list(kfold.split(all_samples))
109
+ train_indices, val_indices = splits[fold]
110
+
111
+ if split == "train":
112
+ self.samples = [all_samples[i] for i in train_indices]
113
+ else:
114
+ self.samples = [all_samples[i] for i in val_indices]
115
+
116
+ def _read_csv(self):
117
+ samples = []
118
+
119
+ with open(self.csv_path, "r") as f:
120
+ for line in f:
121
+ line = line.strip()
122
+
123
+ if not line:
124
+ continue
125
+
126
+ parts = line.split(",")
127
+
128
+ if len(parts) < 2:
129
+ continue
130
+
131
+ filename = parts[0].strip()
132
+ grade = int(parts[1].strip())
133
+
134
+ image_path = self.image_dir / filename
135
+
136
+ if not image_path.exists():
137
+ raise FileNotFoundError(f"Image not found: {image_path}")
138
+
139
+ samples.append(
140
+ {
141
+ "filename": filename,
142
+ "case_id": Path(filename).stem,
143
+ "image_path": image_path,
144
+ "grade": grade,
145
+ }
146
+ )
147
+
148
+ return samples
149
+
150
+ def __len__(self):
151
+ return len(self.samples)
152
+
153
+ def _load_image(self, path):
154
+ image = Image.open(path).convert("RGB")
155
+ return np.array(image)
156
+
157
+ def _load_mask(self, path, shape):
158
+ if path.exists():
159
+ mask = Image.open(path).convert("L")
160
+ mask = np.array(mask)
161
+ else:
162
+ mask = np.zeros(shape, dtype=np.uint8)
163
+
164
+ return mask
165
+
166
+ def _get_mask_path(self, lesion_name, filename):
167
+ mask_dir = self.root / self.lesion_dirs[lesion_name]
168
+
169
+ if self.mask_suffix:
170
+ stem = Path(filename).stem
171
+ suffix = Path(filename).suffix
172
+ filename = f"{stem}{self.mask_suffix}{suffix}"
173
+
174
+ return mask_dir / filename
175
+
176
+ def __getitem__(self, idx):
177
+ sample_info = self.samples[idx]
178
+
179
+ filename = sample_info["filename"]
180
+ image_path = sample_info["image_path"]
181
+ case_id = sample_info["case_id"]
182
+ grade = sample_info["grade"]
183
+
184
+ image = self._load_image(image_path)
185
+ h, w = image.shape[:2]
186
+
187
+ masks = []
188
+ mask_paths = {}
189
+
190
+ for lesion_name in self.class_names:
191
+ mask_path = self._get_mask_path(lesion_name, filename)
192
+ mask = self._load_mask(mask_path, shape=(h, w))
193
+
194
+ masks.append(mask)
195
+ mask_paths[lesion_name] = str(mask_path)
196
+
197
+ if self.transform is not None:
198
+ transformed = self.transform(
199
+ image=image,
200
+ masks=masks,
201
+ )
202
+
203
+ image = transformed["image"]
204
+ masks = transformed["masks"]
205
+
206
+ masks = [
207
+ m.float() if isinstance(m, torch.Tensor) else torch.from_numpy(m).float()
208
+ for m in masks
209
+ ]
210
+
211
+ label = torch.stack(masks, dim=0)
212
+
213
+ else:
214
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
215
+ label = torch.stack(
216
+ [torch.from_numpy(m).float() for m in masks],
217
+ dim=0,
218
+ )
219
+
220
+ label = (label > 0).float()
221
+
222
+ return {
223
+ "image": image,
224
+ "label": label,
225
+ "grade": torch.tensor(grade, dtype=torch.long),
226
+ "case_id": case_id,
227
+ "filename": filename,
228
+ "image_path": str(image_path),
229
+ "mask_paths": mask_paths,
230
+ }
231
+
232
+
233
+ if __name__ == "__main__":
234
+ import matplotlib.pyplot as plt
235
+ from tqdm import tqdm
236
+
237
+ try:
238
+ from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD
239
+ except ImportError:
240
+ import sys
241
+
242
+ project_root = Path(__file__).resolve().parents[1]
243
+ sys.path.append(str(project_root))
244
+
245
+ from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD
246
+
247
+ root = "/data/MIDS/datasets/retina/FGADR/Seg-set"
248
+ image_size = 512
249
+
250
+ dataset = FGADRDataset(
251
+ root=root,
252
+ split="train",
253
+ fold=0,
254
+ n_folds=5,
255
+ seed=42,
256
+ transform=get_train_transforms(image_size=image_size),
257
+ )
258
+
259
+ print("\nChecking all FGADR files...")
260
+
261
+ missing_images = 0
262
+ absent_masks = 0
263
+
264
+ for sample in tqdm(dataset.samples, desc="Checking files"):
265
+ filename = sample["filename"]
266
+
267
+ if not sample["image_path"].exists():
268
+ print(f"Missing image: {sample['image_path']}")
269
+ missing_images += 1
270
+
271
+ for lesion_name in dataset.class_names:
272
+ mask_path = dataset._get_mask_path(lesion_name, filename)
273
+
274
+ if not mask_path.exists():
275
+ absent_masks += 1
276
+
277
+ print("File check complete.")
278
+ print(f"Missing images: {missing_images}")
279
+ print(f"Absent lesion masks treated as empty: {absent_masks}")
280
+
281
+ loader = DataLoader(
282
+ dataset,
283
+ batch_size=4,
284
+ shuffle=True,
285
+ num_workers=0,
286
+ )
287
+
288
+ batch = next(iter(loader))
289
+
290
+ print("\nSmoke test batch:")
291
+ print("Number of samples:", len(dataset))
292
+ print("Split:", dataset.split)
293
+ print("Fold:", dataset.fold)
294
+ print("Number of folds:", dataset.n_folds)
295
+ print("Class names:", dataset.class_names)
296
+ print("Batch keys:", batch.keys())
297
+ print("Image shape:", batch["image"].shape)
298
+ print("Label shape:", batch["label"].shape)
299
+ print("Grade shape:", batch["grade"].shape)
300
+ print("Label min/max:", batch["label"].min().item(), batch["label"].max().item())
301
+ print("Case IDs:", batch["case_id"])
302
+
303
+ image = batch["image"][0].cpu()
304
+ label = batch["label"][0].cpu()
305
+ grade = batch["grade"][0].item()
306
+
307
+ mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
308
+ std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
309
+
310
+ image_vis = image * std + mean
311
+ image_vis = image_vis.clamp(0, 1)
312
+ image_vis = image_vis.permute(1, 2, 0).numpy()
313
+
314
+ combined_mask = (label.sum(dim=0) > 0).float().numpy()
315
+
316
+ fig, axes = plt.subplots(2, 5, figsize=(20, 8))
317
+ axes = axes.flatten()
318
+
319
+ axes[0].imshow(image_vis)
320
+ axes[0].set_title(f"Image | Grade {grade}")
321
+ axes[0].axis("off")
322
+
323
+ axes[1].imshow(combined_mask, cmap="gray")
324
+ axes[1].set_title("Any Lesion")
325
+ axes[1].axis("off")
326
+
327
+ axes[2].imshow(image_vis)
328
+ axes[2].imshow(combined_mask, cmap="Reds", alpha=0.45)
329
+ axes[2].set_title("Overlay")
330
+ axes[2].axis("off")
331
+
332
+ for ax in axes[3:]:
333
+ ax.axis("off")
334
+
335
+ for i, class_name in enumerate(dataset.class_names):
336
+ ax = axes[i + 3]
337
+ ax.imshow(label[i].numpy(), cmap="gray")
338
+ ax.set_title(class_name)
339
+ ax.axis("off")
340
+
341
+ plt.tight_layout()
342
+ plt.show()
datasets/FIVES.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+
8
+
9
+ class FIVESDataset(Dataset):
10
+ """
11
+ PyTorch Dataset for FIVES retinal vessel segmentation.
12
+
13
+ Expected structure:
14
+ FIVES_dataset/
15
+ ├── train/
16
+ │ ├── Original/
17
+ │ └── Ground truth/
18
+ └── test/
19
+ ├── Original/
20
+ └── Ground truth/
21
+
22
+ Each image in Original/ should have a matching vessel mask
23
+ with the same filename in Ground truth/.
24
+
25
+ Output sample:
26
+ {
27
+ "image": Tensor [3, H, W],
28
+ "label": Tensor [1, H, W],
29
+ "case_id": str,
30
+ "image_path": str,
31
+ "label_path": str,
32
+ }
33
+
34
+ If transform is provided, it should be an Albumentations transform.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ root,
40
+ split="train",
41
+ transform=None,
42
+ image_dir_name="Original",
43
+ label_dir_name="Ground truth",
44
+ ):
45
+ self.root = Path(root)
46
+ self.split = split
47
+ self.transform = transform
48
+
49
+ if split not in ["train", "test"]:
50
+ raise ValueError("split must be either 'train' or 'test'")
51
+
52
+ self.split_dir = self.root / split
53
+ self.image_dir = self.split_dir / image_dir_name
54
+ self.label_dir = self.split_dir / label_dir_name
55
+
56
+ if not self.image_dir.exists():
57
+ raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
58
+
59
+ if not self.label_dir.exists():
60
+ raise FileNotFoundError(f"Label directory not found: {self.label_dir}")
61
+
62
+ self.image_paths = sorted(
63
+ [
64
+ p for p in self.image_dir.glob("*.png")
65
+ if not p.name.startswith(".") and p.name.lower() != "thumbs.db"
66
+ ]
67
+ )
68
+
69
+ if len(self.image_paths) == 0:
70
+ raise RuntimeError(f"No PNG images found in {self.image_dir}")
71
+
72
+ self.samples = []
73
+
74
+ for image_path in self.image_paths:
75
+ label_path = self.label_dir / image_path.name
76
+
77
+ if not label_path.exists():
78
+ raise FileNotFoundError(
79
+ f"Missing label for image:\n"
80
+ f"image: {image_path}\n"
81
+ f"label: {label_path}"
82
+ )
83
+
84
+ self.samples.append(
85
+ {
86
+ "image_path": image_path,
87
+ "label_path": label_path,
88
+ "case_id": image_path.stem,
89
+ }
90
+ )
91
+
92
+ def __len__(self):
93
+ return len(self.samples)
94
+
95
+ def _load_image(self, path):
96
+ image = Image.open(path).convert("RGB")
97
+ return np.array(image)
98
+
99
+ def _load_mask(self, path):
100
+ mask = Image.open(path).convert("L")
101
+ return np.array(mask)
102
+
103
+ def __getitem__(self, idx):
104
+ sample_info = self.samples[idx]
105
+
106
+ image_path = sample_info["image_path"]
107
+ label_path = sample_info["label_path"]
108
+ case_id = sample_info["case_id"]
109
+
110
+ image = self._load_image(image_path)
111
+ label = self._load_mask(label_path)
112
+
113
+ if self.transform is not None:
114
+ transformed = self.transform(
115
+ image=image,
116
+ mask=label,
117
+ )
118
+
119
+ image = transformed["image"]
120
+ label = transformed["mask"]
121
+
122
+ # Albumentations ToTensorV2 converts image to [3, H, W],
123
+ # but mask remains [H, W], so add channel dimension.
124
+ if isinstance(label, torch.Tensor):
125
+ label = label.float().unsqueeze(0)
126
+ else:
127
+ label = torch.from_numpy(label).float().unsqueeze(0)
128
+
129
+ else:
130
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
131
+ label = torch.from_numpy(label).float().unsqueeze(0)
132
+
133
+ # Convert vessel mask to binary {0, 1}
134
+ label = (label > 0).float()
135
+
136
+ return {
137
+ "image": image,
138
+ "label": label,
139
+ "case_id": case_id,
140
+ "image_path": str(image_path),
141
+ "label_path": str(label_path),
142
+ }
143
+
144
+
145
+ if __name__ == "__main__":
146
+ import matplotlib.pyplot as plt
147
+
148
+ try:
149
+ from augmentations import get_train_transforms, get_val_transforms
150
+ except ImportError:
151
+ import sys
152
+
153
+ project_root = Path(__file__).resolve().parents[1]
154
+ sys.path.append(str(project_root))
155
+
156
+ from augmentations import get_train_transforms, get_val_transforms
157
+
158
+ root = "/data/MIDS/datasets/retina/FIVES_dataset"
159
+ image_size = 512
160
+
161
+ dataset = FIVESDataset(
162
+ root=root,
163
+ split="train",
164
+ transform=get_train_transforms(image_size=image_size),
165
+ )
166
+
167
+ loader = DataLoader(
168
+ dataset,
169
+ batch_size=4,
170
+ shuffle=True,
171
+ num_workers=0,
172
+ )
173
+
174
+ batch = next(iter(loader))
175
+
176
+ print("Number of samples:", len(dataset))
177
+ print("Batch keys:", batch.keys())
178
+ print("Image shape:", batch["image"].shape)
179
+ print("Label shape:", batch["label"].shape)
180
+ print("Label min/max:", batch["label"].min().item(), batch["label"].max().item())
181
+ print("Case IDs:", batch["case_id"])
182
+
183
+ # -------------------------
184
+ # Matplotlib visualization
185
+ # -------------------------
186
+ image = batch["image"][0]
187
+ label = batch["label"][0, 0]
188
+
189
+ # Undo ImageNet normalization for visualization.
190
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
191
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
192
+
193
+ image_vis = image.cpu() * std + mean
194
+ image_vis = image_vis.clamp(0, 1)
195
+ image_vis = image_vis.permute(1, 2, 0).numpy()
196
+
197
+ label_vis = label.cpu().numpy()
198
+
199
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4))
200
+
201
+ axes[0].imshow(image_vis)
202
+ axes[0].set_title("Image")
203
+ axes[0].axis("off")
204
+
205
+ axes[1].imshow(label_vis, cmap="gray")
206
+ axes[1].set_title("Vessel Label")
207
+ axes[1].axis("off")
208
+
209
+ axes[2].imshow(image_vis)
210
+ axes[2].imshow(label_vis, cmap="Reds", alpha=0.45)
211
+ axes[2].set_title("Overlay")
212
+ axes[2].axis("off")
213
+
214
+ plt.tight_layout()
215
+ plt.show()
datasets/__init__.py ADDED
File without changes
losses.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DiceLoss(nn.Module):
7
+ """
8
+ Soft Dice loss for binary segmentation.
9
+
10
+ Expected shapes:
11
+ logits: [B, 1, H, W]
12
+ targets: [B, 1, H, W]
13
+ mask: [B, 1, H, W], optional FOV mask
14
+
15
+ The model should output raw logits, not sigmoid probabilities.
16
+ """
17
+
18
+ def __init__(self, smooth=1.0):
19
+ super().__init__()
20
+ self.smooth = smooth
21
+
22
+ def forward(self, logits, targets, mask=None):
23
+ probs = torch.sigmoid(logits)
24
+
25
+ if mask is not None:
26
+ probs = probs * mask
27
+ targets = targets * mask
28
+
29
+ probs = probs.flatten(1)
30
+ targets = targets.flatten(1)
31
+
32
+ intersection = (probs * targets).sum(dim=1)
33
+ denominator = probs.sum(dim=1) + targets.sum(dim=1)
34
+
35
+ dice = (2.0 * intersection + self.smooth) / (
36
+ denominator + self.smooth
37
+ )
38
+
39
+ return 1.0 - dice.mean()
40
+
41
+
42
+ class BCEDiceLoss(nn.Module):
43
+ """
44
+ BCEWithLogits + Dice loss for binary vessel segmentation.
45
+
46
+ The optional mask argument is intended for the DRIVE FOV mask, so that
47
+ background outside the retinal field of view does not dominate training.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ bce_weight=1.0,
53
+ dice_weight=1.0,
54
+ smooth=1.0,
55
+ ):
56
+ super().__init__()
57
+
58
+ self.bce_weight = bce_weight
59
+ self.dice_weight = dice_weight
60
+ self.dice = DiceLoss(smooth=smooth)
61
+
62
+ def forward(self, logits, targets, mask=None):
63
+ bce = F.binary_cross_entropy_with_logits(
64
+ logits,
65
+ targets,
66
+ reduction="none",
67
+ )
68
+
69
+ if mask is not None:
70
+ bce = bce * mask
71
+ bce = bce.sum() / mask.sum().clamp_min(1.0)
72
+ else:
73
+ bce = bce.mean()
74
+
75
+ dice = self.dice(logits, targets, mask)
76
+
77
+ loss = self.bce_weight * bce + self.dice_weight * dice
78
+
79
+ return loss
80
+
81
+
82
+ @torch.no_grad()
83
+ def compute_dice_score(
84
+ logits,
85
+ targets,
86
+ mask=None,
87
+ threshold=0.5,
88
+ eps=1e-7,
89
+ ):
90
+ """
91
+ Hard Dice score for monitoring.
92
+
93
+ Expected shapes:
94
+ logits: [B, 1, H, W]
95
+ targets: [B, 1, H, W]
96
+ mask: [B, 1, H, W], optional
97
+ """
98
+
99
+ probs = torch.sigmoid(logits)
100
+ preds = (probs > threshold).float()
101
+
102
+ if mask is not None:
103
+ preds = preds * mask
104
+ targets = targets * mask
105
+
106
+ preds = preds.flatten(1)
107
+ targets = targets.flatten(1)
108
+
109
+ intersection = (preds * targets).sum(dim=1)
110
+ denominator = preds.sum(dim=1) + targets.sum(dim=1)
111
+
112
+ dice = (2.0 * intersection + eps) / (denominator + eps)
113
+
114
+ return dice.mean().item()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ # Smoke test:
119
+ # python losses.py
120
+
121
+ logits = torch.randn(2, 1, 512, 512)
122
+ targets = torch.randint(0, 2, (2, 1, 512, 512)).float()
123
+ fov = torch.ones(2, 1, 512, 512)
124
+
125
+ criterion = BCEDiceLoss(
126
+ bce_weight=1.0,
127
+ dice_weight=1.0,
128
+ )
129
+
130
+ loss = criterion(logits, targets, fov)
131
+ dice = compute_dice_score(logits, targets, fov)
132
+
133
+ print("Loss:", loss.item())
134
+ print("Dice:", dice)
135
+ print("Smoke test passed.")
models/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .unet import build_resunet
2
+ from .deeplabv3 import build_deeplabv3
3
+ from .vit import build_vit
4
+
5
+
6
+ def build_model(
7
+ model_name="resunet",
8
+ num_classes=1,
9
+ in_channels=3,
10
+ image_size=512,
11
+ backbone="resnet50",
12
+ pretrained=True,
13
+ base_channels=32,
14
+ dropout=0.0,
15
+ ):
16
+ """
17
+ Generic model builder.
18
+
19
+ model_name options:
20
+ resunet
21
+ deeplabv3
22
+ vit
23
+
24
+ backbone:
25
+ For deeplabv3:
26
+ resnet50, resnet101
27
+
28
+ For vit:
29
+ tiny, small, base, large
30
+ or a timm model name
31
+
32
+ For resunet:
33
+ unused
34
+ """
35
+
36
+ model_name = model_name.lower()
37
+
38
+ if model_name == "resunet":
39
+ return build_resunet(
40
+ in_channels=in_channels,
41
+ num_classes=num_classes,
42
+ base_channels=base_channels,
43
+ dropout=dropout,
44
+ )
45
+
46
+ if model_name == "deeplabv3":
47
+ return build_deeplabv3(
48
+ backbone=backbone,
49
+ num_classes=num_classes,
50
+ pretrained_backbone=pretrained,
51
+ )
52
+
53
+ if model_name == "vit":
54
+ return build_vit(
55
+ variant=backbone,
56
+ num_classes=num_classes,
57
+ pretrained=pretrained,
58
+ in_chans=in_channels,
59
+ img_size=image_size,
60
+ dropout=dropout,
61
+ )
62
+
63
+ raise ValueError(
64
+ f"Unsupported model_name: {model_name}. "
65
+ "Choose from: resunet, deeplabv3, vit."
66
+ )
models/deeplabv3.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from torchvision.models.segmentation import (
6
+ deeplabv3_resnet50,
7
+ deeplabv3_resnet101,
8
+ )
9
+ from torchvision.models.segmentation.deeplabv3 import DeepLabHead
10
+
11
+
12
+ class DeepLabV3Wrapper(nn.Module):
13
+ """
14
+ DeepLabV3 wrapper for retinal vessel segmentation.
15
+
16
+ Output:
17
+ Raw logits [B, num_classes, H, W]
18
+
19
+ For binary vessel segmentation:
20
+ num_classes = 1
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ backbone="resnet50",
26
+ num_classes=1,
27
+ pretrained_backbone=True,
28
+ aux_loss=False,
29
+ ):
30
+ super().__init__()
31
+
32
+ if backbone == "resnet50":
33
+ model = deeplabv3_resnet50(
34
+ weights=None,
35
+ weights_backbone="DEFAULT" if pretrained_backbone else None,
36
+ aux_loss=aux_loss,
37
+ )
38
+ in_channels = 2048
39
+
40
+ elif backbone == "resnet101":
41
+ model = deeplabv3_resnet101(
42
+ weights=None,
43
+ weights_backbone="DEFAULT" if pretrained_backbone else None,
44
+ aux_loss=aux_loss,
45
+ )
46
+ in_channels = 2048
47
+
48
+ else:
49
+ raise ValueError(
50
+ f"Unsupported backbone: {backbone}. "
51
+ "Choose from: 'resnet50', 'resnet101'."
52
+ )
53
+
54
+ model.classifier = DeepLabHead(
55
+ in_channels=in_channels,
56
+ num_classes=num_classes,
57
+ )
58
+
59
+ if aux_loss and model.aux_classifier is not None:
60
+ model.aux_classifier[-1] = nn.Conv2d(
61
+ model.aux_classifier[-1].in_channels,
62
+ num_classes,
63
+ kernel_size=1,
64
+ )
65
+
66
+ self.model = model
67
+
68
+ def forward(self, x):
69
+ output = self.model(x)
70
+
71
+ # torchvision segmentation models return dict:
72
+ # {"out": logits, "aux": optional aux logits}
73
+ return output["out"]
74
+
75
+
76
+ def build_deeplabv3(
77
+ backbone="resnet50",
78
+ num_classes=1,
79
+ pretrained_backbone=True,
80
+ aux_loss=False,
81
+ ):
82
+ return DeepLabV3Wrapper(
83
+ backbone=backbone,
84
+ num_classes=num_classes,
85
+ pretrained_backbone=pretrained_backbone,
86
+ aux_loss=aux_loss,
87
+ )
88
+
89
+
90
+ if __name__ == "__main__":
91
+ # Smoke test:
92
+ # python models/deeplabv3.py
93
+
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+
96
+ model = build_deeplabv3(
97
+ backbone="resnet50",
98
+ num_classes=1,
99
+ pretrained_backbone=False,
100
+ ).to(device)
101
+
102
+ x = torch.randn(2, 3, 512, 512).to(device)
103
+
104
+ with torch.no_grad():
105
+ y = model(x)
106
+
107
+ print("Input shape:", x.shape)
108
+ print("Output shape:", y.shape)
109
+ print("Output min/max:", y.min().item(), y.max().item())
110
+
111
+ assert y.shape == (2, 1, 512, 512)
112
+
113
+ print("Smoke test passed.")
models/unet.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ConvBNReLU(nn.Module):
7
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
8
+ super().__init__()
9
+
10
+ self.block = nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.ReLU(inplace=True),
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.block(x)
18
+
19
+
20
+ class ResidualBlock(nn.Module):
21
+ """
22
+ Basic residual block for ResUNet.
23
+
24
+ If in_channels != out_channels, the shortcut uses a 1x1 conv.
25
+ """
26
+
27
+ def __init__(self, in_channels, out_channels):
28
+ super().__init__()
29
+
30
+ self.conv1 = ConvBNReLU(in_channels, out_channels)
31
+ self.conv2 = nn.Sequential(
32
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
33
+ nn.BatchNorm2d(out_channels),
34
+ )
35
+
36
+ if in_channels != out_channels:
37
+ self.shortcut = nn.Sequential(
38
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
39
+ nn.BatchNorm2d(out_channels),
40
+ )
41
+ else:
42
+ self.shortcut = nn.Identity()
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+
46
+ def forward(self, x):
47
+ residual = self.shortcut(x)
48
+
49
+ x = self.conv1(x)
50
+ x = self.conv2(x)
51
+
52
+ x = x + residual
53
+ x = self.relu(x)
54
+
55
+ return x
56
+
57
+
58
+ class EncoderBlock(nn.Module):
59
+ def __init__(self, in_channels, out_channels):
60
+ super().__init__()
61
+
62
+ self.res_block = ResidualBlock(in_channels, out_channels)
63
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
64
+
65
+ def forward(self, x):
66
+ skip = self.res_block(x)
67
+ pooled = self.pool(skip)
68
+ return skip, pooled
69
+
70
+
71
+ class DecoderBlock(nn.Module):
72
+ def __init__(self, in_channels, skip_channels, out_channels):
73
+ super().__init__()
74
+
75
+ self.up = nn.ConvTranspose2d(
76
+ in_channels,
77
+ out_channels,
78
+ kernel_size=2,
79
+ stride=2,
80
+ )
81
+
82
+ self.res_block = ResidualBlock(
83
+ out_channels + skip_channels,
84
+ out_channels,
85
+ )
86
+
87
+ def forward(self, x, skip):
88
+ x = self.up(x)
89
+
90
+ # Handles odd image sizes, though 512/1024 should already match.
91
+ if x.shape[-2:] != skip.shape[-2:]:
92
+ x = F.interpolate(
93
+ x,
94
+ size=skip.shape[-2:],
95
+ mode="bilinear",
96
+ align_corners=False,
97
+ )
98
+
99
+ x = torch.cat([x, skip], dim=1)
100
+ x = self.res_block(x)
101
+
102
+ return x
103
+
104
+
105
+ class ResUNet(nn.Module):
106
+ """
107
+ ResUNet for binary or multi-class retinal segmentation.
108
+
109
+ Output:
110
+ Raw logits of shape [B, num_classes, H, W]
111
+
112
+ For vessel segmentation:
113
+ num_classes=1
114
+ loss=BCEWithLogits/Dice/Tversky/etc.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ in_channels=3,
120
+ num_classes=1,
121
+ base_channels=32,
122
+ dropout=0.0,
123
+ ):
124
+ super().__init__()
125
+
126
+ c1 = base_channels
127
+ c2 = base_channels * 2
128
+ c3 = base_channels * 4
129
+ c4 = base_channels * 8
130
+ c5 = base_channels * 16
131
+
132
+ self.enc1 = EncoderBlock(in_channels, c1)
133
+ self.enc2 = EncoderBlock(c1, c2)
134
+ self.enc3 = EncoderBlock(c2, c3)
135
+ self.enc4 = EncoderBlock(c3, c4)
136
+
137
+ self.bottleneck = nn.Sequential(
138
+ ResidualBlock(c4, c5),
139
+ nn.Dropout2d(dropout),
140
+ )
141
+
142
+ self.dec4 = DecoderBlock(c5, c4, c4)
143
+ self.dec3 = DecoderBlock(c4, c3, c3)
144
+ self.dec2 = DecoderBlock(c3, c2, c2)
145
+ self.dec1 = DecoderBlock(c2, c1, c1)
146
+
147
+ self.out_conv = nn.Conv2d(c1, num_classes, kernel_size=1)
148
+
149
+ def forward(self, x):
150
+ s1, x = self.enc1(x)
151
+ s2, x = self.enc2(x)
152
+ s3, x = self.enc3(x)
153
+ s4, x = self.enc4(x)
154
+
155
+ x = self.bottleneck(x)
156
+
157
+ x = self.dec4(x, s4)
158
+ x = self.dec3(x, s3)
159
+ x = self.dec2(x, s2)
160
+ x = self.dec1(x, s1)
161
+
162
+ logits = self.out_conv(x)
163
+
164
+ return logits
165
+
166
+
167
+ def build_resunet(
168
+ in_channels=3,
169
+ num_classes=1,
170
+ base_channels=32,
171
+ dropout=0.0,
172
+ ):
173
+ return ResUNet(
174
+ in_channels=in_channels,
175
+ num_classes=num_classes,
176
+ base_channels=base_channels,
177
+ dropout=dropout,
178
+ )
179
+
180
+
181
+ if __name__ == "__main__":
182
+ # Smoke test:
183
+ # python models/unet.py
184
+
185
+ device = "cuda" if torch.cuda.is_available() else "cpu"
186
+
187
+ model = build_resunet(
188
+ in_channels=3,
189
+ num_classes=1,
190
+ base_channels=32,
191
+ dropout=0.0,
192
+ ).to(device)
193
+
194
+ x = torch.randn(2, 3, 512, 512).to(device)
195
+
196
+ with torch.no_grad():
197
+ y = model(x)
198
+
199
+ print("Input shape:", x.shape)
200
+ print("Output shape:", y.shape)
201
+ print("Output min/max:", y.min().item(), y.max().item())
202
+
203
+ assert y.shape == (2, 1, 512, 512)
204
+
205
+ print("Smoke test passed.")
models/vit.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ import timm
7
+ except ImportError as e:
8
+ raise ImportError(
9
+ "timm is required for models/vit.py. Install with: pip install timm"
10
+ ) from e
11
+
12
+
13
+ class ViTSegmentationModel(nn.Module):
14
+ """
15
+ Simple ViT segmentation model using a timm Vision Transformer backbone.
16
+
17
+ The model:
18
+ image -> ViT patch tokens -> reshape to feature map -> conv head -> upsample
19
+
20
+ Output:
21
+ logits of shape [B, num_classes, H, W]
22
+
23
+ For binary vessel segmentation:
24
+ num_classes = 1
25
+
26
+ For multi-class lesion segmentation:
27
+ num_classes = number of lesion/background classes
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model_name="vit_base_patch16_224",
33
+ num_classes=1,
34
+ pretrained=True,
35
+ in_chans=3,
36
+ img_size=512,
37
+ decoder_dim=256,
38
+ dropout=0.0,
39
+ ):
40
+ super().__init__()
41
+
42
+ self.model_name = model_name
43
+ self.num_classes = num_classes
44
+ self.img_size = img_size
45
+
46
+ self.backbone = timm.create_model(
47
+ model_name,
48
+ pretrained=pretrained,
49
+ num_classes=0,
50
+ global_pool="",
51
+ in_chans=in_chans,
52
+ img_size=img_size,
53
+ )
54
+
55
+ self.embed_dim = self.backbone.num_features
56
+ self.patch_size = self.backbone.patch_embed.patch_size
57
+
58
+ if isinstance(self.patch_size, tuple):
59
+ self.patch_size = self.patch_size[0]
60
+
61
+ self.decoder = nn.Sequential(
62
+ nn.Conv2d(self.embed_dim, decoder_dim, kernel_size=1),
63
+ nn.BatchNorm2d(decoder_dim),
64
+ nn.ReLU(inplace=True),
65
+ nn.Dropout2d(dropout),
66
+ nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1),
67
+ nn.BatchNorm2d(decoder_dim),
68
+ nn.ReLU(inplace=True),
69
+ nn.Conv2d(decoder_dim, num_classes, kernel_size=1),
70
+ )
71
+
72
+ def forward_features_as_map(self, x):
73
+ """
74
+ Convert ViT patch tokens into a spatial feature map.
75
+
76
+ Input:
77
+ x: [B, C, H, W]
78
+
79
+ Output:
80
+ feature_map: [B, embed_dim, H // patch_size, W // patch_size]
81
+ """
82
+ b, _, h, w = x.shape
83
+
84
+ tokens = self.backbone.forward_features(x)
85
+
86
+ # Some timm models return a tuple/list. Usually the first item is token features.
87
+ if isinstance(tokens, (tuple, list)):
88
+ tokens = tokens[0]
89
+
90
+ # For standard ViT:
91
+ # tokens: [B, 1 + num_patches, C], where the first token is CLS.
92
+ if tokens.ndim == 3:
93
+ expected_num_patches = (h // self.patch_size) * (w // self.patch_size)
94
+
95
+ if tokens.shape[1] == expected_num_patches + 1:
96
+ tokens = tokens[:, 1:, :] # remove CLS token
97
+
98
+ feature_h = h // self.patch_size
99
+ feature_w = w // self.patch_size
100
+
101
+ tokens = tokens.transpose(1, 2)
102
+ feature_map = tokens.reshape(b, self.embed_dim, feature_h, feature_w)
103
+
104
+ # Some backbones may already return [B, C, H, W].
105
+ elif tokens.ndim == 4:
106
+ feature_map = tokens
107
+
108
+ else:
109
+ raise RuntimeError(f"Unexpected ViT feature shape: {tokens.shape}")
110
+
111
+ return feature_map
112
+
113
+ def forward(self, x):
114
+ input_size = x.shape[-2:]
115
+
116
+ feature_map = self.forward_features_as_map(x)
117
+ logits = self.decoder(feature_map)
118
+
119
+ logits = F.interpolate(
120
+ logits,
121
+ size=input_size,
122
+ mode="bilinear",
123
+ align_corners=False,
124
+ )
125
+
126
+ return logits
127
+
128
+
129
+ def build_vit(
130
+ variant="base",
131
+ num_classes=1,
132
+ pretrained=True,
133
+ in_chans=3,
134
+ img_size=512,
135
+ decoder_dim=256,
136
+ dropout=0.0,
137
+ ):
138
+ """
139
+ Build a timm ViT segmentation model.
140
+
141
+ Parameters
142
+ ----------
143
+ variant:
144
+ One of:
145
+ "tiny"
146
+ "small"
147
+ "base"
148
+ "large"
149
+
150
+ Or directly pass a timm model name, e.g.:
151
+ "vit_base_patch16_224"
152
+ "vit_small_patch16_224"
153
+ "vit_large_patch16_224"
154
+
155
+ num_classes:
156
+ Number of output channels.
157
+
158
+ Binary segmentation:
159
+ num_classes=1
160
+
161
+ Multi-class segmentation:
162
+ num_classes=N
163
+
164
+ pretrained:
165
+ Whether to load ImageNet-pretrained timm weights.
166
+
167
+ img_size:
168
+ Input image size. For DRIVE, 512 is a reasonable default.
169
+
170
+ Returns
171
+ -------
172
+ model:
173
+ ViTSegmentationModel
174
+ """
175
+
176
+ variants = {
177
+ "tiny": "vit_tiny_patch16_224",
178
+ "small": "vit_small_patch16_224",
179
+ "base": "vit_base_patch16_224",
180
+ "large": "vit_large_patch16_224",
181
+ }
182
+
183
+ model_name = variants.get(variant, variant)
184
+
185
+ model = ViTSegmentationModel(
186
+ model_name=model_name,
187
+ num_classes=num_classes,
188
+ pretrained=pretrained,
189
+ in_chans=in_chans,
190
+ img_size=img_size,
191
+ decoder_dim=decoder_dim,
192
+ dropout=dropout,
193
+ )
194
+
195
+ return model
196
+
197
+
198
+ if __name__ == "__main__":
199
+ # Smoke test:
200
+ # python models/vit.py
201
+
202
+ device = "cuda" if torch.cuda.is_available() else "cpu"
203
+
204
+ model = build_vit(
205
+ variant="base",
206
+ num_classes=1,
207
+ pretrained=False,
208
+ img_size=512,
209
+ ).to(device)
210
+
211
+ x = torch.randn(2, 3, 512, 512).to(device)
212
+
213
+ with torch.no_grad():
214
+ y = model(x)
215
+
216
+ print("Model:", model.model_name)
217
+ print("Input shape:", x.shape)
218
+ print("Output shape:", y.shape)
219
+ print("Output min/max:", y.min().item(), y.max().item())
220
+
221
+ assert y.shape == (2, 1, 512, 512)
222
+
223
+ print("Smoke test passed.")
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations
2
+ gradio
3
+ huggingface_hub
4
+ numpy
5
+ opencv-python
6
+ pandas
7
+ pillow
8
+ pydantic
9
+ timm
10
+ torch
11
+ torchvision
12
+ torchaudio
train.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+
8
+ from augmentations import get_train_transforms, get_val_transforms
9
+ from datasets.FIVES import FIVESDataset
10
+ from models import build_model
11
+ from losses import BCEDiceLoss, compute_dice_score
12
+
13
+
14
+ def train_one_epoch(model, loader, optimizer, scaler, criterion, device, use_amp=True):
15
+ model.train()
16
+
17
+ running_loss = 0.0
18
+ running_dice = 0.0
19
+
20
+ pbar = tqdm(loader, desc="Train", leave=False)
21
+
22
+ for batch in pbar:
23
+ images = batch["image"].to(device)
24
+ labels = batch["label"].to(device)
25
+
26
+ optimizer.zero_grad(set_to_none=True)
27
+
28
+ with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
29
+ logits = model(images)
30
+ loss = criterion(logits, labels)
31
+
32
+ scaler.scale(loss).backward()
33
+ scaler.step(optimizer)
34
+ scaler.update()
35
+
36
+ dice = compute_dice_score(logits.detach(), labels)
37
+
38
+ running_loss += loss.item()
39
+ running_dice += dice
40
+
41
+ avg_loss = running_loss / (pbar.n + 1)
42
+ avg_dice = running_dice / (pbar.n + 1)
43
+
44
+ pbar.set_postfix(
45
+ loss=f"{avg_loss:.4f}",
46
+ dice=f"{avg_dice:.4f}",
47
+ )
48
+
49
+ return running_loss / len(loader), running_dice / len(loader)
50
+
51
+
52
+ @torch.no_grad()
53
+ def validate(model, loader, criterion, device, use_amp=True):
54
+ model.eval()
55
+
56
+ running_loss = 0.0
57
+ running_dice = 0.0
58
+
59
+ pbar = tqdm(loader, desc="Val", leave=False)
60
+
61
+ for batch in pbar:
62
+ images = batch["image"].to(device)
63
+ labels = batch["label"].to(device)
64
+
65
+ with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
66
+ logits = model(images)
67
+ loss = criterion(logits, labels)
68
+
69
+ dice = compute_dice_score(logits, labels)
70
+
71
+ running_loss += loss.item()
72
+ running_dice += dice
73
+
74
+ avg_loss = running_loss / (pbar.n + 1)
75
+ avg_dice = running_dice / (pbar.n + 1)
76
+
77
+ pbar.set_postfix(
78
+ loss=f"{avg_loss:.4f}",
79
+ dice=f"{avg_dice:.4f}",
80
+ )
81
+
82
+ return running_loss / len(loader), running_dice / len(loader)
83
+
84
+
85
+ def save_checkpoint(path, model, optimizer, epoch, best_dice, args):
86
+ path = Path(path)
87
+ path.parent.mkdir(parents=True, exist_ok=True)
88
+
89
+ torch.save(
90
+ {
91
+ "epoch": epoch,
92
+ "model_state_dict": model.state_dict(),
93
+ "optimizer_state_dict": optimizer.state_dict(),
94
+ "best_dice": best_dice,
95
+ "args": vars(args),
96
+ },
97
+ path,
98
+ )
99
+
100
+
101
+ def main(args):
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ train_dataset = FIVESDataset(
105
+ root=args.data_root,
106
+ split="train",
107
+ transform=get_train_transforms(image_size=args.image_size),
108
+ )
109
+
110
+ val_dataset = FIVESDataset(
111
+ root=args.data_root,
112
+ split="test",
113
+ transform=get_val_transforms(image_size=args.image_size),
114
+ )
115
+
116
+ train_loader = DataLoader(
117
+ train_dataset,
118
+ batch_size=args.batch_size,
119
+ shuffle=True,
120
+ num_workers=args.num_workers,
121
+ pin_memory=True,
122
+ )
123
+
124
+ val_loader = DataLoader(
125
+ val_dataset,
126
+ batch_size=args.batch_size,
127
+ shuffle=False,
128
+ num_workers=args.num_workers,
129
+ pin_memory=True,
130
+ )
131
+
132
+ model = build_model(
133
+ model_name=args.model,
134
+ num_classes=1,
135
+ in_channels=3,
136
+ image_size=args.image_size,
137
+ backbone=args.backbone,
138
+ pretrained=not args.no_pretrained,
139
+ base_channels=args.base_channels,
140
+ dropout=args.dropout,
141
+ ).to(device)
142
+
143
+ criterion = BCEDiceLoss(
144
+ bce_weight=args.bce_weight,
145
+ dice_weight=args.dice_weight,
146
+ )
147
+
148
+ optimizer = torch.optim.AdamW(
149
+ model.parameters(),
150
+ lr=args.lr,
151
+ weight_decay=args.weight_decay,
152
+ )
153
+
154
+ scaler = torch.amp.GradScaler(enabled=args.amp and device.type == "cuda")
155
+
156
+ best_dice = -1.0
157
+
158
+ print(f"Device: {device}")
159
+ print(f"Train samples: {len(train_dataset)}")
160
+ print(f"Val samples: {len(val_dataset)}")
161
+ print(f"Image size: {args.image_size}")
162
+ print(f"Batch size: {args.batch_size}")
163
+ print(f"Pretrained: {not args.no_pretrained}")
164
+
165
+ for epoch in range(1, args.epochs + 1):
166
+ print(f"\nEpoch [{epoch:03d}/{args.epochs}]")
167
+
168
+ train_loss, train_dice = train_one_epoch(
169
+ model=model,
170
+ loader=train_loader,
171
+ optimizer=optimizer,
172
+ scaler=scaler,
173
+ criterion=criterion,
174
+ device=device,
175
+ use_amp=args.amp,
176
+ )
177
+
178
+ val_loss, val_dice = validate(
179
+ model=model,
180
+ loader=val_loader,
181
+ criterion=criterion,
182
+ device=device,
183
+ use_amp=args.amp,
184
+ )
185
+
186
+ print(
187
+ f"train_loss={train_loss:.4f} "
188
+ f"train_dice={train_dice:.4f} "
189
+ f"val_loss={val_loss:.4f} "
190
+ f"val_dice={val_dice:.4f}"
191
+ )
192
+
193
+ if val_dice > best_dice:
194
+ best_dice = val_dice
195
+ save_checkpoint(
196
+ Path(args.output_dir) / "best.pt",
197
+ model,
198
+ optimizer,
199
+ epoch,
200
+ best_dice,
201
+ args,
202
+ )
203
+ print(f"Saved best checkpoint: val_dice={best_dice:.4f}")
204
+
205
+ if epoch % args.save_every == 0:
206
+ save_checkpoint(
207
+ Path(args.output_dir) / f"epoch_{epoch:03d}.pt",
208
+ model,
209
+ optimizer,
210
+ epoch,
211
+ best_dice,
212
+ args,
213
+ )
214
+
215
+ save_checkpoint(
216
+ Path(args.output_dir) / "last.pt",
217
+ model,
218
+ optimizer,
219
+ args.epochs,
220
+ best_dice,
221
+ args,
222
+ )
223
+
224
+ print("Training complete.")
225
+ print(f"Best val Dice: {best_dice:.4f}")
226
+
227
+
228
+ def parse_args():
229
+ parser = argparse.ArgumentParser(description="Train retinal vessel segmentation model on FIVES.")
230
+
231
+ parser.add_argument("--data-root", type=str, required=True)
232
+ parser.add_argument("--output-dir", type=str, default="checkpoints/fives")
233
+ parser.add_argument("--image-size", type=int, default=512)
234
+ parser.add_argument("--epochs", type=int, default=100)
235
+ parser.add_argument("--batch-size", type=int, default=4)
236
+ parser.add_argument("--num-workers", type=int, default=4)
237
+
238
+ parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"])
239
+ parser.add_argument("--backbone", type=str, default="resnet50")
240
+ parser.add_argument("--base-channels", type=int, default=32)
241
+ parser.add_argument("--dropout", type=float, default=0.0)
242
+ parser.add_argument("--no-pretrained", action="store_true")
243
+
244
+ parser.add_argument("--lr", type=float, default=1e-4)
245
+ parser.add_argument("--weight-decay", type=float, default=1e-4)
246
+ parser.add_argument("--bce-weight", type=float, default=1.0)
247
+ parser.add_argument("--dice-weight", type=float, default=1.0)
248
+ parser.add_argument("--save-every", type=int, default=25)
249
+ parser.add_argument("--amp", action="store_true")
250
+
251
+ return parser.parse_args()
252
+
253
+
254
+ if __name__ == "__main__":
255
+ args = parse_args()
256
+ main(args)