nithishbasireddy commited on
Commit
d3f1e7d
·
verified ·
1 Parent(s): 2ab499c

Fix: mask dtype → LongTensor for DiceLoss one_hot compatibility

Browse files
Files changed (1) hide show
  1. train.py +5 -5
train.py CHANGED
@@ -145,8 +145,8 @@ class ESCDDDataset(Dataset):
145
  # Apply augmentations
146
  if self.transform:
147
  augmented = self.transform(image=img, mask=mask)
148
- img = augmented["image"] # (1, H, W) float tensor
149
- mask = augmented["mask"] # (H, W) long tensor
150
  else:
151
  img = torch.from_numpy(img).unsqueeze(0) / 255.0
152
  mask = torch.from_numpy(mask).long()
@@ -331,7 +331,7 @@ def train():
331
 
332
  for batch_idx, (images, masks) in enumerate(train_loader):
333
  images = images.to(device)
334
- masks = masks.to(device)
335
 
336
  optimizer.zero_grad()
337
 
@@ -358,7 +358,7 @@ def train():
358
  with torch.no_grad():
359
  for images, masks in val_loader:
360
  images = images.to(device)
361
- masks = masks.to(device)
362
 
363
  with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
364
  logits = model(images)
@@ -387,7 +387,7 @@ def train():
387
  all_per_class = {name: [] for name in cfg.CLASS_NAMES}
388
  with torch.no_grad():
389
  for images, masks in val_loader:
390
- images, masks = images.to(device), masks.to(device)
391
  with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
392
  logits = model(images)
393
  m = compute_metrics(logits, masks, cfg.NUM_CLASSES)
 
145
  # Apply augmentations
146
  if self.transform:
147
  augmented = self.transform(image=img, mask=mask)
148
+ img = augmented["image"] # (1, H, W) float tensor
149
+ mask = augmented["mask"].long() # (H, W) LongTensor — required by DiceLoss one_hot
150
  else:
151
  img = torch.from_numpy(img).unsqueeze(0) / 255.0
152
  mask = torch.from_numpy(mask).long()
 
331
 
332
  for batch_idx, (images, masks) in enumerate(train_loader):
333
  images = images.to(device)
334
+ masks = masks.to(device).long() # Ensure LongTensor on GPU
335
 
336
  optimizer.zero_grad()
337
 
 
358
  with torch.no_grad():
359
  for images, masks in val_loader:
360
  images = images.to(device)
361
+ masks = masks.to(device).long() # Ensure LongTensor on GPU
362
 
363
  with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
364
  logits = model(images)
 
387
  all_per_class = {name: [] for name in cfg.CLASS_NAMES}
388
  with torch.no_grad():
389
  for images, masks in val_loader:
390
+ images, masks = images.to(device), masks.to(device).long()
391
  with torch.amp.autocast(device_type="cuda", enabled=cfg.USE_AMP):
392
  logits = model(images)
393
  m = compute_metrics(logits, masks, cfg.NUM_CLASSES)