AJain1234's picture
Upload folder using huggingface_hub
4bb934b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import wandb
import os
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
# wandb.login(key="your_wandb_api_key_here")
EPOCHS = 25
BATCH_SIZE = 8
LR = 1e-3
NUM_CLASSES = 21 # Pascal VOC has 21 classes including background
IMAGE_SIZE = (256, 256)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# wandb.init(project="segnet-efficientnet-voc", config={
# "epochs": EPOCHS,
# "batch_size": BATCH_SIZE,
# "learning_rate": LR,
# "architecture": "SegNet-EfficientNet",
# "dataset": "PascalVOC2012"
# })
class SegNetEfficientNet(nn.Module):
def __init__(self, num_classes):
super(SegNetEfficientNet, self).__init__()
base_model = models.efficientnet_b0(pretrained=True)
features = list(base_model.features.children())
# Encoder: Use EfficientNet blocks
self.encoder = nn.Sequential(*features)
# Decoder: Up-convolutions
self.decoder = nn.Sequential(
nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, num_classes, kernel_size=1)
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
x = F.interpolate(x, size=IMAGE_SIZE, mode='bilinear', align_corners=False)
return x
class VOCSegmentationDataset(VOCSegmentation):
def __init__(self, root, image_set='train', transform=None, target_transform=None):
super().__init__(root=root, year='2012', image_set=image_set, download=True)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
img, target = super().__getitem__(index)
if self.transform:
img = self.transform(img)
if self.target_transform:
target = self.target_transform(target)
target = torch.as_tensor(np.array(target), dtype=torch.long)
return img, target
if __name__ == "__main__":
image_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
mask_transform = transforms.Resize(IMAGE_SIZE, interpolation=Image.NEAREST)
train_dataset = VOCSegmentationDataset("voc_data", 'train', image_transform, mask_transform)
val_dataset = VOCSegmentationDataset("voc_data", 'val', image_transform, mask_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)