import os import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image, UnidentifiedImageError from transformers import ConvNextModel, ConvNextImageProcessor # ================================ ResNet-18 Classifier ================================ class Car_Classifier_Resnet(nn.Module): def __init__(self, num_classes): super().__init__() self.model = models.resnet18(weights="DEFAULT") for param in self.model.parameters(): param.requires_grad = False for param in self.model.layer3.parameters(): param.requires_grad = True for param in self.model.layer4.parameters(): param.requires_grad = True # Replace FC head self.model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(self.model.fc.in_features, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) def forward(self, x): return self.model(x) class ResnetCarDamagePredictor: def __init__(self, checkpoint_path, class_map): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.class_map = class_map self.test_transforms = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) try: self.model = Car_Classifier_Resnet(num_classes=len(class_map)) checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Failed to load ResNet model: {str(e)}") def resnet_predict(self, image_input): try: if isinstance(image_input, str): image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: raise TypeError("image_input must be a file path or PIL.Image") image = self.test_transforms(image) image = image.unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(image) probs = torch.nn.functional.softmax(outputs, dim=1)[0] class_probs = { self.class_map[i]: float(probs[i].item()) for i in range(len(self.class_map)) } return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) except UnidentifiedImageError: raise ValueError("Invalid image file provided") except Exception as e: raise RuntimeError(f"ResNet prediction failed: {str(e)}") # ================================ Fusion Classifier (your model) ================================ class FusionClassifier(nn.Module): def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"): super().__init__() # EfficientNet-V2-S backbone eff = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1) for param in eff.parameters(): param.requires_grad = False for param in eff.features[5].parameters(): param.requires_grad = True for param in eff.features[6].parameters(): param.requires_grad = True for param in eff.features[7].parameters(): param.requires_grad = True self.eff_features = eff.features self.eff_avgpool = eff.avgpool self.eff_out_dim = eff.classifier[1].in_features # 1280 # ConvNeXt-Small backbone cnx = ConvNextModel.from_pretrained(convnext_model_name) for param in cnx.parameters(): param.requires_grad = False for param in cnx.encoder.stages[2].parameters(): param.requires_grad = True for param in cnx.encoder.stages[3].parameters(): param.requires_grad = True for param in cnx.layernorm.parameters(): param.requires_grad = True self.cnx_backbone = cnx self.cnx_out_dim = 768 # Fusion head fused_dim = self.eff_out_dim + self.cnx_out_dim # 2048 self.fusion_head = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(fused_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(256, num_classes) ) def forward(self, pixel_values_eff, pixel_values_cnx): x_eff = self.eff_features(pixel_values_eff) x_eff = self.eff_avgpool(x_eff) x_eff = torch.flatten(x_eff, 1) cnx_out = self.cnx_backbone(pixel_values=pixel_values_cnx, return_dict=True) x_cnx = cnx_out.pooler_output fused = torch.cat([x_eff, x_cnx], dim=1) logits = self.fusion_head(fused) return logits # ================================ New Predictor Class ================================ class FusionCarDamagePredictor: def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"): """ Args: checkpoint_path (str): Path to the .pt/.pth file containing the model state dict. class_map (dict): Mapping from class index (int) to class name (str). convnext_model_name (str): Pretrained ConvNeXt model name from HuggingFace. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.class_map = class_map # ---------- Preprocessing pipelines ---------- # EfficientNet: resize to 260, then ImageNet normalization self.eff_normalize = transforms.Compose([ transforms.Resize((260, 260)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ConvNeXt: its own processor (resize to 224 + its own mean/std) self.convnext_processor = ConvNextImageProcessor.from_pretrained(convnext_model_name) # ---------- Load model ---------- try: self.model = FusionClassifier( num_classes=len(class_map), convnext_model_name=convnext_model_name ) checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Failed to load Fusion model: {str(e)}") def predict(self, image_input): """ Args: image_input (str or PIL.Image): Path to image file or PIL Image. Returns: dict: Sorted dictionary {class_name: probability} from highest to lowest. """ try: # Load and convert image if isinstance(image_input, str): image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: raise TypeError("image_input must be a file path or PIL.Image") # ---- EfficientNet branch ---- pixel_eff = self.eff_normalize(image) # (3, 260, 260) pixel_eff = pixel_eff.unsqueeze(0).to(self.device) # (1, 3, 260, 260) # ---- ConvNeXt branch ---- inputs_cnx = self.convnext_processor(images=image, return_tensors="pt") pixel_cnx = inputs_cnx["pixel_values"].to(self.device) # (1, 3, 224, 224) # ---- Forward pass ---- with torch.no_grad(): logits = self.model(pixel_eff, pixel_cnx) # (1, num_classes) probs = torch.nn.functional.softmax(logits, dim=1)[0] # (num_classes,) # Convert to dict of class -> probability class_probs = { self.class_map[i]: float(probs[i].item()) for i in range(len(self.class_map)) } return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) except UnidentifiedImageError: raise ValueError("Invalid image file provided") except Exception as e: raise RuntimeError(f"Fusion prediction failed: {str(e)}")