Fix: total_mem → total_memory, NUM_WORKERS=0 for Windows
Browse files
train.py
CHANGED
|
@@ -55,7 +55,7 @@ class Config:
|
|
| 55 |
DECODER_LR = 5e-4 # Higher LR for random decoder
|
| 56 |
WEIGHT_DECAY = 1e-4
|
| 57 |
USE_AMP = True # Mixed precision — halves VRAM usage
|
| 58 |
-
NUM_WORKERS =
|
| 59 |
GRADIENT_CLIP = 1.0
|
| 60 |
|
| 61 |
# Loss
|
|
@@ -75,18 +75,7 @@ class Config:
|
|
| 75 |
# CLASS MAPPING: E-SCDD 30 classes → 5 classes
|
| 76 |
# ═══════════════════════════════════════════════════════════════
|
| 77 |
|
| 78 |
-
|
| 79 |
-
# We remap to 5 meaningful classes:
|
| 80 |
-
# 0 = background (all spacing, borders, padding, text, clamp, frame, jbox)
|
| 81 |
-
# 1 = busbar (label 9)
|
| 82 |
-
# 2 = crack (label 14=crack, label 10=crack_rbn_edge)
|
| 83 |
-
# 3 = dark/inactive (label 11=inactive, label 17=dead_cell, label 20=edge_dark)
|
| 84 |
-
# 4 = other_defect (rings, material, gridline, splice, corrosion, belt_mark, etc.)
|
| 85 |
-
|
| 86 |
-
LABEL_REMAP = np.zeros(30, dtype=np.uint8) # default: everything → 0 (background)
|
| 87 |
-
|
| 88 |
-
# Background features (labels 0-8, 21-24, 29)
|
| 89 |
-
# Already 0 by default
|
| 90 |
|
| 91 |
# Busbar
|
| 92 |
LABEL_REMAP[9] = 1 # busbars → busbar
|
|
@@ -250,7 +239,7 @@ def train():
|
|
| 250 |
print(f"Device: {device}")
|
| 251 |
if device.type == "cuda":
|
| 252 |
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 253 |
-
print(f"VRAM: {torch.cuda.get_device_properties(0).
|
| 254 |
|
| 255 |
# ── Download data ────────────────────────────────────────
|
| 256 |
download_dataset(cfg.DATA_DIR)
|
|
@@ -395,7 +384,6 @@ def train():
|
|
| 395 |
|
| 396 |
# Per-class dice every 10 epochs
|
| 397 |
if (epoch + 1) % 10 == 0:
|
| 398 |
-
# Run full validation for per-class metrics
|
| 399 |
all_per_class = {name: [] for name in cfg.CLASS_NAMES}
|
| 400 |
with torch.no_grad():
|
| 401 |
for images, masks in val_loader:
|
|
|
|
| 55 |
DECODER_LR = 5e-4 # Higher LR for random decoder
|
| 56 |
WEIGHT_DECAY = 1e-4
|
| 57 |
USE_AMP = True # Mixed precision — halves VRAM usage
|
| 58 |
+
NUM_WORKERS = 0 # 0 for Windows compatibility
|
| 59 |
GRADIENT_CLIP = 1.0
|
| 60 |
|
| 61 |
# Loss
|
|
|
|
| 75 |
# CLASS MAPPING: E-SCDD 30 classes → 5 classes
|
| 76 |
# ═══════════════════════════════════════════════════════════════
|
| 77 |
|
| 78 |
+
LABEL_REMAP = np.zeros(30, dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# Busbar
|
| 81 |
LABEL_REMAP[9] = 1 # busbars → busbar
|
|
|
|
| 239 |
print(f"Device: {device}")
|
| 240 |
if device.type == "cuda":
|
| 241 |
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 242 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 243 |
|
| 244 |
# ── Download data ────────────────────────────────────────
|
| 245 |
download_dataset(cfg.DATA_DIR)
|
|
|
|
| 384 |
|
| 385 |
# Per-class dice every 10 epochs
|
| 386 |
if (epoch + 1) % 10 == 0:
|
|
|
|
| 387 |
all_per_class = {name: [] for name in cfg.CLASS_NAMES}
|
| 388 |
with torch.no_grad():
|
| 389 |
for images, masks in val_loader:
|