himanshuch8055 commited on
Commit
4971505
·
1 Parent(s): e9bd06c

Implement DeepLabV3+ with EfficientNet-B3 for fibril segmentation; add GPU selection, data preparation, and training loop

Browse files
Files changed (1) hide show
  1. training-model/train_fibril_segment.py +1011 -0
training-model/train_fibril_segment.py ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============== Fibril Segmentation — DeepLabV3+ with EfficientNet-B3 ===============
2
+
3
+ import os, random, subprocess
4
+ from glob import glob
5
+ import numpy as np
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
10
+
11
+ import torch
12
+ torch.cuda.empty_cache()
13
+ import torch.nn as nn
14
+ from torch.utils.data import Dataset, DataLoader
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+ import segmentation_models_pytorch as smp
18
+
19
+ import json
20
+ from sklearn.utils import shuffle
21
+ import os
22
+ import subprocess
23
+
24
+ # ─── GPU Selection Function ───────────────────────────────
25
+ def get_free_gpu(threshold_mb=1000):
26
+ try:
27
+ result = subprocess.run(
28
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,nounits,noheader"],
29
+ stdout=subprocess.PIPE, text=True
30
+ )
31
+ for idx, line in enumerate(result.stdout.strip().split("\n")):
32
+ used, total = map(int, line.split(","))
33
+ if total - used > threshold_mb:
34
+ return str(idx)
35
+ except Exception as e:
36
+ print("GPU check failed:", e)
37
+ return None
38
+
39
+ # ─── Find Free GPU BEFORE Defining Config ────────────────
40
+ free_gpu_id = get_free_gpu()
41
+
42
+ # ─── Configurations ───────────────────────────────────────
43
+ config = {
44
+ "seed": 42,
45
+ "img_size": 512,
46
+ "batch_size": 2,
47
+ "num_workers": 4,
48
+ "epochs": 100,
49
+ "lr": 1e-4,
50
+ "train_img_dir": "./alldataset/images",
51
+ "train_mask_dir": "./alldataset/masks",
52
+ "save_path": "./trained-models/encoder_resnest101e_decoder_UnetPlusPlus_fibril_seg_model.pth",
53
+ "gpu_id": free_gpu_id,
54
+ }
55
+
56
+ # ─── GPU Setup ────────────────────────────────────────────
57
+ if config["gpu_id"] is not None:
58
+ os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu_id"]
59
+ print(f"✅ Using GPU ID: {config['gpu_id']}")
60
+ else:
61
+ print("⚠️ No free GPU detected — training may use default device or fail")
62
+
63
+ # ─── Reproducibility ───────────────────────────────────────
64
+ def seed_everything(seed=42):
65
+ random.seed(seed)
66
+ np.random.seed(seed)
67
+ torch.manual_seed(seed)
68
+ torch.cuda.manual_seed_all(seed)
69
+ torch.backends.cudnn.deterministic = True
70
+ torch.backends.cudnn.benchmark = False
71
+
72
+ seed_everything(config["seed"])
73
+
74
+ # ─── Dataset ───────────────────────────────────────────────
75
+ class FibrilSegmentationDataset(torch.utils.data.Dataset):
76
+ def __init__(self, image_paths, mask_paths, transform=None):
77
+ self.image_paths = image_paths
78
+ self.mask_paths = mask_paths
79
+ self.transform = transform
80
+
81
+ def __len__(self): return len(self.image_paths)
82
+
83
+ def __getitem__(self, idx):
84
+ image = np.array(Image.open(self.image_paths[idx]).convert("L"))
85
+ mask = (np.array(Image.open(self.mask_paths[idx]).convert("L")) > 127).astype(np.float32)
86
+ if self.transform:
87
+ aug = self.transform(image=image, mask=mask)
88
+ image, mask = aug['image'], aug['mask']
89
+ return image, mask.unsqueeze(0)
90
+
91
+ # ─── Image-Mask Matcher ────────────────────────────────────
92
+ def match_images_and_masks(img_dir, mask_dir, img_exts=("jpg", "jpeg", "png"), mask_exts=("jpg", "png")):
93
+ image_paths, mask_paths = [], []
94
+ for ext in img_exts:
95
+ for img_path in glob(f"{img_dir}/*.{ext}"):
96
+ base = os.path.splitext(os.path.basename(img_path))[0]
97
+ for mext in mask_exts:
98
+ mask_path = os.path.join(mask_dir, f"{base}-vectors.{mext}")
99
+ if os.path.exists(mask_path):
100
+ image_paths.append(img_path)
101
+ mask_paths.append(mask_path)
102
+ break
103
+ return image_paths, mask_paths
104
+
105
+ # ─── Loss Function ─────────────────────────────────────────
106
+ class DiceBCELoss(nn.Module):
107
+ def __init__(self):
108
+ super().__init__()
109
+ self.bce = nn.BCEWithLogitsLoss()
110
+
111
+ # def forward(self, inputs, targets):
112
+ # inputs = torch.sigmoid(inputs)
113
+ # intersection = (inputs * targets).sum()
114
+ # dice = (2. * intersection + 1e-6) / (inputs.sum() + targets.sum() + 1e-6)
115
+ # return 1 - dice + self.bce(inputs, targets)
116
+
117
+ def forward(self, inputs, targets):
118
+ bce_loss = self.bce(inputs, targets) # Raw logits
119
+ inputs = torch.sigmoid(inputs) # Probabilities for Dice
120
+ intersection = (inputs * targets).sum()
121
+ dice_loss = 1 - (2. * intersection + 1e-6) / (inputs.sum() + targets.sum() + 1e-6)
122
+ return dice_loss + bce_loss
123
+
124
+
125
+ # ─── Metrics ───────────────────────────────────────────────
126
+ @torch.no_grad()
127
+ def dice_coeff(pred, target, smooth=1e-6):
128
+ pred = (torch.sigmoid(pred) > 0.5).float()
129
+ intersection = (pred * target).sum()
130
+ return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
131
+
132
+ @torch.no_grad()
133
+ def iou_score(pred, target, smooth=1e-6):
134
+ pred = (torch.sigmoid(pred) > 0.5).float()
135
+ intersection = (pred * target).sum()
136
+ union = pred.sum() + target.sum() - intersection
137
+ return (intersection + smooth) / (union + smooth)
138
+
139
+ # ─── Data Preparation ──────────────────────────────────────
140
+ # image_paths, mask_paths = match_images_and_masks(config["train_img_dir"], config["train_mask_dir"])
141
+ # split = int(0.8 * len(image_paths))
142
+ # train_imgs, val_imgs = image_paths[:split], image_paths[split:]
143
+ # train_masks, val_masks = mask_paths[:split], mask_paths[split:]
144
+
145
+ # ─── Data Preparation with persistent train/val split ──────
146
+ split_path = "train_val_split.json"
147
+
148
+ if os.path.exists(split_path):
149
+ print(f"Loading saved train/val split from {split_path}")
150
+ with open(split_path, "r") as f:
151
+ split_data = json.load(f)
152
+
153
+ train_imgs = split_data["train_images"]
154
+ train_masks = split_data["train_masks"]
155
+ val_imgs = split_data["val_images"]
156
+ val_masks = split_data["val_masks"]
157
+
158
+ else:
159
+ print("Creating new train/val split and saving it...")
160
+ image_paths, mask_paths = match_images_and_masks(config["train_img_dir"], config["train_mask_dir"])
161
+
162
+ # Shuffle dataset to randomize
163
+ train_val = list(zip(image_paths, mask_paths))
164
+ random.seed(config["seed"])
165
+ random.shuffle(train_val)
166
+ image_paths, mask_paths = zip(*train_val)
167
+
168
+ split = int(0.8 * len(image_paths))
169
+ train_imgs = list(image_paths[:split])
170
+ train_masks = list(mask_paths[:split])
171
+ val_imgs = list(image_paths[split:])
172
+ val_masks = list(mask_paths[split:])
173
+
174
+ split_data = {
175
+ "train_images": train_imgs,
176
+ "train_masks": train_masks,
177
+ "val_images": val_imgs,
178
+ "val_masks": val_masks
179
+ }
180
+
181
+ with open(split_path, "w") as f:
182
+ json.dump(split_data, f, indent=2)
183
+
184
+
185
+ common_norm = A.Normalize(mean=(0.5,), std=(0.5,))
186
+ train_tf = A.Compose([
187
+ A.Resize(config["img_size"], config["img_size"]), A.HorizontalFlip(0.5), A.VerticalFlip(0.5), A.RandomRotate90(0.5),
188
+ A.Affine(scale=(0.9, 1.1), translate_percent=0.05, rotate=(-30, 30), shear=(-5, 5), p=0.5),
189
+ A.RandomBrightnessContrast(0.3), A.ElasticTransform(alpha=1.0, sigma=50.0, approximate=True, p=0.2),
190
+ A.Blur(3, p=0.2), common_norm, ToTensorV2()
191
+ ])
192
+ val_tf = A.Compose([A.Resize(config["img_size"], config["img_size"]), common_norm, ToTensorV2()])
193
+
194
+ train_loader = DataLoader(FibrilSegmentationDataset(train_imgs, train_masks, train_tf),
195
+ batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"])
196
+ val_loader = DataLoader(FibrilSegmentationDataset(val_imgs, val_masks, val_tf),
197
+ batch_size=1, shuffle=False, num_workers=config["num_workers"])
198
+
199
+ print(f"Train samples: {len(train_imgs)} | Batch size: {config['batch_size']}")
200
+ print(f"Steps/epoch: {int(np.ceil(len(train_imgs) / config['batch_size']))}")
201
+
202
+ # ─── Model Setup ──────────────────────────────────────────
203
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ # device = torch.device("cpu")
205
+
206
+ # model = smp.Unet(
207
+ # encoder_name="resnet34",
208
+ # encoder_weights="imagenet",
209
+ # in_channels=1, # grayscale
210
+ # classes=1 # binary segmentation
211
+ # ).to(device)
212
+
213
+ # model = smp.Unet(
214
+ # encoder_name="efficientnet-b3",
215
+ # encoder_weights="imagenet",
216
+ # in_channels=1,
217
+ # classes=1
218
+ # ).to(device)
219
+
220
+ # model = smp.DeepLabV3Plus(
221
+ # encoder_name='efficientnet-b3',
222
+ # encoder_depth=5,
223
+ # encoder_weights='imagenet',
224
+ # decoder_use_norm='batchnorm',
225
+ # decoder_channels=(256, 128, 64, 32, 16),
226
+ # decoder_attention_type=None,
227
+ # decoder_interpolation='nearest',
228
+ # in_channels=1,
229
+ # classes=1,
230
+ # activation=None,
231
+ # aux_params=None
232
+ # ).to(device)
233
+
234
+ # model = smp.Unet(
235
+ # encoder_name="mobilenet_v2", # much lighter than resnet34
236
+ # encoder_weights="imagenet",
237
+ # in_channels=1, # grayscale input
238
+ # classes=1 # binary mask
239
+ # ).to(device)
240
+
241
+ # model = smp.UnetPlusPlus(
242
+ # encoder_name='resnet34',
243
+ # encoder_depth=5,
244
+ # encoder_weights='imagenet',
245
+ # decoder_use_norm='batchnorm',
246
+ # decoder_channels=(256, 128, 64, 32, 16),
247
+ # decoder_attention_type=None,
248
+ # decoder_interpolation='nearest',
249
+ # in_channels=1,
250
+ # classes=1,
251
+ # activation=None,
252
+ # aux_params=None
253
+ # ).to(device)
254
+
255
+ model = smp.UnetPlusPlus(
256
+ encoder_name='resnest101e',
257
+ encoder_depth=5,
258
+ encoder_weights='imagenet',
259
+ decoder_use_norm='batchnorm',
260
+ decoder_channels=(256, 128, 64, 32, 16),
261
+ decoder_attention_type=None,
262
+ decoder_interpolation='nearest',
263
+ in_channels=1,
264
+ classes=1,
265
+ activation=None,
266
+ aux_params=None
267
+ ).to(device)
268
+
269
+ # model = smp.UnetPlusPlus(
270
+ # encoder_name='efficientnet-b3', # Lightweight, solid performance
271
+ # encoder_depth=5, # Standard depth
272
+ # encoder_weights='imagenet', # Useful even for grayscale (see note below)
273
+ # decoder_use_norm='batchnorm', # Recommended for stability
274
+ # decoder_channels=(256, 128, 64, 32, 16), # Deep decoder, good for details
275
+ # decoder_attention_type=None, # Optional, can add SE or SCSE for boost
276
+ # decoder_interpolation='nearest', # Good, avoids checkerboard artifacts
277
+ # in_channels=1, # Correct for grayscale (e.g., EM images)
278
+ # classes=1, # Binary segmentation (fibrils vs background)
279
+ # activation=None, # No activation for logits output
280
+ # aux_params=None # No classification head
281
+ # ).to(device)
282
+
283
+
284
+ loss_fn = DiceBCELoss()
285
+ optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
286
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
287
+
288
+ # ─── Training Loop ─────────────────────────────────────────
289
+ best_dice = 0.0
290
+ os.makedirs(os.path.dirname(config["save_path"]), exist_ok=True)
291
+
292
+ for epoch in range(1, config["epochs"] + 1):
293
+ model.train()
294
+ total_loss, total_dice = 0, 0
295
+
296
+ for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch} - Train"):
297
+ imgs, masks = imgs.to(device), masks.to(device)
298
+ preds = model(imgs)
299
+ loss = loss_fn(preds, masks)
300
+
301
+ optimizer.zero_grad()
302
+ loss.backward()
303
+ nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
304
+ optimizer.step()
305
+
306
+ total_loss += loss.item()
307
+ total_dice += dice_coeff(preds, masks).item()
308
+
309
+ avg_loss = total_loss / len(train_loader)
310
+ avg_dice = total_dice / len(train_loader)
311
+ print(f"[Train] Epoch {epoch} | Loss: {avg_loss:.4f} | Dice: {avg_dice:.4f}")
312
+
313
+ # ─── Validation ────────────────────────────────────────
314
+ model.eval()
315
+ val_loss, val_dice, val_iou = 0, 0, 0
316
+ with torch.no_grad():
317
+ for imgs, masks in val_loader:
318
+ imgs, masks = imgs.to(device), masks.to(device)
319
+ preds = model(imgs)
320
+ val_loss += loss_fn(preds, masks).item()
321
+ val_dice += dice_coeff(preds, masks).item()
322
+ val_iou += iou_score(preds, masks).item()
323
+
324
+ val_loss /= len(val_loader)
325
+ val_dice /= len(val_loader)
326
+ val_iou /= len(val_loader)
327
+ scheduler.step(val_loss)
328
+
329
+ print(f"[Val] Epoch {epoch} | Loss: {val_loss:.4f} | Dice: {val_dice:.4f} | IoU: {val_iou:.4f}")
330
+
331
+ if val_dice > best_dice:
332
+ best_dice = val_dice
333
+ torch.save(model.state_dict(), config["save_path"])
334
+ print(f"✅ Saved Best Model (Epoch {epoch} - Dice: {val_dice:.4f})")
335
+
336
+
337
+
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+
352
+
353
+
354
+
355
+
356
+
357
+
358
+
359
+ # import os
360
+ # import random
361
+ # import subprocess
362
+ # from glob import glob
363
+
364
+ # import numpy as np
365
+ # from PIL import Image
366
+ # from tqdm import tqdm
367
+
368
+ # import torch
369
+ # import torch.nn as nn
370
+ # from torch.utils.data import Dataset, DataLoader
371
+ # from torch.cuda.amp import autocast, GradScaler
372
+
373
+ # import albumentations as A
374
+ # from albumentations.pytorch import ToTensorV2
375
+ # import segmentation_models_pytorch as smp
376
+
377
+ # # ─── Select Free GPU ──────────────────────────────────────
378
+ # def get_free_gpu(threshold_mb=500):
379
+ # try:
380
+ # result = subprocess.run(
381
+ # ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,nounits,noheader"],
382
+ # stdout=subprocess.PIPE, text=True
383
+ # )
384
+ # for idx, line in enumerate(result.stdout.strip().split("\n")):
385
+ # used, total = map(int, line.strip().split(","))
386
+ # if total - used > threshold_mb:
387
+ # return str(idx)
388
+ # except Exception as e:
389
+ # print("GPU check failed:", e)
390
+ # return None
391
+
392
+ # free_gpu = get_free_gpu()
393
+ # if free_gpu is not None:
394
+ # os.environ["CUDA_VISIBLE_DEVICES"] = free_gpu
395
+ # print(f"Using GPU {free_gpu}")
396
+ # else:
397
+ # print("No free GPU found — training may fail due to lack of memory")
398
+
399
+ # # ─── Seed Everything ──────────────────────────────────────
400
+ # def seed_everything(seed=42):
401
+ # random.seed(seed)
402
+ # np.random.seed(seed)
403
+ # torch.manual_seed(seed)
404
+ # torch.cuda.manual_seed_all(seed)
405
+ # torch.backends.cudnn.deterministic = True
406
+ # torch.backends.cudnn.benchmark = False
407
+
408
+ # seed_everything()
409
+
410
+ # # ─── Dataset ──────────────────────────────────────────────
411
+ # class FibrilSegmentationDataset(Dataset):
412
+ # def __init__(self, image_paths, mask_paths, transform=None):
413
+ # self.image_paths = image_paths
414
+ # self.mask_paths = mask_paths
415
+ # self.transform = transform
416
+
417
+ # def __len__(self):
418
+ # return len(self.image_paths)
419
+
420
+ # def __getitem__(self, idx):
421
+ # image = Image.open(self.image_paths[idx]).convert("L")
422
+ # mask = Image.open(self.mask_paths[idx]).convert("L")
423
+
424
+ # image = np.array(image)
425
+ # mask = (np.array(mask) > 127).astype(np.float32)
426
+
427
+ # if self.transform:
428
+ # augmented = self.transform(image=image, mask=mask)
429
+ # image = augmented['image']
430
+ # mask = augmented['mask']
431
+
432
+ # return image, mask.unsqueeze(0) # [1, H, W]
433
+
434
+ # # ─── Match Image-Mask ─────────────────────────────────────
435
+ # def match_images_and_masks(img_dir, mask_dir, img_exts=("jpg", "jpeg", "png"), mask_exts=("jpg", "png")):
436
+ # image_paths, mask_paths = [], []
437
+ # for ext in img_exts:
438
+ # for img_path in glob(f"{img_dir}/*.{ext}"):
439
+ # base_name = os.path.splitext(os.path.basename(img_path))[0]
440
+ # for mask_ext in mask_exts:
441
+ # possible_mask = os.path.join(mask_dir, f"{base_name}-vectors.{mask_ext}")
442
+ # if os.path.exists(possible_mask):
443
+ # image_paths.append(img_path)
444
+ # mask_paths.append(possible_mask)
445
+ # break
446
+ # return image_paths, mask_paths
447
+
448
+ # # ─── Loss Function ────────────────────────────────────────
449
+ # class DiceBCELoss(nn.Module):
450
+ # def __init__(self):
451
+ # super().__init__()
452
+ # self.bce = nn.BCEWithLogitsLoss()
453
+
454
+ # def forward(self, inputs, targets):
455
+ # smooth = 1e-6
456
+ # inputs = torch.sigmoid(inputs)
457
+ # intersection = (inputs * targets).sum()
458
+ # dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
459
+ # return 1 - dice + self.bce(inputs, targets)
460
+
461
+ # # ─── Data ─────────────────────────────────────────────────
462
+ # image_paths, mask_paths = match_images_and_masks("./dataset4/images", "./dataset4/masks")
463
+
464
+ # split = int(0.8 * len(image_paths))
465
+ # train_imgs, val_imgs = image_paths[:split], image_paths[split:]
466
+ # train_masks, val_masks = mask_paths[:split], mask_paths[split:]
467
+
468
+ # common_normalization = A.Normalize(mean=(0.5,), std=(0.5,))
469
+ # train_transform = A.Compose([
470
+ # A.Resize(512, 512),
471
+ # A.HorizontalFlip(p=0.5),
472
+ # A.VerticalFlip(p=0.5),
473
+ # A.RandomRotate90(p=0.5),
474
+ # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.05), rotate=(-30, 30), shear=(-5, 5), p=0.5),
475
+ # A.RandomBrightnessContrast(p=0.3),
476
+ # A.ElasticTransform(alpha=1.0, sigma=50.0, approximate=True, p=0.2),
477
+ # A.Blur(blur_limit=3, p=0.2),
478
+ # common_normalization,
479
+ # ToTensorV2()
480
+ # ])
481
+
482
+ # val_transform = A.Compose([
483
+ # A.Resize(512, 512),
484
+ # common_normalization,
485
+ # ToTensorV2()
486
+ # ])
487
+
488
+ # train_ds = FibrilSegmentationDataset(train_imgs, train_masks, train_transform)
489
+ # val_ds = FibrilSegmentationDataset(val_imgs, val_masks, val_transform)
490
+
491
+ # train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
492
+ # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)
493
+
494
+ # # ─── Model ────────────────────────────────────────────────
495
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
496
+
497
+ # model = smp.DeepLabV3Plus(
498
+ # encoder_name="efficientnet-b3",
499
+ # encoder_weights="imagenet",
500
+ # in_channels=1,
501
+ # classes=1
502
+ # ).to(device)
503
+
504
+ # loss_fn = DiceBCELoss()
505
+ # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
506
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
507
+ # scaler = GradScaler()
508
+
509
+ # # ─── Metrics ───────────────────────────────────────────────
510
+ # def dice_coeff(pred, target, smooth=1e-6):
511
+ # pred = torch.sigmoid(pred)
512
+ # pred = (pred > 0.5).float()
513
+ # intersection = (pred * target).sum()
514
+ # return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
515
+
516
+ # def iou_score(pred, target, smooth=1e-6):
517
+ # pred = torch.sigmoid(pred)
518
+ # pred = (pred > 0.5).float()
519
+ # intersection = (pred * target).sum()
520
+ # union = pred.sum() + target.sum() - intersection
521
+ # return (intersection + smooth) / (union + smooth)
522
+
523
+ # # ─── Training ──────────────────────────────────────────────
524
+ # best_dice = 0.0
525
+ # os.makedirs("./trained-models", exist_ok=True)
526
+
527
+ # for epoch in range(1, 101):
528
+ # model.train()
529
+ # total_loss, total_dice = 0, 0
530
+
531
+ # for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch} - Train"):
532
+ # imgs, masks = imgs.to(device), masks.to(device)
533
+
534
+ # optimizer.zero_grad()
535
+ # with autocast():
536
+ # preds = model(imgs)
537
+ # loss = loss_fn(preds, masks)
538
+
539
+ # scaler.scale(loss).backward()
540
+ # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
541
+ # scaler.step(optimizer)
542
+ # scaler.update()
543
+
544
+ # total_loss += loss.item()
545
+ # total_dice += dice_coeff(preds, masks).item()
546
+
547
+ # avg_loss = total_loss / len(train_loader)
548
+ # avg_dice = total_dice / len(train_loader)
549
+ # print(f"[Train] Epoch {epoch} | Loss: {avg_loss:.4f} | Dice: {avg_dice:.4f}")
550
+
551
+ # model.eval()
552
+ # val_loss, val_dice, val_iou = 0, 0, 0
553
+ # with torch.no_grad():
554
+ # for imgs, masks in val_loader:
555
+ # imgs, masks = imgs.to(device), masks.to(device)
556
+ # preds = model(imgs)
557
+ # val_loss += loss_fn(preds, masks).item()
558
+ # val_dice += dice_coeff(preds, masks).item()
559
+ # val_iou += iou_score(preds, masks).item()
560
+
561
+ # val_loss /= len(val_loader)
562
+ # val_dice /= len(val_loader)
563
+ # val_iou /= len(val_loader)
564
+ # scheduler.step(val_loss)
565
+
566
+ # print(f"[Val] Epoch {epoch} | Loss: {val_loss:.4f} | Dice: {val_dice:.4f} | IoU: {val_iou:.4f}")
567
+
568
+ # if val_dice > best_dice:
569
+ # best_dice = val_dice
570
+ # torch.save(model.state_dict(), f"./trained-models/fibril_epoch{epoch}_dice{val_dice:.4f}.pth")
571
+ # print(f"✅ Saved Best Model (Epoch {epoch} - Dice: {val_dice:.4f})")
572
+
573
+
574
+
575
+
576
+
577
+
578
+
579
+
580
+
581
+ # # # =============== Working fine with Gary images (UNet model with ResNet34 as the encoder ===================
582
+ # # # =============== Encoder (ResNet34) and Decoder (UNet)==============
583
+
584
+
585
+ # import os
586
+ # import random
587
+ # from glob import glob
588
+ # import numpy as np
589
+ # from PIL import Image
590
+ # from tqdm import tqdm
591
+ # from itertools import chain
592
+
593
+ # import torch
594
+ # import torch.nn as nn
595
+ # from torch.utils.data import Dataset, DataLoader
596
+
597
+ # import albumentations as A
598
+ # from albumentations.pytorch import ToTensorV2
599
+ # import segmentation_models_pytorch as smp
600
+
601
+ # import subprocess
602
+ # import os
603
+
604
+ # # Force GPU selection if available
605
+ # # import os
606
+ # # os.environ["CUDA_VISIBLE_DEVICES"] = "3" # Change '3' to any free GPU ID
607
+
608
+ # def get_free_gpu(threshold_mb=500):
609
+ # try:
610
+ # result = subprocess.run(
611
+ # ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,nounits,noheader"],
612
+ # stdout=subprocess.PIPE, text=True
613
+ # )
614
+ # for idx, line in enumerate(result.stdout.strip().split("\n")):
615
+ # used, total = map(int, line.strip().split(","))
616
+ # if total - used > threshold_mb:
617
+ # return str(idx)
618
+ # except Exception as e:
619
+ # print("GPU check failed:", e)
620
+ # return None
621
+
622
+ # # free_gpu = get_free_gpu()
623
+ # free_gpu = "5"
624
+ # if free_gpu is not None:
625
+ # os.environ["CUDA_VISIBLE_DEVICES"] = free_gpu
626
+ # print(f"Using GPU {free_gpu}")
627
+ # else:
628
+ # print("No free GPU found — training may fail due to lack of memory")
629
+
630
+
631
+ # # ─── Seed for Reproducibility ─────────────────────────────
632
+ # def seed_everything(seed=42):
633
+ # random.seed(seed)
634
+ # np.random.seed(seed)
635
+ # torch.manual_seed(seed)
636
+ # torch.cuda.manual_seed_all(seed)
637
+ # torch.backends.cudnn.deterministic = True
638
+ # torch.backends.cudnn.benchmark = False
639
+
640
+ # seed_everything()
641
+
642
+ # # ─── Dataset ──────────────────────────────────────────────
643
+ # class FibrilSegmentationDataset(Dataset):
644
+ # def __init__(self, image_paths, mask_paths, transform=None):
645
+ # self.image_paths = image_paths
646
+ # self.mask_paths = mask_paths
647
+ # self.transform = transform
648
+
649
+ # def __len__(self):
650
+ # return len(self.image_paths)
651
+
652
+ # def __getitem__(self, idx):
653
+ # image = Image.open(self.image_paths[idx]).convert("L")
654
+ # mask = Image.open(self.mask_paths[idx]).convert("L")
655
+
656
+ # image = np.array(image)
657
+ # mask = (np.array(mask) > 127).astype(np.float32)
658
+
659
+ # if self.transform:
660
+ # augmented = self.transform(image=image, mask=mask)
661
+ # image = augmented['image']
662
+ # mask = augmented['mask']
663
+
664
+ # return image, mask.unsqueeze(0) # [1, H, W]
665
+
666
+ # # ─── Utility to Match Image-Mask Pairs ─────────────────────
667
+ # def match_images_and_masks(img_dir, mask_dir, img_exts=("jpg", "jpeg", "png"), mask_exts=("jpg", "png")):
668
+ # image_paths, mask_paths = [], []
669
+
670
+ # for ext in img_exts:
671
+ # for img_path in glob(f"{img_dir}/*.{ext}"):
672
+ # base_name = os.path.splitext(os.path.basename(img_path))[0]
673
+ # for mask_ext in mask_exts:
674
+ # # possible_mask = os.path.join(mask_dir, f"{base_name}_mask.{mask_ext}")
675
+ # possible_mask = os.path.join(mask_dir, f"{base_name}-vectors.{mask_ext}")
676
+ # if os.path.exists(possible_mask):
677
+ # image_paths.append(img_path)
678
+ # mask_paths.append(possible_mask)
679
+ # break # Stop after first match
680
+
681
+ # return image_paths, mask_paths
682
+
683
+
684
+ # class DiceBCELoss(nn.Module):
685
+ # def __init__(self):
686
+ # super().__init__()
687
+ # self.bce = nn.BCEWithLogitsLoss()
688
+
689
+ # def forward(self, inputs, targets):
690
+ # smooth = 1e-6
691
+ # inputs = torch.sigmoid(inputs)
692
+ # intersection = (inputs * targets).sum()
693
+ # dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
694
+ # return 1 - dice + self.bce(inputs, targets)
695
+
696
+
697
+ # # ─── Load Dataset ──────────────────────────────────────────
698
+ # image_paths, mask_paths = match_images_and_masks("./dataset4/images", "./dataset4/masks")
699
+
700
+ # split = int(0.8 * len(image_paths))
701
+ # train_imgs, val_imgs = image_paths[:split], image_paths[split:]
702
+ # train_masks, val_masks = mask_paths[:split], mask_paths[split:]
703
+
704
+ # # ─── Transformations ──────────────────────────────────────
705
+ # common_normalization = A.Normalize(mean=(0.5,), std=(0.5,))
706
+ # train_transform = A.Compose([
707
+ # A.Resize(512, 512),
708
+ # A.HorizontalFlip(p=0.5),
709
+ # A.VerticalFlip(p=0.5),
710
+ # A.RandomRotate90(p=0.5),
711
+ # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.05), rotate=(-30, 30), shear=(-5, 5), p=0.5),
712
+ # A.RandomBrightnessContrast(p=0.3),
713
+ # A.ElasticTransform(alpha=1.0, sigma=50.0, approximate=True, p=0.2),
714
+ # A.Blur(blur_limit=3, p=0.2),
715
+ # common_normalization,
716
+ # ToTensorV2()
717
+ # ])
718
+
719
+ # val_transform = A.Compose([
720
+ # A.Resize(512, 512),
721
+ # common_normalization,
722
+ # ToTensorV2()
723
+ # ])
724
+
725
+ # # ─── Datasets & Loaders ───────────────────────────────────
726
+ # train_ds = FibrilSegmentationDataset(train_imgs, train_masks, train_transform)
727
+ # val_ds = FibrilSegmentationDataset(val_imgs, val_masks, val_transform)
728
+
729
+ # # train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4)
730
+ # # train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
731
+ # # For training (20 samples):
732
+ # train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
733
+
734
+ # print(f"Train samples: {len(train_ds)}")
735
+ # print(f"Batch size: {train_loader.batch_size}")
736
+ # print(f"Expected steps per epoch: {int(np.ceil(len(train_ds)/train_loader.batch_size))}")
737
+
738
+ # # val_loader = DataLoader(val_ds, batch_size=8, num_workers=4)
739
+ # # For validation (5 samples):
740
+ # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)
741
+
742
+ # # ─── Model Setup ──────────────────────────────────────────
743
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
744
+ # # device = torch.device("cpu")
745
+
746
+ # # model = smp.Unet(
747
+ # # encoder_name="resnet34",
748
+ # # encoder_weights="imagenet",
749
+ # # in_channels=1, # grayscale
750
+ # # classes=1 # binary segmentation
751
+ # # ).to(device)
752
+
753
+ # # model = smp.Unet(
754
+ # # encoder_name="efficientnet-b3",
755
+ # # encoder_weights="imagenet",
756
+ # # in_channels=1,
757
+ # # classes=1
758
+ # # ).to(device)
759
+
760
+ # model = smp.DeepLabV3Plus(
761
+ # encoder_name="efficientnet-b3",
762
+ # encoder_weights="imagenet",
763
+ # in_channels=1,
764
+ # classes=1
765
+ # ).to(device)
766
+
767
+ # # model = smp.Unet(
768
+ # # encoder_name="mobilenet_v2", # much lighter than resnet34
769
+ # # encoder_weights="imagenet",
770
+ # # in_channels=1, # grayscale input
771
+ # # classes=1 # binary mask
772
+ # # ).to(device)
773
+
774
+ # # loss_fn = nn.BCEWithLogitsLoss()
775
+ # loss_fn = DiceBCELoss()
776
+ # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
777
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
778
+
779
+ # # ─── Metrics ───────────────────────────────────────────────
780
+ # def dice_coeff(pred, target, smooth=1e-6):
781
+ # pred = torch.sigmoid(pred)
782
+ # pred = (pred > 0.5).float()
783
+ # intersection = (pred * target).sum()
784
+ # return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
785
+
786
+ # def iou_score(pred, target, smooth=1e-6):
787
+ # pred = torch.sigmoid(pred)
788
+ # pred = (pred > 0.5).float()
789
+ # intersection = (pred * target).sum()
790
+ # union = pred.sum() + target.sum() - intersection
791
+ # return (intersection + smooth) / (union + smooth)
792
+
793
+ # # ─── Training Loop ─────────────────────────────────────────
794
+ # best_dice = 0.0
795
+ # os.makedirs("./trained-models", exist_ok=True)
796
+
797
+ # for epoch in range(1, 101):
798
+ # model.train()
799
+ # total_loss, total_dice = 0, 0
800
+
801
+ # for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch} - Train"):
802
+ # imgs, masks = imgs.to(device), masks.to(device)
803
+
804
+ # preds = model(imgs)
805
+ # loss = loss_fn(preds, masks)
806
+
807
+ # optimizer.zero_grad()
808
+ # loss.backward()
809
+ # nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
810
+ # optimizer.step()
811
+
812
+ # total_loss += loss.item()
813
+ # total_dice += dice_coeff(preds, masks).item()
814
+
815
+ # avg_loss = total_loss / len(train_loader)
816
+ # avg_dice = total_dice / len(train_loader)
817
+ # print(f"[Train] Epoch {epoch} | Loss: {avg_loss:.4f} | Dice: {avg_dice:.4f}")
818
+
819
+ # # Validation
820
+ # model.eval()
821
+ # val_loss, val_dice, val_iou = 0, 0, 0
822
+ # with torch.no_grad():
823
+ # for imgs, masks in val_loader:
824
+ # imgs, masks = imgs.to(device), masks.to(device)
825
+ # preds = model(imgs)
826
+ # val_loss += loss_fn(preds, masks).item()
827
+ # val_dice += dice_coeff(preds, masks).item()
828
+ # val_iou += iou_score(preds, masks).item()
829
+
830
+ # val_loss /= len(val_loader)
831
+ # val_dice /= len(val_loader)
832
+ # val_iou /= len(val_loader)
833
+ # scheduler.step(val_loss)
834
+
835
+ # print(f"[Val] Epoch {epoch} | Loss: {val_loss:.4f} | Dice: {val_dice:.4f} | IoU: {val_iou:.4f}")
836
+
837
+ # # Save best model
838
+ # if val_dice > best_dice:
839
+ # best_dice = val_dice
840
+ # torch.save(model.state_dict(), "./trained-models/amalesh_encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth")
841
+ # print(f"✅ Saved Best Model (Epoch {epoch} - Dice: {val_dice:.4f})")
842
+
843
+
844
+
845
+
846
+
847
+
848
+
849
+
850
+ # # Working on the gray images fine
851
+
852
+ # # =============== Working fine with Gary images (UNet model with ResNet34 as the encoder ===================
853
+ # # =============== Encoder (ResNet34) and Decoder (UNet)==============
854
+
855
+
856
+ # import os
857
+ # from glob import glob
858
+ # import numpy as np
859
+ # from PIL import Image
860
+ # from tqdm import tqdm
861
+
862
+ # import torch
863
+ # import torch.nn as nn
864
+ # from torch.utils.data import Dataset, DataLoader
865
+
866
+ # import albumentations as A
867
+ # from albumentations.pytorch import ToTensorV2
868
+ # import segmentation_models_pytorch as smp
869
+
870
+ # # ─── Dataset ────────────────────────────
871
+ # class FibrilSegmentationDataset(Dataset):
872
+ # def __init__(self, image_paths, mask_paths, transform=None):
873
+ # self.image_paths = image_paths
874
+ # self.mask_paths = mask_paths
875
+ # self.transform = transform
876
+
877
+ # def __len__(self):
878
+ # return len(self.image_paths)
879
+
880
+ # def __getitem__(self, idx):
881
+ # # Load grayscale image and mask
882
+ # image = Image.open(self.image_paths[idx]).convert("L")
883
+ # mask = Image.open(self.mask_paths[idx]).convert("L")
884
+
885
+ # image = image.resize((512, 512))
886
+ # mask = mask.resize((512, 512))
887
+
888
+ # image = np.array(image)
889
+ # mask = np.array(mask)
890
+
891
+ # # Binarize mask
892
+ # mask = (mask > 127).astype(np.float32)
893
+
894
+ # if self.transform:
895
+ # augmented = self.transform(image=image, mask=mask)
896
+ # image = augmented["image"]
897
+ # mask = augmented["mask"]
898
+
899
+ # # image shape: [1, H, W], mask shape: [H, W]
900
+ # return image, mask.unsqueeze(0)
901
+
902
+ # # ─── Paths ─────────────────────────────
903
+ # image_paths = sorted(glob("./dataset/images/*.jpg"))
904
+ # mask_paths = sorted(glob("./dataset/masks/*.jpg"))
905
+
906
+ # split = int(0.8 * len(image_paths))
907
+ # train_imgs, val_imgs = image_paths[:split], image_paths[split:]
908
+ # train_masks, val_masks = mask_paths[:split], mask_paths[split:]
909
+
910
+ # # ─── Augmentations ─────────────────────
911
+ # train_transform = A.Compose([
912
+ # A.Resize(512, 512),
913
+ # A.HorizontalFlip(p=0.5),
914
+ # A.VerticalFlip(p=0.5),
915
+ # A.RandomRotate90(p=0.5),
916
+ # A.Affine(
917
+ # scale=(0.9, 1.1),
918
+ # translate_percent=(0.05, 0.05),
919
+ # rotate=(-30, 30),
920
+ # shear=(-5, 5),
921
+ # p=0.5
922
+ # ),
923
+ # A.RandomBrightnessContrast(
924
+ # brightness_limit=0.2,
925
+ # contrast_limit=0.2,
926
+ # p=0.3
927
+ # ),
928
+ # A.ElasticTransform(
929
+ # alpha=1.0,
930
+ # sigma=50.0,
931
+ # approximate=True,
932
+ # p=0.2
933
+ # ),
934
+ # A.Blur(blur_limit=3, p=0.2),
935
+ # A.Normalize(mean=(0.5,), std=(0.5,)),
936
+ # ToTensorV2()
937
+ # ])
938
+
939
+ # val_transform = A.Compose([
940
+ # A.Resize(512, 512),
941
+ # A.Normalize(mean=(0.5,), std=(0.5,)),
942
+ # ToTensorV2()
943
+ # ])
944
+
945
+ # train_ds = FibrilSegmentationDataset(train_imgs, train_masks, transform=train_transform)
946
+ # val_ds = FibrilSegmentationDataset(val_imgs, val_masks, transform=val_transform)
947
+
948
+ # train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
949
+ # val_loader = DataLoader(val_ds, batch_size=4, num_workers=4)
950
+
951
+ # # ─── Model ───────────────────────────────
952
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
953
+
954
+ # model = smp.Unet(
955
+ # encoder_name="resnet34",
956
+ # encoder_weights="imagenet",
957
+ # in_channels=1, # grayscale input
958
+ # classes=1 # binary segmentation
959
+ # ).to(device)
960
+
961
+ # loss_fn = nn.()
962
+ # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
963
+
964
+ # # ─── Metrics ─────────────────────────────
965
+ # def dice_coeff(pred, target, smooth=1e-6):
966
+ # pred = torch.sigmoid(pred)
967
+ # pred = (pred > 0.5).float()
968
+ # intersection = (pred * target).sum()
969
+ # return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
970
+
971
+ # # ─── Train Loop ──────────────────────────
972
+ # for epoch in range(1, 100):
973
+ # model.train()
974
+ # total_loss = 0
975
+ # total_dice = 0
976
+
977
+ # for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch} - Train"):
978
+ # imgs, masks = imgs.to(device), masks.to(device)
979
+
980
+ # preds = model(imgs)
981
+ # loss = loss_fn(preds, masks)
982
+
983
+ # optimizer.zero_grad()
984
+ # loss.backward()
985
+ # optimizer.step()
986
+
987
+ # total_loss += loss.item()
988
+ # total_dice += dice_coeff(preds, masks).item()
989
+
990
+ # avg_loss = total_loss / len(train_loader)
991
+ # avg_dice = total_dice / len(train_loader)
992
+ # print(f"Epoch {epoch} - Train Loss: {avg_loss:.4f}, Dice: {avg_dice:.4f}")
993
+
994
+ # # Validation
995
+ # model.eval()
996
+ # val_loss = 0
997
+ # val_dice = 0
998
+ # with torch.no_grad():
999
+ # for imgs, masks in val_loader:
1000
+ # imgs, masks = imgs.to(device), masks.to(device)
1001
+ # preds = model(imgs)
1002
+ # loss = loss_fn(preds, masks)
1003
+ # val_loss += loss.item()
1004
+ # val_dice += dice_coeff(preds, masks).item()
1005
+
1006
+ # val_loss /= len(val_loader)
1007
+ # val_dice /= len(val_loader)
1008
+ # print(f"Epoch {epoch} - Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")
1009
+
1010
+ # torch.save(model.state_dict(), "./trained-models/fibril_seg_model.pth")
1011
+ # print("✅ Model saved as fibril_seg_model.pth")