nithishbasireddy commited on
Commit
1bac9b8
Β·
verified Β·
1 Parent(s): 60c79ca

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +482 -0
train.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EL Defect Detection β€” Training Script for RTX 4060 (8GB VRAM)
3
+
4
+ Model: U-Net++ with EfficientNet-B4 encoder + scSE attention
5
+ Dataset: E-SCDD (snt-ubix/e-scdd) β€” 903 images, 512x512
6
+ Loss: 0.5 * Dice + 0.5 * Focal (handles severe class imbalance)
7
+ Classes: 0=background, 1=busbar, 2=crack, 3=dark/inactive, 4=other_defects
8
+
9
+ Usage:
10
+ pip install torch torchvision segmentation-models-pytorch albumentations \
11
+ huggingface-hub scikit-image scipy opencv-python-headless pillow
12
+ python train.py
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import json
18
+ import time
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.utils.data import Dataset, DataLoader
23
+ from torch.optim import AdamW
24
+ from torch.optim.lr_scheduler import CosineAnnealingLR
25
+ from pathlib import Path
26
+ from PIL import Image
27
+
28
+ import segmentation_models_pytorch as smp
29
+ import albumentations as A
30
+ from albumentations.pytorch import ToTensorV2
31
+
32
+
33
+ # ═══════════════════════════════════════════════════════════════
34
+ # CONFIGURATION
35
+ # ═══════════════════════════════════════════════════════════════
36
+
37
+ class Config:
38
+ # Data
39
+ DATA_DIR = "./data" # Will download here
40
+ OUTPUT_DIR = "./output"
41
+
42
+ # Model β€” U-Net++ with EfficientNet-B4 is SOTA for thin-crack segmentation
43
+ # Dense skip connections preserve fine details that plain U-Net misses
44
+ ARCHITECTURE = "UnetPlusPlus" # UnetPlusPlus > Unet for thin structures
45
+ ENCODER = "efficientnet-b4" # Best accuracy/size ratio, 20.9M params
46
+ ENCODER_WEIGHTS = "imagenet"
47
+ IN_CHANNELS = 1 # EL images are grayscale
48
+ NUM_CLASSES = 5 # bg, busbar, crack, dark, other_defects
49
+
50
+ # Training β€” tuned for RTX 4060 (8GB VRAM)
51
+ IMG_SIZE = 512 # E-SCDD native resolution
52
+ BATCH_SIZE = 4 # Safe for 8GB with AMP
53
+ NUM_EPOCHS = 100
54
+ ENCODER_LR = 1e-4 # Lower LR for pretrained encoder
55
+ DECODER_LR = 5e-4 # Higher LR for random decoder
56
+ WEIGHT_DECAY = 1e-4
57
+ USE_AMP = True # Mixed precision β€” halves VRAM usage
58
+ NUM_WORKERS = 4
59
+ GRADIENT_CLIP = 1.0
60
+
61
+ # Loss
62
+ DICE_WEIGHT = 0.5
63
+ FOCAL_WEIGHT = 0.5
64
+ FOCAL_GAMMA = 2.0
65
+
66
+ # Hub
67
+ HUB_MODEL_ID = None # Set to "username/model-name" to push
68
+ PUSH_TO_HUB = False
69
+
70
+ # Class names
71
+ CLASS_NAMES = ["background", "busbar", "crack", "dark", "other_defect"]
72
+
73
+
74
+ # ═══════════════════════════════════════════════════════════════
75
+ # CLASS MAPPING: E-SCDD 30 classes β†’ 5 classes
76
+ # ═══════════════════════════════════════════════════════════════
77
+
78
+ # Mask pixel values in E-SCDD are integers 0-29 (Label column in CSV)
79
+ # We remap to 5 meaningful classes:
80
+ # 0 = background (all spacing, borders, padding, text, clamp, frame, jbox)
81
+ # 1 = busbar (label 9)
82
+ # 2 = crack (label 14=crack, label 10=crack_rbn_edge)
83
+ # 3 = dark/inactive (label 11=inactive, label 17=dead_cell, label 20=edge_dark)
84
+ # 4 = other_defect (rings, material, gridline, splice, corrosion, belt_mark, etc.)
85
+
86
+ LABEL_REMAP = np.zeros(30, dtype=np.uint8) # default: everything β†’ 0 (background)
87
+
88
+ # Background features (labels 0-8, 21-24, 29)
89
+ # Already 0 by default
90
+
91
+ # Busbar
92
+ LABEL_REMAP[9] = 1 # busbars β†’ busbar
93
+
94
+ # Crack (HIGH IMPORTANCE)
95
+ LABEL_REMAP[10] = 2 # crack_rbn_edge β†’ crack
96
+ LABEL_REMAP[14] = 2 # crack β†’ crack
97
+
98
+ # Dark/Inactive (HIGH IMPORTANCE)
99
+ LABEL_REMAP[11] = 3 # inactive β†’ dark
100
+ LABEL_REMAP[17] = 3 # dead_cell β†’ dark
101
+ LABEL_REMAP[20] = 3 # edge_dark β†’ dark
102
+
103
+ # Other defects (MEDIUM IMPORTANCE)
104
+ LABEL_REMAP[12] = 4 # rings
105
+ LABEL_REMAP[13] = 4 # material
106
+ LABEL_REMAP[15] = 4 # gridline defect
107
+ LABEL_REMAP[16] = 4 # splice
108
+ LABEL_REMAP[18] = 4 # corrosion_rbn
109
+ LABEL_REMAP[19] = 4 # belt_mark
110
+ LABEL_REMAP[25] = 4 # scuff
111
+ LABEL_REMAP[26] = 4 # corrosion_cell
112
+ LABEL_REMAP[27] = 4 # brightening
113
+ LABEL_REMAP[28] = 4 # star
114
+
115
+
116
+ # ═══════════════════════════════════════════════════════════════
117
+ # DATASET
118
+ # ═══════════════════════════════════════════════════════════════
119
+
120
+ class ESCDDDataset(Dataset):
121
+ """
122
+ E-SCDD dataset: 512x512 EL images (RGBA) + grayscale masks (L, values 0-29).
123
+ """
124
+
125
+ def __init__(self, img_dir, mask_dir, transform=None):
126
+ self.img_dir = Path(img_dir)
127
+ self.mask_dir = Path(mask_dir)
128
+ self.transform = transform
129
+
130
+ # Match images to masks by filename
131
+ img_files = {f.stem: f for f in sorted(self.img_dir.glob("*.png"))}
132
+ mask_files = {f.stem: f for f in sorted(self.mask_dir.glob("*.png"))}
133
+
134
+ self.pairs = []
135
+ for stem in img_files:
136
+ if stem in mask_files:
137
+ self.pairs.append((img_files[stem], mask_files[stem]))
138
+
139
+ print(f" {img_dir}: {len(self.pairs)} image-mask pairs")
140
+
141
+ def __len__(self):
142
+ return len(self.pairs)
143
+
144
+ def __getitem__(self, idx):
145
+ img_path, mask_path = self.pairs[idx]
146
+
147
+ # Load image β€” RGBA, convert to grayscale
148
+ img = np.array(Image.open(img_path).convert("L"), dtype=np.float32)
149
+
150
+ # Load mask β€” grayscale, pixel value = class label (0-29)
151
+ mask = np.array(Image.open(mask_path), dtype=np.uint8)
152
+
153
+ # Remap 30 β†’ 5 classes using lookup table
154
+ mask = LABEL_REMAP[np.clip(mask, 0, 29)]
155
+
156
+ # Apply augmentations
157
+ if self.transform:
158
+ augmented = self.transform(image=img, mask=mask)
159
+ img = augmented["image"] # (1, H, W) float tensor
160
+ mask = augmented["mask"] # (H, W) long tensor
161
+ else:
162
+ img = torch.from_numpy(img).unsqueeze(0) / 255.0
163
+ mask = torch.from_numpy(mask).long()
164
+
165
+ return img, mask
166
+
167
+
168
+ def get_train_transforms(img_size=512):
169
+ return A.Compose([
170
+ A.RandomCrop(img_size, img_size, p=1.0),
171
+ A.HorizontalFlip(p=0.5),
172
+ A.VerticalFlip(p=0.5),
173
+ A.RandomRotate90(p=0.5),
174
+ A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
175
+ A.GaussNoise(std_range=(0.02, 0.1), p=0.3),
176
+ A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
177
+ A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
178
+ ToTensorV2(),
179
+ ])
180
+
181
+
182
+ def get_val_transforms(img_size=512):
183
+ return A.Compose([
184
+ A.CenterCrop(img_size, img_size, p=1.0),
185
+ A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
186
+ ToTensorV2(),
187
+ ])
188
+
189
+
190
+ # ═══════════════════════════════════════════════════════════════
191
+ # DOWNLOAD DATASET
192
+ # ═══════════════════════════════════════════════════════════════
193
+
194
+ def download_dataset(data_dir):
195
+ """Download E-SCDD from HuggingFace Hub."""
196
+ train_img = os.path.join(data_dir, "el_images_train")
197
+ if os.path.exists(train_img) and len(os.listdir(train_img)) > 100:
198
+ print("Dataset already downloaded.")
199
+ return
200
+
201
+ print("Downloading E-SCDD dataset from HuggingFace Hub...")
202
+ from huggingface_hub import snapshot_download
203
+ snapshot_download(
204
+ repo_id="snt-ubix/e-scdd",
205
+ repo_type="dataset",
206
+ local_dir=data_dir,
207
+ )
208
+ print(f"Downloaded to {data_dir}")
209
+
210
+
211
+ # ═══════════════════════════════════════════════════════════════
212
+ # METRICS
213
+ # ═══════════════════════════════════════════════════════════════
214
+
215
+ def compute_metrics(pred_logits, target, num_classes=5):
216
+ """Compute per-class IoU and Dice."""
217
+ pred = torch.argmax(pred_logits, dim=1) # (B, H, W)
218
+
219
+ ious, dices = [], []
220
+ for c in range(num_classes):
221
+ pred_c = (pred == c)
222
+ target_c = (target == c)
223
+
224
+ intersection = (pred_c & target_c).float().sum()
225
+ union = (pred_c | target_c).float().sum()
226
+
227
+ iou = (intersection + 1e-6) / (union + 1e-6)
228
+ dice = (2 * intersection + 1e-6) / (pred_c.float().sum() + target_c.float().sum() + 1e-6)
229
+
230
+ ious.append(iou.item())
231
+ dices.append(dice.item())
232
+
233
+ return {
234
+ "mean_iou": np.mean(ious),
235
+ "mean_dice": np.mean(dices),
236
+ "per_class_iou": dict(zip(Config.CLASS_NAMES, ious)),
237
+ "per_class_dice": dict(zip(Config.CLASS_NAMES, dices)),
238
+ }
239
+
240
+
241
+ # ═══════════════════════════════════════════════════════════════
242
+ # TRAINING
243
+ # ═══════════════════════════════════════════════════════════════
244
+
245
+ def train():
246
+ cfg = Config()
247
+ os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
248
+
249
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
250
+ print(f"Device: {device}")
251
+ if device.type == "cuda":
252
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
253
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
254
+
255
+ # ── Download data ────────────────────────────────────────
256
+ download_dataset(cfg.DATA_DIR)
257
+
258
+ # ── Create datasets ──────────────────────────────────────
259
+ print("\nLoading datasets...")
260
+ train_ds = ESCDDDataset(
261
+ os.path.join(cfg.DATA_DIR, "el_images_train"),
262
+ os.path.join(cfg.DATA_DIR, "el_masks_train"),
263
+ transform=get_train_transforms(cfg.IMG_SIZE),
264
+ )
265
+ val_ds = ESCDDDataset(
266
+ os.path.join(cfg.DATA_DIR, "el_images_val"),
267
+ os.path.join(cfg.DATA_DIR, "el_masks_val"),
268
+ transform=get_val_transforms(cfg.IMG_SIZE),
269
+ )
270
+
271
+ train_loader = DataLoader(
272
+ train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True,
273
+ num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True,
274
+ )
275
+ val_loader = DataLoader(
276
+ val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False,
277
+ num_workers=cfg.NUM_WORKERS, pin_memory=True,
278
+ )
279
+
280
+ # ── Compute class weights from training data ─────────────
281
+ print("\nComputing class distribution...")
282
+ class_pixels = np.zeros(cfg.NUM_CLASSES, dtype=np.float64)
283
+ for i in range(min(len(train_ds), 200)): # Sample 200 images
284
+ _, mask = train_ds[i]
285
+ if isinstance(mask, torch.Tensor):
286
+ mask = mask.numpy()
287
+ for c in range(cfg.NUM_CLASSES):
288
+ class_pixels[c] += (mask == c).sum()
289
+
290
+ total = class_pixels.sum()
291
+ class_freq = class_pixels / total
292
+ print("Class distribution:")
293
+ for i, name in enumerate(cfg.CLASS_NAMES):
294
+ print(f" {name}: {class_freq[i]*100:.2f}% ({int(class_pixels[i]):,} px)")
295
+
296
+ # ── Create model ─────────────────────────────────────────
297
+ print(f"\nCreating {cfg.ARCHITECTURE} + {cfg.ENCODER}...")
298
+ ModelClass = getattr(smp, cfg.ARCHITECTURE)
299
+ model = ModelClass(
300
+ encoder_name=cfg.ENCODER,
301
+ encoder_weights=cfg.ENCODER_WEIGHTS,
302
+ in_channels=cfg.IN_CHANNELS,
303
+ classes=cfg.NUM_CLASSES,
304
+ decoder_attention_type="scse",
305
+ )
306
+ model = model.to(device)
307
+
308
+ total_params = sum(p.numel() for p in model.parameters())
309
+ print(f"Parameters: {total_params:,}")
310
+
311
+ # ── Loss: Dice + Focal (handles class imbalance) ─────────
312
+ dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True, smooth=1.0)
313
+ focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=cfg.FOCAL_GAMMA)
314
+
315
+ def criterion(pred, target):
316
+ return cfg.DICE_WEIGHT * dice_loss(pred, target) + cfg.FOCAL_WEIGHT * focal_loss(pred, target)
317
+
318
+ # ── Optimizer with differential LR ───────────────────────
319
+ optimizer = AdamW([
320
+ {"params": model.encoder.parameters(), "lr": cfg.ENCODER_LR},
321
+ {"params": model.decoder.parameters(), "lr": cfg.DECODER_LR},
322
+ {"params": model.segmentation_head.parameters(), "lr": cfg.DECODER_LR},
323
+ ], weight_decay=cfg.WEIGHT_DECAY)
324
+
325
+ scheduler = CosineAnnealingLR(optimizer, T_max=cfg.NUM_EPOCHS, eta_min=1e-6)
326
+ scaler = torch.amp.GradScaler(enabled=cfg.USE_AMP)
327
+
328
+ # ── Training loop ────────────────────────────────────────
329
+ best_val_dice = 0.0
330
+ history = {"train_loss": [], "val_loss": [], "val_dice": [], "val_iou": []}
331
+
332
+ print(f"\n{'='*60}")
333
+ print(f"Starting training: {cfg.NUM_EPOCHS} epochs")
334
+ print(f"{'='*60}\n")
335
+
336
+ for epoch in range(cfg.NUM_EPOCHS):
337
+ t_start = time.time()
338
+
339
+ # ── Train ────────────────────────────────────────────
340
+ model.train()
341
+ train_loss = 0.0
342
+
343
+ for batch_idx, (images, masks) in enumerate(train_loader):
344
+ images = images.to(device)
345
+ masks = masks.to(device)
346
+
347
+ optimizer.zero_grad()
348
+
349
+ with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
350
+ logits = model(images)
351
+ loss = criterion(logits, masks)
352
+
353
+ scaler.scale(loss).backward()
354
+ scaler.unscale_(optimizer)
355
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.GRADIENT_CLIP)
356
+ scaler.step(optimizer)
357
+ scaler.update()
358
+
359
+ train_loss += loss.item()
360
+
361
+ train_loss /= len(train_loader)
362
+ scheduler.step()
363
+
364
+ # ── Validate ─────────────────────────────────────────
365
+ model.eval()
366
+ val_loss = 0.0
367
+ all_ious, all_dices = [], []
368
+
369
+ with torch.no_grad():
370
+ for images, masks in val_loader:
371
+ images = images.to(device)
372
+ masks = masks.to(device)
373
+
374
+ with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
375
+ logits = model(images)
376
+ loss = criterion(logits, masks)
377
+
378
+ val_loss += loss.item()
379
+ metrics = compute_metrics(logits, masks, cfg.NUM_CLASSES)
380
+ all_ious.append(metrics["mean_iou"])
381
+ all_dices.append(metrics["mean_dice"])
382
+
383
+ val_loss /= len(val_loader)
384
+ val_dice = np.mean(all_dices)
385
+ val_iou = np.mean(all_ious)
386
+
387
+ t_elapsed = time.time() - t_start
388
+ lr_enc = optimizer.param_groups[0]["lr"]
389
+ lr_dec = optimizer.param_groups[1]["lr"]
390
+
391
+ print(f"Epoch {epoch+1:3d}/{cfg.NUM_EPOCHS} | "
392
+ f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | "
393
+ f"val_dice={val_dice:.4f} | val_iou={val_iou:.4f} | "
394
+ f"lr_enc={lr_enc:.6f} | {t_elapsed:.1f}s")
395
+
396
+ # Per-class dice every 10 epochs
397
+ if (epoch + 1) % 10 == 0:
398
+ # Run full validation for per-class metrics
399
+ all_per_class = {name: [] for name in cfg.CLASS_NAMES}
400
+ with torch.no_grad():
401
+ for images, masks in val_loader:
402
+ images, masks = images.to(device), masks.to(device)
403
+ with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
404
+ logits = model(images)
405
+ m = compute_metrics(logits, masks, cfg.NUM_CLASSES)
406
+ for name in cfg.CLASS_NAMES:
407
+ all_per_class[name].append(m["per_class_dice"][name])
408
+ print(" Per-class Dice:")
409
+ for name in cfg.CLASS_NAMES:
410
+ print(f" {name:20s}: {np.mean(all_per_class[name]):.4f}")
411
+
412
+ history["train_loss"].append(train_loss)
413
+ history["val_loss"].append(val_loss)
414
+ history["val_dice"].append(val_dice)
415
+ history["val_iou"].append(val_iou)
416
+
417
+ # ── Save best model ──────────────────────────────────
418
+ if val_dice > best_val_dice:
419
+ best_val_dice = val_dice
420
+ save_path = os.path.join(cfg.OUTPUT_DIR, "best_model.pth")
421
+ torch.save({
422
+ "epoch": epoch + 1,
423
+ "model_state_dict": model.state_dict(),
424
+ "optimizer_state_dict": optimizer.state_dict(),
425
+ "val_dice": val_dice,
426
+ "val_iou": val_iou,
427
+ "architecture": cfg.ARCHITECTURE,
428
+ "encoder": cfg.ENCODER,
429
+ "num_classes": cfg.NUM_CLASSES,
430
+ "img_size": cfg.IMG_SIZE,
431
+ "class_names": cfg.CLASS_NAMES,
432
+ "label_remap": LABEL_REMAP.tolist(),
433
+ }, save_path)
434
+ print(f" β†’ Best model saved (dice={val_dice:.4f})")
435
+
436
+ # Periodic checkpoint every 25 epochs
437
+ if (epoch + 1) % 25 == 0:
438
+ ckpt_path = os.path.join(cfg.OUTPUT_DIR, f"checkpoint_ep{epoch+1}.pth")
439
+ torch.save({"epoch": epoch+1, "model_state_dict": model.state_dict()}, ckpt_path)
440
+
441
+ # ── Save final model + history ───────────────────────────
442
+ final_path = os.path.join(cfg.OUTPUT_DIR, "final_model.pth")
443
+ torch.save({
444
+ "epoch": cfg.NUM_EPOCHS,
445
+ "model_state_dict": model.state_dict(),
446
+ "val_dice": history["val_dice"][-1],
447
+ "val_iou": history["val_iou"][-1],
448
+ "architecture": cfg.ARCHITECTURE,
449
+ "encoder": cfg.ENCODER,
450
+ "num_classes": cfg.NUM_CLASSES,
451
+ "img_size": cfg.IMG_SIZE,
452
+ "class_names": cfg.CLASS_NAMES,
453
+ "label_remap": LABEL_REMAP.tolist(),
454
+ "history": history,
455
+ }, final_path)
456
+
457
+ with open(os.path.join(cfg.OUTPUT_DIR, "history.json"), "w") as f:
458
+ json.dump(history, f, indent=2)
459
+
460
+ print(f"\n{'='*60}")
461
+ print(f"Training complete! Best val dice: {best_val_dice:.4f}")
462
+ print(f"Models saved to {cfg.OUTPUT_DIR}/")
463
+ print(f"{'='*60}")
464
+
465
+ # ── Push to Hub ──────────────────────────────────────────
466
+ if cfg.PUSH_TO_HUB and cfg.HUB_MODEL_ID:
467
+ try:
468
+ from huggingface_hub import HfApi
469
+ api = HfApi()
470
+ api.create_repo(cfg.HUB_MODEL_ID, exist_ok=True)
471
+ api.upload_folder(
472
+ folder_path=cfg.OUTPUT_DIR,
473
+ repo_id=cfg.HUB_MODEL_ID,
474
+ commit_message=f"Trained model (dice={best_val_dice:.4f})",
475
+ )
476
+ print(f"Pushed to hub: {cfg.HUB_MODEL_ID}")
477
+ except Exception as e:
478
+ print(f"Hub push failed: {e}")
479
+
480
+
481
+ if __name__ == "__main__":
482
+ train()