Spaces:
Build error
Build error
File size: 3,231 Bytes
4bb934b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import wandb
import os
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
# wandb.login(key="your_wandb_api_key_here")
EPOCHS = 25
BATCH_SIZE = 8
LR = 1e-3
NUM_CLASSES = 21 # Pascal VOC has 21 classes including background
IMAGE_SIZE = (256, 256)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# wandb.init(project="segnet-efficientnet-voc", config={
# "epochs": EPOCHS,
# "batch_size": BATCH_SIZE,
# "learning_rate": LR,
# "architecture": "SegNet-EfficientNet",
# "dataset": "PascalVOC2012"
# })
class SegNetEfficientNet(nn.Module):
def __init__(self, num_classes):
super(SegNetEfficientNet, self).__init__()
base_model = models.efficientnet_b0(pretrained=True)
features = list(base_model.features.children())
# Encoder: Use EfficientNet blocks
self.encoder = nn.Sequential(*features)
# Decoder: Up-convolutions
self.decoder = nn.Sequential(
nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, num_classes, kernel_size=1)
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
x = F.interpolate(x, size=IMAGE_SIZE, mode='bilinear', align_corners=False)
return x
class VOCSegmentationDataset(VOCSegmentation):
def __init__(self, root, image_set='train', transform=None, target_transform=None):
super().__init__(root=root, year='2012', image_set=image_set, download=True)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
img, target = super().__getitem__(index)
if self.transform:
img = self.transform(img)
if self.target_transform:
target = self.target_transform(target)
target = torch.as_tensor(np.array(target), dtype=torch.long)
return img, target
if __name__ == "__main__":
image_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
mask_transform = transforms.Resize(IMAGE_SIZE, interpolation=Image.NEAREST)
train_dataset = VOCSegmentationDataset("voc_data", 'train', image_transform, mask_transform)
val_dataset = VOCSegmentationDataset("voc_data", 'val', image_transform, mask_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
|