rpudathu's picture
Upload folder using huggingface_hub
57440be verified
Raw
History Blame Contribute Delete
2.45 kB
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from tqdm import tqdm
# Import model definition from your training script
# (Make sure the file train_segmentation_optimized.py is still in /content)
from train_segmentation_optimized import SegmentationHeadConvNeXt, value_map, n_classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Paths
model_path = '/content/train_stats/segmentation_head.pth' # your trained model
test_dir = '/content/Offroad_Segmentation_Training_Dataset/testImages' # where you unzipped
output_dir = '/content/predictions'
os.makedirs(output_dir, exist_ok=True)
# Image size (must match training)
h, w = 266, 476
# Load model (embedding dim = 768 for base backbone)
embed_dim = 768
classifier = SegmentationHeadConvNeXt(in_channels=embed_dim, out_channels=n_classes, tokenW=w//14, tokenH=h//14)
classifier.load_state_dict(torch.load(model_path, map_location='cpu'))
classifier = classifier.to(device)
classifier.eval()
# Load DINOv2 backbone (base, because that's what you used)
backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
backbone.eval().to(device)
# Transform (same as validation)
transform = A.Compose([
A.Resize(h, w),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
# Process all test images
test_images = [f for f in os.listdir(test_dir) if f.endswith('.png')]
print(f"Found {len(test_images)} test images.")
for img_name in tqdm(test_images):
img_path = os.path.join(test_dir, img_name)
image = np.array(Image.open(img_path).convert("RGB"))
augmented = transform(image=image)
input_tensor = augmented['image'].unsqueeze(0).to(device)
with torch.no_grad():
features = backbone.forward_features(input_tensor)["x_norm_patchtokens"]
logits = classifier(features)
logits = F.interpolate(logits, size=(h, w), mode='bilinear', align_corners=False)
pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
# Map class indices back to original IDs (0,100,200,...)
inv_map = {v:k for k,v in value_map.items()}
original_pred = np.vectorize(inv_map.get)(pred).astype(np.uint16)
Image.fromarray(original_pred).save(os.path.join(output_dir, img_name))
print(f"Predictions saved to {output_dir}")