|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| from torch import nn |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import torch.optim as optim |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import cv2 |
| import os |
| import torchvision |
| from tqdm import tqdm |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
| |
| torch.backends.cudnn.benchmark = True |
|
|
| plt.switch_backend('Agg') |
|
|
| |
| def save_image(img, filename): |
| img = np.array(img) |
| mean = np.array([0.485, 0.456, 0.406]) |
| std = np.array([0.229, 0.224, 0.225]) |
| img = np.moveaxis(img, 0, -1) |
| img = (img * std + mean) * 255 |
| cv2.imwrite(filename, img[:, :, ::-1]) |
|
|
| |
| value_map = { |
| 0: 0, 100: 1, 200: 2, 300: 3, 500: 4, |
| 550: 5, 700: 6, 800: 7, 7100: 8, 10000: 9 |
| } |
| n_classes = len(value_map) |
|
|
| def convert_mask(mask): |
| arr = np.array(mask) |
| new_arr = np.zeros_like(arr, dtype=np.uint8) |
| for raw_value, new_value in value_map.items(): |
| new_arr[arr == raw_value] = new_value |
| return Image.fromarray(new_arr) |
|
|
| |
| class MaskDataset(Dataset): |
| def __init__(self, data_dir, transform=None): |
| self.image_dir = os.path.join(data_dir, 'Color_Images') |
| self.masks_dir = os.path.join(data_dir, 'Segmentation') |
| self.transform = transform |
| self.data_ids = os.listdir(self.image_dir) |
|
|
| def __len__(self): |
| return len(self.data_ids) |
|
|
| def __getitem__(self, idx): |
| data_id = self.data_ids[idx] |
| img_path = os.path.join(self.image_dir, data_id) |
| mask_path = os.path.join(self.masks_dir, data_id) |
|
|
| image = np.array(Image.open(img_path).convert("RGB")) |
| mask_pil = Image.open(mask_path) |
| mask = convert_mask(mask_pil) |
| mask = np.array(mask) |
|
|
| if self.transform: |
| augmented = self.transform(image=image, mask=mask) |
| image = augmented['image'] |
| mask = augmented['mask'] |
| else: |
| image = torch.from_numpy(image).permute(2,0,1).float() / 255.0 |
| mask = torch.from_numpy(mask).long() |
| return image, mask |
|
|
| |
| class SegmentationHeadConvNeXt(nn.Module): |
| def __init__(self, in_channels, out_channels, tokenW, tokenH): |
| super().__init__() |
| self.H, self.W = tokenH, tokenW |
| self.stem = nn.Sequential( |
| nn.Conv2d(in_channels, 128, kernel_size=7, padding=3), |
| nn.GELU() |
| ) |
| self.block = nn.Sequential( |
| nn.Conv2d(128, 128, kernel_size=7, padding=3, groups=128), |
| nn.GELU(), |
| nn.Conv2d(128, 128, kernel_size=1), |
| nn.GELU(), |
| ) |
| self.classifier = nn.Conv2d(128, out_channels, 1) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| x = x.reshape(B, self.H, self.W, C).permute(0, 3, 1, 2) |
| x = self.stem(x) |
| x = self.block(x) |
| return self.classifier(x) |
|
|
| |
| def compute_iou(pred, target, num_classes=10, ignore_index=255): |
| pred = torch.argmax(pred, dim=1) |
| pred, target = pred.view(-1), target.view(-1) |
| iou_per_class = [] |
| for class_id in range(num_classes): |
| if class_id == ignore_index: |
| continue |
| pred_inds = pred == class_id |
| target_inds = target == class_id |
| intersection = (pred_inds & target_inds).sum().float() |
| union = (pred_inds | target_inds).sum().float() |
| if union == 0: |
| iou_per_class.append(float('nan')) |
| else: |
| iou_per_class.append((intersection / union).cpu().numpy()) |
| return np.nanmean(iou_per_class) |
|
|
| def compute_dice(pred, target, num_classes=10, smooth=1e-6): |
| pred = torch.argmax(pred, dim=1) |
| pred, target = pred.view(-1), target.view(-1) |
| dice_per_class = [] |
| for class_id in range(num_classes): |
| pred_inds = pred == class_id |
| target_inds = target == class_id |
| intersection = (pred_inds & target_inds).sum().float() |
| dice_score = (2. * intersection + smooth) / (pred_inds.sum().float() + target_inds.sum().float() + smooth) |
| dice_per_class.append(dice_score.cpu().numpy()) |
| return np.mean(dice_per_class) |
|
|
| def compute_pixel_accuracy(pred, target): |
| pred_classes = torch.argmax(pred, dim=1) |
| return (pred_classes == target).float().mean().cpu().numpy() |
|
|
| def evaluate_metrics(model, backbone, data_loader, device, num_classes=10, show_progress=True): |
| iou_scores = []; dice_scores = []; pixel_accuracies = [] |
| model.eval() |
| loader = tqdm(data_loader, desc="Evaluating", leave=False, unit="batch") if show_progress else data_loader |
| with torch.no_grad(): |
| for imgs, labels in loader: |
| imgs, labels = imgs.to(device), labels.to(device) |
| output = backbone.forward_features(imgs)["x_norm_patchtokens"] |
| logits = model(output.to(device)) |
| outputs = F.interpolate(logits, size=imgs.shape[2:], mode="bilinear", align_corners=False) |
| labels = labels.squeeze(dim=1).long() |
| iou = compute_iou(outputs, labels, num_classes=num_classes) |
| dice = compute_dice(outputs, labels, num_classes=num_classes) |
| pixel_acc = compute_pixel_accuracy(outputs, labels) |
| iou_scores.append(iou); dice_scores.append(dice); pixel_accuracies.append(pixel_acc) |
| model.train() |
| return np.mean(iou_scores), np.mean(dice_scores), np.mean(pixel_accuracies) |
|
|
| |
| def save_training_plots(history, output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
| plt.figure(figsize=(12,5)) |
| plt.subplot(1,2,1); plt.plot(history['train_loss'], label='train'); plt.plot(history['val_loss'], label='val'); plt.title('Loss'); plt.legend(); plt.grid() |
| plt.subplot(1,2,2); plt.plot(history['train_pixel_acc'], label='train'); plt.plot(history['val_pixel_acc'], label='val'); plt.title('Pixel Acc'); plt.legend(); plt.grid() |
| plt.tight_layout(); plt.savefig(os.path.join(output_dir, 'training_curves.png')); plt.close() |
| plt.figure(figsize=(12,5)) |
| plt.subplot(1,2,1); plt.plot(history['train_iou'], label='Train IoU'); plt.title('Train IoU'); plt.grid() |
| plt.subplot(1,2,2); plt.plot(history['val_iou'], label='Val IoU'); plt.title('Val IoU'); plt.grid() |
| plt.tight_layout(); plt.savefig(os.path.join(output_dir, 'iou_curves.png')); plt.close() |
| print("Plots saved.") |
|
|
| def save_history_to_file(history, output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
| filepath = os.path.join(output_dir, 'evaluation_metrics.txt') |
| with open(filepath, 'w') as f: |
| f.write(f"Final Val IoU: {history['val_iou'][-1]:.4f}\nBest Val IoU: {max(history['val_iou']):.4f}\n") |
| print(f"Metrics saved to {filepath}") |
|
|
| |
| def main(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| |
| batch_size = 8 |
| w = int(((960 / 2) // 14) * 14) |
| h = int(((540 / 2) // 14) * 14) |
| lr = 1e-4 |
| n_epochs = 20 |
|
|
| |
| base_path = '/content' |
| data_dir = os.path.join(base_path, 'Offroad_Segmentation_Training_Dataset', 'train') |
| val_dir = os.path.join(base_path, 'Offroad_Segmentation_Training_Dataset', 'val') |
| output_dir = os.path.join(base_path, 'train_stats') |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| train_transform = A.Compose([ |
| A.Resize(height=h, width=w), |
| A.HorizontalFlip(p=0.5), |
| A.Rotate(limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5), |
| A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
| val_transform = A.Compose([ |
| A.Resize(height=h, width=w), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
|
|
| trainset = MaskDataset(data_dir=data_dir, transform=train_transform) |
| |
| train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, |
| num_workers=4, pin_memory=True, persistent_workers=True) |
| valset = MaskDataset(data_dir=val_dir, transform=val_transform) |
| val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, |
| num_workers=4, pin_memory=True, persistent_workers=True) |
|
|
| print(f"Train samples: {len(trainset)}, Val samples: {len(valset)}") |
| print(f"Batch size: {batch_size}, Workers: 4") |
|
|
| |
| backbone_model = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model="dinov2_vits14") |
| backbone_model.eval() |
| backbone_model.to(device) |
|
|
| |
| imgs, _ = next(iter(train_loader)) |
| imgs = imgs.to(device) |
| with torch.no_grad(): |
| output = backbone_model.forward_features(imgs)["x_norm_patchtokens"] |
| n_embedding = output.shape[2] |
|
|
| classifier = SegmentationHeadConvNeXt( |
| in_channels=n_embedding, |
| out_channels=n_classes, |
| tokenW=w // 14, |
| tokenH=h // 14 |
| ).to(device) |
|
|
| loss_fct = torch.nn.CrossEntropyLoss() |
| optimizer = optim.SGD(classifier.parameters(), lr=lr, momentum=0.9) |
|
|
| history = {k:[] for k in ['train_loss','val_loss','train_iou','val_iou','train_dice','val_dice','train_pixel_acc','val_pixel_acc']} |
|
|
| for epoch in range(n_epochs): |
| classifier.train() |
| train_losses = [] |
| for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train"): |
| imgs, labels = imgs.to(device), labels.to(device) |
| with torch.no_grad(): |
| output = backbone_model.forward_features(imgs)["x_norm_patchtokens"] |
| logits = classifier(output) |
| outputs = F.interpolate(logits, size=imgs.shape[2:], mode="bilinear", align_corners=False) |
| labels = labels.squeeze(dim=1).long() |
| loss = loss_fct(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
| train_losses.append(loss.item()) |
|
|
| classifier.eval() |
| val_losses = [] |
| with torch.no_grad(): |
| for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} Val"): |
| imgs, labels = imgs.to(device), labels.to(device) |
| output = backbone_model.forward_features(imgs)["x_norm_patchtokens"] |
| logits = classifier(output) |
| outputs = F.interpolate(logits, size=imgs.shape[2:], mode="bilinear", align_corners=False) |
| labels = labels.squeeze(dim=1).long() |
| loss = loss_fct(outputs, labels) |
| val_losses.append(loss.item()) |
|
|
| train_iou, train_dice, train_acc = evaluate_metrics(classifier, backbone_model, train_loader, device, num_classes=n_classes, show_progress=False) |
| val_iou, val_dice, val_acc = evaluate_metrics(classifier, backbone_model, val_loader, device, num_classes=n_classes, show_progress=False) |
|
|
| history['train_loss'].append(np.mean(train_losses)) |
| history['val_loss'].append(np.mean(val_losses)) |
| history['train_iou'].append(train_iou) |
| history['val_iou'].append(val_iou) |
| history['train_dice'].append(train_dice) |
| history['val_dice'].append(val_dice) |
| history['train_pixel_acc'].append(train_acc) |
| history['val_pixel_acc'].append(val_acc) |
|
|
| print(f"Epoch {epoch+1}: Train Loss={history['train_loss'][-1]:.4f}, Val Loss={history['val_loss'][-1]:.4f}, Val IoU={val_iou:.4f}") |
|
|
| save_training_plots(history, output_dir) |
| save_history_to_file(history, output_dir) |
| torch.save(classifier.state_dict(), os.path.join(output_dir, 'segmentation_head.pth')) |
| print("Training complete. Model saved.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|