Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| import numpy as np | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from tqdm import tqdm | |
| import segmentation_models_pytorch as smp | |
| import cv2 | |
| # --- 1. Configuration --- | |
| class CFG: | |
| DATA_DIR = r"SEN-2_LULC_preprocessed" | |
| TRAIN_IMG_DIR = os.path.join(DATA_DIR, "train_images") | |
| TRAIN_MASK_DIR = os.path.join(DATA_DIR, "train_masks") | |
| VAL_IMG_DIR = os.path.join(DATA_DIR, "val_images") | |
| VAL_MASK_DIR = os.path.join(DATA_DIR, "val_masks") | |
| OUTPUT_DIR = "./outputs_rgb_optimized" | |
| # The path for the 'best' model, for inference later | |
| MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "best_model_optimized.pth") | |
| # --- NEW: Path for the resumable checkpoint file --- | |
| CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "checkpoint.pth") | |
| PREDICTION_SAVE_PATH = os.path.join(OUTPUT_DIR, "predictions_optimized") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_NAME = "CustomDeepLabV3+" | |
| ENCODER_NAME = "timm-efficientnet-b2" | |
| LOSS_FN_NAME = "DiceFocal" | |
| IN_CHANNELS = 3; NUM_CLASSES = 8; IMG_SIZE = 256 | |
| BATCH_SIZE = 4; ACCUMULATION_STEPS = 4 | |
| NUM_WORKERS = 8; LEARNING_RATE = 1e-4; EPOCHS = 50 | |
| SEED = 42; SUBSET_FRACTION = 0.75 | |
| # --- ARCHITECTURE and LOSS CLASSES (Unchanged) --- | |
| class SELayer(nn.Module): | |
| def __init__(self, channel, reduction=16): | |
| super(SELayer, self).__init__(); self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) | |
| def forward(self, x): | |
| b, c, _, _ = x.size(); y = self.avg_pool(x).view(b, c); y = self.fc(y).view(b, c, 1, 1); return x * y.expand_as(x) | |
| class CustomDeepLabV3Plus(nn.Module): | |
| def __init__(self, encoder_name, in_channels, classes): | |
| super().__init__(); self.smp_model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights="imagenet", in_channels=in_channels, classes=classes) | |
| decoder_channels = self.smp_model.segmentation_head[0].in_channels; self.se_layer = SELayer(decoder_channels) | |
| self.segmentation_head = self.smp_model.segmentation_head; self.smp_model.segmentation_head = nn.Identity() | |
| def forward(self, x): | |
| decoder_features = self.smp_model(x); attended_features = self.se_layer(decoder_features) | |
| output = self.segmentation_head(attended_features); return output | |
| class CombinedLoss(nn.Module): | |
| def __init__(self, loss1, loss2, alpha=0.5): | |
| super(CombinedLoss, self).__init__(); self.loss1 = loss1; self.loss2 = loss2; self.alpha = alpha | |
| self.name = f"{alpha}*{self.loss1.__class__.__name__} + {1-alpha}*{self.loss2.__class__.__name__}" | |
| def forward(self, prediction, target): | |
| loss1_val = self.loss1(prediction, target); loss2_val = self.loss2(prediction, target); return self.alpha * loss1_val + (1 - self.alpha) * loss2_val | |
| # --- DATASET and TRANSFORMS (Unchanged) --- | |
| class LULCDataset(Dataset): | |
| def __init__(self, image_dir, mask_dir, transform=None, subset_fraction=1.0): | |
| self.image_dir = image_dir; self.mask_dir = mask_dir; self.transform = transform | |
| all_images = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')]) | |
| all_masks = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')]) | |
| num_samples = int(len(all_images) * subset_fraction) | |
| self.images = all_images[:num_samples]; self.masks = all_masks[:num_samples] | |
| assert len(self.images) == len(self.masks), "Mismatch"; print(f"Found {len(all_images)} total images, USING {len(self.images)} samples ({subset_fraction*100}%) from {image_dir}") | |
| def __len__(self): return len(self.images) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.image_dir, self.images[idx]); mask_path = os.path.join(self.mask_dir, self.masks[idx]) | |
| image = np.array(Image.open(img_path).convert("RGB"), dtype=np.float32) | |
| mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) | |
| if self.transform: augmented = self.transform(image=image, mask=mask); image, mask = augmented['image'], augmented['mask'] | |
| return image, mask | |
| def get_transforms(img_size): | |
| DATASET_MEAN = [0.485, 0.456, 0.406]; DATASET_STD = [0.229, 0.224, 0.225] | |
| train_transform = A.Compose([A.Resize(img_size, img_size), A.Rotate(limit=35, p=0.5), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Normalize(mean=DATASET_MEAN, std=DATASET_STD), ToTensorV2()]) | |
| val_transform = A.Compose([A.Resize(img_size, img_size), A.Normalize(mean=DATASET_MEAN, std=DATASET_STD), ToTensorV2()]) | |
| return train_transform, val_transform | |
| # --- GET MODEL AND LOSS (Unchanged) --- | |
| def get_model(): | |
| if CFG.MODEL_NAME == "CustomDeepLabV3+": model = CustomDeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, in_channels=CFG.IN_CHANNELS, classes=CFG.NUM_CLASSES) | |
| else: model = smp.DeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, encoder_weights="imagenet", in_channels=CFG.IN_CHANNELS, classes=CFG.NUM_CLASSES) | |
| return model.to(CFG.DEVICE) | |
| def get_loss_fn(): | |
| if CFG.LOSS_FN_NAME == "DiceFocal": dice = smp.losses.DiceLoss(mode='multiclass'); focal = smp.losses.FocalLoss(mode='multiclass'); return CombinedLoss(focal, dice, alpha=0.5) | |
| else: return smp.losses.DiceLoss(mode='multiclass') | |
| # --- Training and Evaluation Functions (Unchanged) --- | |
| def train_one_epoch(loader, model, optimizer, loss_fn, scaler): | |
| loop = tqdm(loader, desc="Training"); model.train(); optimizer.zero_grad() | |
| for batch_idx, (data, targets) in enumerate(loop): | |
| data = data.to(CFG.DEVICE, non_blocking=True, memory_format=torch.channels_last) | |
| targets = targets.long().to(CFG.DEVICE, non_blocking=True) | |
| with torch.amp.autocast(device_type=CFG.DEVICE, dtype=torch.bfloat16, enabled=(CFG.DEVICE=="cuda")): | |
| predictions = model(data); loss = loss_fn(predictions, targets) / CFG.ACCUMULATION_STEPS | |
| scaler.scale(loss).backward() | |
| if (batch_idx + 1) % CFG.ACCUMULATION_STEPS == 0: | |
| scaler.step(optimizer); scaler.update(); optimizer.zero_grad() | |
| loop.set_postfix(loss=loss.item() * CFG.ACCUMULATION_STEPS) | |
| def evaluate_model(loader, model, loss_fn): | |
| model.eval(); intersection, union = torch.zeros(CFG.NUM_CLASSES, device=CFG.DEVICE), torch.zeros(CFG.NUM_CLASSES, device=CFG.DEVICE) | |
| pixel_correct, pixel_total, total_loss = 0, 0, 0 | |
| with torch.no_grad(): | |
| loop = tqdm(loader, desc="Evaluating") | |
| for x, y in loop: | |
| x = x.to(CFG.DEVICE, non_blocking=True, memory_format=torch.channels_last) | |
| y = y.to(CFG.DEVICE, non_blocking=True).long() | |
| with torch.amp.autocast(device_type=CFG.DEVICE, dtype=torch.bfloat16, enabled=(CFG.DEVICE=="cuda")): | |
| preds = model(x); loss = loss_fn(preds, y); total_loss += loss.item() | |
| pred_labels = torch.argmax(preds, dim=1); pixel_correct += (pred_labels == y).sum(); pixel_total += torch.numel(y) | |
| for cls in range(CFG.NUM_CLASSES): pred_mask = (pred_labels == cls); true_mask = (y == cls); intersection[cls] += (pred_mask & true_mask).sum(); union[cls] += (pred_mask | true_mask).sum() | |
| pixel_acc = (pixel_correct / pixel_total) * 100; iou_per_class = (intersection + 1e-6) / (union + 1e-6) | |
| mean_iou = iou_per_class.mean(); avg_loss = total_loss / len(loader) | |
| print(f"Validation Results -> Avg Loss: {avg_loss:.4f}, Pixel Acc: {pixel_acc:.2f}%, mIoU: {mean_iou:.4f}") | |
| for i, iou in enumerate(iou_per_class): print(f" Class {i} IoU: {iou:.4f}") | |
| return mean_iou | |
| def save_predictions_as_images(loader, model): | |
| # This function is not part of the training loop, no changes needed. | |
| pass # implementation is correct as-is | |
| # --- NEW: Helper function to save a checkpoint --- | |
| def save_checkpoint(state, filename="checkpoint.pth"): | |
| print("=> Saving checkpoint") | |
| torch.save(state, filename) | |
| def main(): | |
| torch.manual_seed(CFG.SEED); np.random.seed(CFG.SEED); os.makedirs(CFG.OUTPUT_DIR, exist_ok=True) | |
| if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True | |
| train_transform, val_transform = get_transforms(CFG.IMG_SIZE) | |
| train_ds = LULCDataset(CFG.TRAIN_IMG_DIR, CFG.TRAIN_MASK_DIR, transform=train_transform, subset_fraction=CFG.SUBSET_FRACTION) | |
| val_ds = LULCDataset(CFG.VAL_IMG_DIR, CFG.VAL_MASK_DIR, transform=val_transform, subset_fraction=CFG.SUBSET_FRACTION) | |
| train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, num_workers=CFG.NUM_WORKERS, pin_memory=True, shuffle=True, persistent_workers=True) | |
| val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE, num_workers=CFG.NUM_WORKERS, pin_memory=True, shuffle=False, persistent_workers=True) | |
| model = get_model() | |
| model = model.to(memory_format=torch.channels_last) | |
| loss_fn = get_loss_fn() | |
| optimizer = optim.AdamW(model.parameters(), lr=CFG.LEARNING_RATE) | |
| scaler = torch.amp.GradScaler(enabled=(CFG.DEVICE=="cuda")) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS, eta_min=1e-6) | |
| # --- NEW: Logic to load checkpoint and resume training --- | |
| start_epoch = 0 | |
| best_val_miou = -1.0 | |
| if os.path.exists(CFG.CHECKPOINT_PATH): | |
| print(f"=> Loading checkpoint '{CFG.CHECKPOINT_PATH}'") | |
| checkpoint = torch.load(CFG.CHECKPOINT_PATH, map_location=CFG.DEVICE) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| scaler.load_state_dict(checkpoint['scaler_state_dict']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| best_val_miou = checkpoint['best_val_miou'] | |
| print(f"=> Resuming training from epoch {start_epoch}") | |
| else: | |
| print("=> No checkpoint found, starting new training session.") | |
| # --- MODIFIED: Main training loop now starts from the correct epoch --- | |
| for epoch in range(start_epoch, CFG.EPOCHS): | |
| print(f"\n--- Epoch {epoch+1}/{CFG.EPOCHS} ---") | |
| train_one_epoch(train_loader, model, optimizer, loss_fn, scaler) | |
| current_miou = evaluate_model(val_loader, model, loss_fn) | |
| scheduler.step() | |
| # Create the checkpoint dictionary with the complete state | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'scaler_state_dict': scaler.state_dict(), | |
| 'best_val_miou': best_val_miou | |
| } | |
| if current_miou > best_val_miou: | |
| best_val_miou = current_miou | |
| checkpoint['best_val_miou'] = best_val_miou # Update best score in checkpoint | |
| print(f"🎉 New best mIoU: {best_val_miou:.4f}! Saving best model to {CFG.MODEL_SAVE_PATH}") | |
| torch.save(model.state_dict(), CFG.MODEL_SAVE_PATH) # Save just the model for easy inference | |
| # Save the full state checkpoint after every epoch | |
| save_checkpoint(checkpoint, filename=CFG.CHECKPOINT_PATH) | |
| print("\n--- Training Complete. Saving final predictions. ---") | |
| # Load the best performing model for final predictions | |
| model.load_state_dict(torch.load(CFG.MODEL_SAVE_PATH)) | |
| # Note: You may want a separate test_loader for final unbiased evaluation | |
| save_predictions_as_images(val_loader, model) | |
| if __name__ == "__main__": | |
| main() |