Offroad_segmentation / train_segmentation_optimized.py
rpudathu's picture
Upload folder using huggingface_hub
57440be verified
Raw
History Blame Contribute Delete
12.1 kB
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import cv2
import os
import torchvision
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Enable cudnn benchmark for fixed input sizes
torch.backends.cudnn.benchmark = True
plt.switch_backend('Agg')
# ========== Utility Functions ==========
def save_image(img, filename):
img = np.array(img)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = np.moveaxis(img, 0, -1)
img = (img * std + mean) * 255
cv2.imwrite(filename, img[:, :, ::-1])
# ========== Mask Conversion ==========
value_map = {
0: 0, 100: 1, 200: 2, 300: 3, 500: 4,
550: 5, 700: 6, 800: 7, 7100: 8, 10000: 9
}
n_classes = len(value_map)
def convert_mask(mask):
arr = np.array(mask)
new_arr = np.zeros_like(arr, dtype=np.uint8)
for raw_value, new_value in value_map.items():
new_arr[arr == raw_value] = new_value
return Image.fromarray(new_arr)
# ========== Dataset ==========
class MaskDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.image_dir = os.path.join(data_dir, 'Color_Images')
self.masks_dir = os.path.join(data_dir, 'Segmentation')
self.transform = transform
self.data_ids = os.listdir(self.image_dir)
def __len__(self):
return len(self.data_ids)
def __getitem__(self, idx):
data_id = self.data_ids[idx]
img_path = os.path.join(self.image_dir, data_id)
mask_path = os.path.join(self.masks_dir, data_id)
image = np.array(Image.open(img_path).convert("RGB"))
mask_pil = Image.open(mask_path)
mask = convert_mask(mask_pil)
mask = np.array(mask)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
else:
image = torch.from_numpy(image).permute(2,0,1).float() / 255.0
mask = torch.from_numpy(mask).long()
return image, mask
# ========== Model ==========
class SegmentationHeadConvNeXt(nn.Module):
def __init__(self, in_channels, out_channels, tokenW, tokenH):
super().__init__()
self.H, self.W = tokenH, tokenW
self.stem = nn.Sequential(
nn.Conv2d(in_channels, 128, kernel_size=7, padding=3),
nn.GELU()
)
self.block = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=7, padding=3, groups=128),
nn.GELU(),
nn.Conv2d(128, 128, kernel_size=1),
nn.GELU(),
)
self.classifier = nn.Conv2d(128, out_channels, 1)
def forward(self, x):
B, N, C = x.shape
x = x.reshape(B, self.H, self.W, C).permute(0, 3, 1, 2)
x = self.stem(x)
x = self.block(x)
return self.classifier(x)
# ========== Metrics ==========
def compute_iou(pred, target, num_classes=10, ignore_index=255):
pred = torch.argmax(pred, dim=1)
pred, target = pred.view(-1), target.view(-1)
iou_per_class = []
for class_id in range(num_classes):
if class_id == ignore_index:
continue
pred_inds = pred == class_id
target_inds = target == class_id
intersection = (pred_inds & target_inds).sum().float()
union = (pred_inds | target_inds).sum().float()
if union == 0:
iou_per_class.append(float('nan'))
else:
iou_per_class.append((intersection / union).cpu().numpy())
return np.nanmean(iou_per_class)
def compute_dice(pred, target, num_classes=10, smooth=1e-6):
pred = torch.argmax(pred, dim=1)
pred, target = pred.view(-1), target.view(-1)
dice_per_class = []
for class_id in range(num_classes):
pred_inds = pred == class_id
target_inds = target == class_id
intersection = (pred_inds & target_inds).sum().float()
dice_score = (2. * intersection + smooth) / (pred_inds.sum().float() + target_inds.sum().float() + smooth)
dice_per_class.append(dice_score.cpu().numpy())
return np.mean(dice_per_class)
def compute_pixel_accuracy(pred, target):
pred_classes = torch.argmax(pred, dim=1)
return (pred_classes == target).float().mean().cpu().numpy()
def evaluate_metrics(model, backbone, data_loader, device, num_classes=10, show_progress=True):
iou_scores = []; dice_scores = []; pixel_accuracies = []
model.eval()
loader = tqdm(data_loader, desc="Evaluating", leave=False, unit="batch") if show_progress else data_loader
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
output = backbone.forward_features(imgs)["x_norm_patchtokens"]
logits = model(output.to(device))
outputs = F.interpolate(logits, size=imgs.shape[2:], mode="bilinear", align_corners=False)
labels = labels.squeeze(dim=1).long()
iou = compute_iou(outputs, labels, num_classes=num_classes)
dice = compute_dice(outputs, labels, num_classes=num_classes)
pixel_acc = compute_pixel_accuracy(outputs, labels)
iou_scores.append(iou); dice_scores.append(dice); pixel_accuracies.append(pixel_acc)
model.train()
return np.mean(iou_scores), np.mean(dice_scores), np.mean(pixel_accuracies)
# ========== Plotting Functions ==========
def save_training_plots(history, output_dir):
os.makedirs(output_dir, exist_ok=True)
plt.figure(figsize=(12,5))
plt.subplot(1,2,1); plt.plot(history['train_loss'], label='train'); plt.plot(history['val_loss'], label='val'); plt.title('Loss'); plt.legend(); plt.grid()
plt.subplot(1,2,2); plt.plot(history['train_pixel_acc'], label='train'); plt.plot(history['val_pixel_acc'], label='val'); plt.title('Pixel Acc'); plt.legend(); plt.grid()
plt.tight_layout(); plt.savefig(os.path.join(output_dir, 'training_curves.png')); plt.close()
plt.figure(figsize=(12,5))
plt.subplot(1,2,1); plt.plot(history['train_iou'], label='Train IoU'); plt.title('Train IoU'); plt.grid()
plt.subplot(1,2,2); plt.plot(history['val_iou'], label='Val IoU'); plt.title('Val IoU'); plt.grid()
plt.tight_layout(); plt.savefig(os.path.join(output_dir, 'iou_curves.png')); plt.close()
print("Plots saved.")
def save_history_to_file(history, output_dir):
os.makedirs(output_dir, exist_ok=True)
filepath = os.path.join(output_dir, 'evaluation_metrics.txt')
with open(filepath, 'w') as f:
f.write(f"Final Val IoU: {history['val_iou'][-1]:.4f}\nBest Val IoU: {max(history['val_iou']):.4f}\n")
print(f"Metrics saved to {filepath}")
# ========== Main ==========
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# INCREASE BATCH SIZE (try 8, then 16 if memory allows)
batch_size = 8
w = int(((960 / 2) // 14) * 14)
h = int(((540 / 2) // 14) * 14)
lr = 1e-4
n_epochs = 20
# Local dataset
base_path = '/content'
data_dir = os.path.join(base_path, 'Offroad_Segmentation_Training_Dataset', 'train')
val_dir = os.path.join(base_path, 'Offroad_Segmentation_Training_Dataset', 'val')
output_dir = os.path.join(base_path, 'train_stats')
os.makedirs(output_dir, exist_ok=True)
train_transform = A.Compose([
A.Resize(height=h, width=w),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
val_transform = A.Compose([
A.Resize(height=h, width=w),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
trainset = MaskDataset(data_dir=data_dir, transform=train_transform)
# INCREASE NUM_WORKERS (try 4, 8) and use persistent_workers for faster restarts
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, persistent_workers=True)
valset = MaskDataset(data_dir=val_dir, transform=val_transform)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True, persistent_workers=True)
print(f"Train samples: {len(trainset)}, Val samples: {len(valset)}")
print(f"Batch size: {batch_size}, Workers: 4")
# Backbone (you can try 'base' if you have enough GPU memory)
backbone_model = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model="dinov2_vits14")
backbone_model.eval()
backbone_model.to(device)
# Get embedding dimension
imgs, _ = next(iter(train_loader))
imgs = imgs.to(device)
with torch.no_grad():
output = backbone_model.forward_features(imgs)["x_norm_patchtokens"]
n_embedding = output.shape[2]
classifier = SegmentationHeadConvNeXt(
in_channels=n_embedding,
out_channels=n_classes,
tokenW=w // 14,
tokenH=h // 14
).to(device)
loss_fct = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(classifier.parameters(), lr=lr, momentum=0.9)
history = {k:[] for k in ['train_loss','val_loss','train_iou','val_iou','train_dice','val_dice','train_pixel_acc','val_pixel_acc']}
for epoch in range(n_epochs):
classifier.train()
train_losses = []
for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train"):
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
output = backbone_model.forward_features(imgs)["x_norm_patchtokens"]
logits = classifier(output)
outputs = F.interpolate(logits, size=imgs.shape[2:], mode="bilinear", align_corners=False)
labels = labels.squeeze(dim=1).long()
loss = loss_fct(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_losses.append(loss.item())
classifier.eval()
val_losses = []
with torch.no_grad():
for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} Val"):
imgs, labels = imgs.to(device), labels.to(device)
output = backbone_model.forward_features(imgs)["x_norm_patchtokens"]
logits = classifier(output)
outputs = F.interpolate(logits, size=imgs.shape[2:], mode="bilinear", align_corners=False)
labels = labels.squeeze(dim=1).long()
loss = loss_fct(outputs, labels)
val_losses.append(loss.item())
train_iou, train_dice, train_acc = evaluate_metrics(classifier, backbone_model, train_loader, device, num_classes=n_classes, show_progress=False)
val_iou, val_dice, val_acc = evaluate_metrics(classifier, backbone_model, val_loader, device, num_classes=n_classes, show_progress=False)
history['train_loss'].append(np.mean(train_losses))
history['val_loss'].append(np.mean(val_losses))
history['train_iou'].append(train_iou)
history['val_iou'].append(val_iou)
history['train_dice'].append(train_dice)
history['val_dice'].append(val_dice)
history['train_pixel_acc'].append(train_acc)
history['val_pixel_acc'].append(val_acc)
print(f"Epoch {epoch+1}: Train Loss={history['train_loss'][-1]:.4f}, Val Loss={history['val_loss'][-1]:.4f}, Val IoU={val_iou:.4f}")
save_training_plots(history, output_dir)
save_history_to_file(history, output_dir)
torch.save(classifier.state_dict(), os.path.join(output_dir, 'segmentation_head.pth'))
print("Training complete. Model saved.")
if __name__ == "__main__":
main()