Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms | |
| from torchvision.datasets import VOCSegmentation | |
| from torch.utils.data import DataLoader | |
| from PIL import Image | |
| import numpy as np | |
| import wandb | |
| import os | |
| import matplotlib.pyplot as plt | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| # wandb.login(key="your_wandb_api_key_here") | |
| EPOCHS = 25 | |
| BATCH_SIZE = 8 | |
| LR = 1e-3 | |
| NUM_CLASSES = 21 # Pascal VOC has 21 classes including background | |
| IMAGE_SIZE = (256, 256) | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # wandb.init(project="segnet-efficientnet-voc", config={ | |
| # "epochs": EPOCHS, | |
| # "batch_size": BATCH_SIZE, | |
| # "learning_rate": LR, | |
| # "architecture": "SegNet-EfficientNet", | |
| # "dataset": "PascalVOC2012" | |
| # }) | |
| class SegNetEfficientNet(nn.Module): | |
| def __init__(self, num_classes): | |
| super(SegNetEfficientNet, self).__init__() | |
| base_model = models.efficientnet_b0(pretrained=True) | |
| features = list(base_model.features.children()) | |
| # Encoder: Use EfficientNet blocks | |
| self.encoder = nn.Sequential(*features) | |
| # Decoder: Up-convolutions | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(64, num_classes, kernel_size=1) | |
| ) | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| x = self.decoder(x) | |
| x = F.interpolate(x, size=IMAGE_SIZE, mode='bilinear', align_corners=False) | |
| return x | |
| class VOCSegmentationDataset(VOCSegmentation): | |
| def __init__(self, root, image_set='train', transform=None, target_transform=None): | |
| super().__init__(root=root, year='2012', image_set=image_set, download=True) | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| def __getitem__(self, index): | |
| img, target = super().__getitem__(index) | |
| if self.transform: | |
| img = self.transform(img) | |
| if self.target_transform: | |
| target = self.target_transform(target) | |
| target = torch.as_tensor(np.array(target), dtype=torch.long) | |
| return img, target | |
| if __name__ == "__main__": | |
| image_transform = transforms.Compose([ | |
| transforms.Resize(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| mask_transform = transforms.Resize(IMAGE_SIZE, interpolation=Image.NEAREST) | |
| train_dataset = VOCSegmentationDataset("voc_data", 'train', image_transform, mask_transform) | |
| val_dataset = VOCSegmentationDataset("voc_data", 'val', image_transform, mask_transform) | |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) | |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) | |