new_car / scripts /prediction_helper.py
junaid17's picture
Initial commit: DamageLens project
c5377b5
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
from transformers import ConvNextModel, ConvNextImageProcessor
# ================================ ResNet-18 Classifier ================================
class Car_Classifier_Resnet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.model = models.resnet18(weights="DEFAULT")
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.layer3.parameters():
param.requires_grad = True
for param in self.model.layer4.parameters():
param.requires_grad = True
# Replace FC head
self.model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.model.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.model(x)
class ResnetCarDamagePredictor:
def __init__(self, checkpoint_path, class_map):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.class_map = class_map
self.test_transforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
try:
self.model = Car_Classifier_Resnet(num_classes=len(class_map))
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(self.device)
self.model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load ResNet model: {str(e)}")
def resnet_predict(self, image_input):
try:
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise TypeError("image_input must be a file path or PIL.Image")
image = self.test_transforms(image)
image = image.unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(image)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
class_probs = {
self.class_map[i]: float(probs[i].item())
for i in range(len(self.class_map))
}
return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
except UnidentifiedImageError:
raise ValueError("Invalid image file provided")
except Exception as e:
raise RuntimeError(f"ResNet prediction failed: {str(e)}")
# ================================ Fusion Classifier (your model) ================================
class FusionClassifier(nn.Module):
def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"):
super().__init__()
# EfficientNet-V2-S backbone
eff = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
for param in eff.parameters():
param.requires_grad = False
for param in eff.features[5].parameters():
param.requires_grad = True
for param in eff.features[6].parameters():
param.requires_grad = True
for param in eff.features[7].parameters():
param.requires_grad = True
self.eff_features = eff.features
self.eff_avgpool = eff.avgpool
self.eff_out_dim = eff.classifier[1].in_features # 1280
# ConvNeXt-Small backbone
cnx = ConvNextModel.from_pretrained(convnext_model_name)
for param in cnx.parameters():
param.requires_grad = False
for param in cnx.encoder.stages[2].parameters():
param.requires_grad = True
for param in cnx.encoder.stages[3].parameters():
param.requires_grad = True
for param in cnx.layernorm.parameters():
param.requires_grad = True
self.cnx_backbone = cnx
self.cnx_out_dim = 768
# Fusion head
fused_dim = self.eff_out_dim + self.cnx_out_dim # 2048
self.fusion_head = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(fused_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(256, num_classes)
)
def forward(self, pixel_values_eff, pixel_values_cnx):
x_eff = self.eff_features(pixel_values_eff)
x_eff = self.eff_avgpool(x_eff)
x_eff = torch.flatten(x_eff, 1)
cnx_out = self.cnx_backbone(pixel_values=pixel_values_cnx, return_dict=True)
x_cnx = cnx_out.pooler_output
fused = torch.cat([x_eff, x_cnx], dim=1)
logits = self.fusion_head(fused)
return logits
# ================================ New Predictor Class ================================
class FusionCarDamagePredictor:
def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"):
"""
Args:
checkpoint_path (str): Path to the .pt/.pth file containing the model state dict.
class_map (dict): Mapping from class index (int) to class name (str).
convnext_model_name (str): Pretrained ConvNeXt model name from HuggingFace.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.class_map = class_map
# ---------- Preprocessing pipelines ----------
# EfficientNet: resize to 260, then ImageNet normalization
self.eff_normalize = transforms.Compose([
transforms.Resize((260, 260)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# ConvNeXt: its own processor (resize to 224 + its own mean/std)
self.convnext_processor = ConvNextImageProcessor.from_pretrained(convnext_model_name)
# ---------- Load model ----------
try:
self.model = FusionClassifier(
num_classes=len(class_map),
convnext_model_name=convnext_model_name
)
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(self.device)
self.model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load Fusion model: {str(e)}")
def predict(self, image_input):
"""
Args:
image_input (str or PIL.Image): Path to image file or PIL Image.
Returns:
dict: Sorted dictionary {class_name: probability} from highest to lowest.
"""
try:
# Load and convert image
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise TypeError("image_input must be a file path or PIL.Image")
# ---- EfficientNet branch ----
pixel_eff = self.eff_normalize(image) # (3, 260, 260)
pixel_eff = pixel_eff.unsqueeze(0).to(self.device) # (1, 3, 260, 260)
# ---- ConvNeXt branch ----
inputs_cnx = self.convnext_processor(images=image, return_tensors="pt")
pixel_cnx = inputs_cnx["pixel_values"].to(self.device) # (1, 3, 224, 224)
# ---- Forward pass ----
with torch.no_grad():
logits = self.model(pixel_eff, pixel_cnx) # (1, num_classes)
probs = torch.nn.functional.softmax(logits, dim=1)[0] # (num_classes,)
# Convert to dict of class -> probability
class_probs = {
self.class_map[i]: float(probs[i].item())
for i in range(len(self.class_map))
}
return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
except UnidentifiedImageError:
raise ValueError("Invalid image file provided")
except Exception as e:
raise RuntimeError(f"Fusion prediction failed: {str(e)}")