Spaces:
Build error
Build error
File size: 6,380 Bytes
0f9608b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import glob
from PIL import Image
import numpy as np
import wandb
import pandas as pd
import os
import matplotlib.pyplot as plt
import opendatasets as opd
import zipfile
torch.manual_seed(42)
np.random.seed(42)
# wandb.login(key="your_wandb_api_key_here")
EPOCHS = 25
BATCH_SIZE = 8
LR = 1e-3
NUM_CLASSES = 32
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# wandb.init(project="segnet-efficientnet-camvid", config={
# "epochs": EPOCHS,
# "batch_size": BATCH_SIZE,
# "learning_rate": LR,
# "architecture": "SegNet-EfficientNet",
# "dataset": "CamVid"
# })
class SegNetEfficientNet(nn.Module):
def __init__(self, num_classes=32):
super(SegNetEfficientNet, self).__init__()
base_model = models.efficientnet_b0(pretrained=True)
features = list(base_model.features.children())
# EfficientNet-B0 backbone (output channels gradually increase to 1280)
self.encoder = nn.Sequential(*features) # Output: [B, 1280, H/32, W/32]
# Decoder blocks (mirroring encoder with ConvTranspose2d)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
def forward(self, x):
x = self.encoder(x) # Downsampled features from EfficientNet
x = self.decoder(x) # Upsampled
x = self.classifier(x)
x = F.interpolate(x, size=(360, 480), mode='bilinear', align_corners=False)
return x
class CamVidDataset(Dataset):
"""
CamVid dataset loader with RGB mask to class index conversion.
Expects directory structure:
camvid/
train/
train_labels/
val/
val_labels/
test/
test_labels/
"""
def __init__(self, root, split='train', transform=None, image_size=(360, 480), target_transform=None, class_dict_path='camvid/CamVid/class_dict.csv'):
self.root = root
self.split = split
self.transform = transform
self.target_transform = target_transform
self.image_dir = os.path.join(root, split)
self.label_dir = os.path.join(root, f"{split}_labels")
self.image_paths = sorted(glob.glob(os.path.join(self.image_dir, '*.png')))
self.label_paths = sorted(glob.glob(os.path.join(self.label_dir, '*.png')))
self.label_resize = transforms.Resize(image_size, interpolation=Image.NEAREST)
self.image_resize = transforms.Resize(image_size, interpolation=Image.BILINEAR)
assert len(self.image_paths) == len(self.label_paths), "Mismatch between images and labels."
# Load class_dict.csv and build color-to-class mapping
df = pd.read_csv(class_dict_path)
self.color_to_class = {
(row['r'], row['g'], row['b']): idx for idx, row in df.iterrows()
}
def __len__(self):
return len(self.image_paths)
def rgb_to_class(self, mask):
"""Convert an RGB mask (PIL.Image) to a 2D class index mask."""
mask_np = np.array(mask)
h, w, _ = mask_np.shape
class_mask = np.zeros((h, w), dtype=np.uint8)
for rgb, class_idx in self.color_to_class.items():
matches = (mask_np == rgb).all(axis=2)
class_mask[matches] = class_idx
return class_mask
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
label = Image.open(self.label_paths[idx]).convert('RGB')
# Resize both to 360x480
image = self.image_resize(image)
label = self.label_resize(label)
if self.transform:
image = self.transform(image)
label = self.rgb_to_class(label)
label = torch.from_numpy(label).long()
return image, label
if __name__ == "__main__":
dataset_url = "https://www.kaggle.com/datasets/carlolepelaars/camvid"
opd.download(dataset_url)
# Set dataset folder (adjust path if needed)
dataset_folder = "camvid"
print("Dataset directory contents:")
print(os.listdir(dataset_folder))
input_transform = transforms.Compose([
transforms.Resize((360, 480)), # Or larger if needed
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def label_transform(label):
# Resize using nearest neighbor so that labels are not interpolated
label = label.resize((480, 360), Image.NEAREST)
label = np.array(label, dtype=np.int64)
return torch.from_numpy(label)
num_classes = 32
data_root = 'camvid/CamVid/' # make sure this matches your structure
# Load datasets and dataloaders (assuming CamVidDataset is already defined)
train_dataset = CamVidDataset(root=data_root, split='train',
transform=input_transform, target_transform=label_transform)
val_dataset = CamVidDataset(root=data_root, split='val',
transform=input_transform, target_transform=label_transform)
test_dataset = CamVidDataset(root=data_root, split='test',
transform=input_transform, target_transform=label_transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4)
|