import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms from torchvision.datasets import VOCSegmentation from torch.utils.data import DataLoader from torch.utils.data import Dataset import glob from PIL import Image import numpy as np import wandb import pandas as pd import os import matplotlib.pyplot as plt import opendatasets as opd import zipfile torch.manual_seed(42) np.random.seed(42) # wandb.login(key="your_wandb_api_key_here") EPOCHS = 25 BATCH_SIZE = 8 LR = 1e-3 NUM_CLASSES = 32 DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # wandb.init(project="segnet-efficientnet-camvid", config={ # "epochs": EPOCHS, # "batch_size": BATCH_SIZE, # "learning_rate": LR, # "architecture": "SegNet-EfficientNet", # "dataset": "CamVid" # }) class SegNetEfficientNet(nn.Module): def __init__(self, num_classes=32): super(SegNetEfficientNet, self).__init__() base_model = models.efficientnet_b0(pretrained=True) features = list(base_model.features.children()) # EfficientNet-B0 backbone (output channels gradually increase to 1280) self.encoder = nn.Sequential(*features) # Output: [B, 1280, H/32, W/32] # Decoder blocks (mirroring encoder with ConvTranspose2d) self.decoder = nn.Sequential( nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), nn.BatchNorm2d(32), nn.ReLU(inplace=True), ) self.classifier = nn.Conv2d(32, num_classes, kernel_size=1) def forward(self, x): x = self.encoder(x) # Downsampled features from EfficientNet x = self.decoder(x) # Upsampled x = self.classifier(x) x = F.interpolate(x, size=(360, 480), mode='bilinear', align_corners=False) return x class CamVidDataset(Dataset): """ CamVid dataset loader with RGB mask to class index conversion. Expects directory structure: camvid/ train/ train_labels/ val/ val_labels/ test/ test_labels/ """ def __init__(self, root, split='train', transform=None, image_size=(360, 480), target_transform=None, class_dict_path='camvid/CamVid/class_dict.csv'): self.root = root self.split = split self.transform = transform self.target_transform = target_transform self.image_dir = os.path.join(root, split) self.label_dir = os.path.join(root, f"{split}_labels") self.image_paths = sorted(glob.glob(os.path.join(self.image_dir, '*.png'))) self.label_paths = sorted(glob.glob(os.path.join(self.label_dir, '*.png'))) self.label_resize = transforms.Resize(image_size, interpolation=Image.NEAREST) self.image_resize = transforms.Resize(image_size, interpolation=Image.BILINEAR) assert len(self.image_paths) == len(self.label_paths), "Mismatch between images and labels." # Load class_dict.csv and build color-to-class mapping df = pd.read_csv(class_dict_path) self.color_to_class = { (row['r'], row['g'], row['b']): idx for idx, row in df.iterrows() } def __len__(self): return len(self.image_paths) def rgb_to_class(self, mask): """Convert an RGB mask (PIL.Image) to a 2D class index mask.""" mask_np = np.array(mask) h, w, _ = mask_np.shape class_mask = np.zeros((h, w), dtype=np.uint8) for rgb, class_idx in self.color_to_class.items(): matches = (mask_np == rgb).all(axis=2) class_mask[matches] = class_idx return class_mask def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert('RGB') label = Image.open(self.label_paths[idx]).convert('RGB') # Resize both to 360x480 image = self.image_resize(image) label = self.label_resize(label) if self.transform: image = self.transform(image) label = self.rgb_to_class(label) label = torch.from_numpy(label).long() return image, label if __name__ == "__main__": dataset_url = "https://www.kaggle.com/datasets/carlolepelaars/camvid" opd.download(dataset_url) # Set dataset folder (adjust path if needed) dataset_folder = "camvid" print("Dataset directory contents:") print(os.listdir(dataset_folder)) input_transform = transforms.Compose([ transforms.Resize((360, 480)), # Or larger if needed transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def label_transform(label): # Resize using nearest neighbor so that labels are not interpolated label = label.resize((480, 360), Image.NEAREST) label = np.array(label, dtype=np.int64) return torch.from_numpy(label) num_classes = 32 data_root = 'camvid/CamVid/' # make sure this matches your structure # Load datasets and dataloaders (assuming CamVidDataset is already defined) train_dataset = CamVidDataset(root=data_root, split='train', transform=input_transform, target_transform=label_transform) val_dataset = CamVidDataset(root=data_root, split='val', transform=input_transform, target_transform=label_transform) test_dataset = CamVidDataset(root=data_root, split='test', transform=input_transform, target_transform=label_transform) train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4)