Fix: mask dtype → LongTensor for DiceLoss one_hot compatibility
Browse files
train.py
CHANGED
|
@@ -145,8 +145,8 @@ class ESCDDDataset(Dataset):
|
|
| 145 |
# Apply augmentations
|
| 146 |
if self.transform:
|
| 147 |
augmented = self.transform(image=img, mask=mask)
|
| 148 |
-
img = augmented["image"]
|
| 149 |
-
mask = augmented["mask"] # (H, W)
|
| 150 |
else:
|
| 151 |
img = torch.from_numpy(img).unsqueeze(0) / 255.0
|
| 152 |
mask = torch.from_numpy(mask).long()
|
|
@@ -331,7 +331,7 @@ def train():
|
|
| 331 |
|
| 332 |
for batch_idx, (images, masks) in enumerate(train_loader):
|
| 333 |
images = images.to(device)
|
| 334 |
-
masks = masks.to(device)
|
| 335 |
|
| 336 |
optimizer.zero_grad()
|
| 337 |
|
|
@@ -358,7 +358,7 @@ def train():
|
|
| 358 |
with torch.no_grad():
|
| 359 |
for images, masks in val_loader:
|
| 360 |
images = images.to(device)
|
| 361 |
-
masks = masks.to(device)
|
| 362 |
|
| 363 |
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
|
| 364 |
logits = model(images)
|
|
@@ -387,7 +387,7 @@ def train():
|
|
| 387 |
all_per_class = {name: [] for name in cfg.CLASS_NAMES}
|
| 388 |
with torch.no_grad():
|
| 389 |
for images, masks in val_loader:
|
| 390 |
-
images, masks = images.to(device), masks.to(device)
|
| 391 |
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
|
| 392 |
logits = model(images)
|
| 393 |
m = compute_metrics(logits, masks, cfg.NUM_CLASSES)
|
|
|
|
| 145 |
# Apply augmentations
|
| 146 |
if self.transform:
|
| 147 |
augmented = self.transform(image=img, mask=mask)
|
| 148 |
+
img = augmented["image"] # (1, H, W) float tensor
|
| 149 |
+
mask = augmented["mask"].long() # (H, W) LongTensor — required by DiceLoss one_hot
|
| 150 |
else:
|
| 151 |
img = torch.from_numpy(img).unsqueeze(0) / 255.0
|
| 152 |
mask = torch.from_numpy(mask).long()
|
|
|
|
| 331 |
|
| 332 |
for batch_idx, (images, masks) in enumerate(train_loader):
|
| 333 |
images = images.to(device)
|
| 334 |
+
masks = masks.to(device).long() # Ensure LongTensor on GPU
|
| 335 |
|
| 336 |
optimizer.zero_grad()
|
| 337 |
|
|
|
|
| 358 |
with torch.no_grad():
|
| 359 |
for images, masks in val_loader:
|
| 360 |
images = images.to(device)
|
| 361 |
+
masks = masks.to(device).long() # Ensure LongTensor on GPU
|
| 362 |
|
| 363 |
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
|
| 364 |
logits = model(images)
|
|
|
|
| 387 |
all_per_class = {name: [] for name in cfg.CLASS_NAMES}
|
| 388 |
with torch.no_grad():
|
| 389 |
for images, masks in val_loader:
|
| 390 |
+
images, masks = images.to(device), masks.to(device).long()
|
| 391 |
with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
|
| 392 |
logits = model(images)
|
| 393 |
m = compute_metrics(logits, masks, cfg.NUM_CLASSES)
|