Spaces:
Build error
Build error
File size: 2,591 Bytes
0f9608b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import wandb
import os
import matplotlib.pyplot as plt
from .architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE, LR, EPOCHS, train_loader, val_loader, IMAGE_SIZE
from tqdm import tqdm
model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=255)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def pixel_accuracy(preds, labels):
_, preds = torch.max(preds, 1)
correct = (preds == labels).float()
acc = correct.sum() / correct.numel()
return acc
# def mean_iou(preds, labels, num_classes=NUM_CLASSES):
# _, preds = torch.max(preds, 1)
# ious = []
# for cls in range(num_classes):
# intersection = ((preds == cls) & (labels == cls)).float().sum()
# union = ((preds == cls) | (labels == cls)).float().sum()
# if union > 0:
# ious.append(intersection / union)
# return sum(ious) / len(ious) if ious else 0
for epoch in tqdm(range(EPOCHS)):
model.train()
train_loss, train_acc = 0.0, 0.0
for images, masks in train_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += pixel_accuracy(outputs, masks).item()
train_loss /= len(train_loader)
train_acc /= len(train_loader)
# Validation
model.eval()
val_loss, val_acc = 0.0, 0.0
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item()
val_acc += pixel_accuracy(outputs, masks).item()
val_loss /= len(val_loader)
val_acc /= len(val_loader)
# wandb.log({
# "epoch": epoch + 1,
# "train_loss": train_loss,
# "train_accuracy": train_acc,
# "val_loss": val_loss,
# "val_accuracy": val_acc
# })
print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
torch.save(model.state_dict(), "segnet_efficientnet_camvid.pth")
# wandb.finish()
|