| import os |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms, models |
| from PIL import Image, UnidentifiedImageError |
| from transformers import ConvNextModel, ConvNextImageProcessor |
|
|
| |
| class Car_Classifier_Resnet(nn.Module): |
| def __init__(self, num_classes): |
| super().__init__() |
| self.model = models.resnet18(weights="DEFAULT") |
|
|
| for param in self.model.parameters(): |
| param.requires_grad = False |
| for param in self.model.layer3.parameters(): |
| param.requires_grad = True |
| for param in self.model.layer4.parameters(): |
| param.requires_grad = True |
|
|
| |
| self.model.fc = nn.Sequential( |
| nn.Dropout(0.5), |
| nn.Linear(self.model.fc.in_features, 256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, num_classes) |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class ResnetCarDamagePredictor: |
| def __init__(self, checkpoint_path, class_map): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.class_map = class_map |
|
|
| self.test_transforms = transforms.Compose([ |
| transforms.Resize((128, 128)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], |
| [0.229, 0.224, 0.225]) |
| ]) |
|
|
| try: |
| self.model = Car_Classifier_Resnet(num_classes=len(class_map)) |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| self.model.load_state_dict(checkpoint["model_state_dict"]) |
| self.model.to(self.device) |
| self.model.eval() |
| except Exception as e: |
| raise RuntimeError(f"Failed to load ResNet model: {str(e)}") |
|
|
| def resnet_predict(self, image_input): |
| try: |
| if isinstance(image_input, str): |
| image = Image.open(image_input).convert("RGB") |
| elif isinstance(image_input, Image.Image): |
| image = image_input.convert("RGB") |
| else: |
| raise TypeError("image_input must be a file path or PIL.Image") |
|
|
| image = self.test_transforms(image) |
| image = image.unsqueeze(0).to(self.device) |
|
|
| with torch.no_grad(): |
| outputs = self.model(image) |
|
|
| probs = torch.nn.functional.softmax(outputs, dim=1)[0] |
| class_probs = { |
| self.class_map[i]: float(probs[i].item()) |
| for i in range(len(self.class_map)) |
| } |
| return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) |
|
|
| except UnidentifiedImageError: |
| raise ValueError("Invalid image file provided") |
| except Exception as e: |
| raise RuntimeError(f"ResNet prediction failed: {str(e)}") |
|
|
| |
| class FusionClassifier(nn.Module): |
| def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"): |
| super().__init__() |
| |
| eff = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1) |
| for param in eff.parameters(): |
| param.requires_grad = False |
| for param in eff.features[5].parameters(): |
| param.requires_grad = True |
| for param in eff.features[6].parameters(): |
| param.requires_grad = True |
| for param in eff.features[7].parameters(): |
| param.requires_grad = True |
|
|
| self.eff_features = eff.features |
| self.eff_avgpool = eff.avgpool |
| self.eff_out_dim = eff.classifier[1].in_features |
|
|
| |
| cnx = ConvNextModel.from_pretrained(convnext_model_name) |
| for param in cnx.parameters(): |
| param.requires_grad = False |
| for param in cnx.encoder.stages[2].parameters(): |
| param.requires_grad = True |
| for param in cnx.encoder.stages[3].parameters(): |
| param.requires_grad = True |
| for param in cnx.layernorm.parameters(): |
| param.requires_grad = True |
| self.cnx_backbone = cnx |
| self.cnx_out_dim = 768 |
|
|
| |
| fused_dim = self.eff_out_dim + self.cnx_out_dim |
| self.fusion_head = nn.Sequential( |
| nn.Dropout(p=0.4), |
| nn.Linear(fused_dim, 512), |
| nn.LayerNorm(512), |
| nn.GELU(), |
| nn.Dropout(p=0.3), |
| nn.Linear(512, 256), |
| nn.LayerNorm(256), |
| nn.GELU(), |
| nn.Dropout(p=0.2), |
| nn.Linear(256, num_classes) |
| ) |
|
|
| def forward(self, pixel_values_eff, pixel_values_cnx): |
| x_eff = self.eff_features(pixel_values_eff) |
| x_eff = self.eff_avgpool(x_eff) |
| x_eff = torch.flatten(x_eff, 1) |
|
|
| cnx_out = self.cnx_backbone(pixel_values=pixel_values_cnx, return_dict=True) |
| x_cnx = cnx_out.pooler_output |
|
|
| fused = torch.cat([x_eff, x_cnx], dim=1) |
| logits = self.fusion_head(fused) |
| return logits |
|
|
|
|
| |
| class FusionCarDamagePredictor: |
| def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"): |
| """ |
| Args: |
| checkpoint_path (str): Path to the .pt/.pth file containing the model state dict. |
| class_map (dict): Mapping from class index (int) to class name (str). |
| convnext_model_name (str): Pretrained ConvNeXt model name from HuggingFace. |
| """ |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.class_map = class_map |
|
|
| |
| |
| self.eff_normalize = transforms.Compose([ |
| transforms.Resize((260, 260)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], |
| [0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| self.convnext_processor = ConvNextImageProcessor.from_pretrained(convnext_model_name) |
|
|
| |
| try: |
| self.model = FusionClassifier( |
| num_classes=len(class_map), |
| convnext_model_name=convnext_model_name |
| ) |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| self.model.load_state_dict(checkpoint["model_state_dict"]) |
| self.model.to(self.device) |
| self.model.eval() |
| except Exception as e: |
| raise RuntimeError(f"Failed to load Fusion model: {str(e)}") |
|
|
| def predict(self, image_input): |
| """ |
| Args: |
| image_input (str or PIL.Image): Path to image file or PIL Image. |
| Returns: |
| dict: Sorted dictionary {class_name: probability} from highest to lowest. |
| """ |
| try: |
| |
| if isinstance(image_input, str): |
| image = Image.open(image_input).convert("RGB") |
| elif isinstance(image_input, Image.Image): |
| image = image_input.convert("RGB") |
| else: |
| raise TypeError("image_input must be a file path or PIL.Image") |
|
|
| |
| pixel_eff = self.eff_normalize(image) |
| pixel_eff = pixel_eff.unsqueeze(0).to(self.device) |
|
|
| |
| inputs_cnx = self.convnext_processor(images=image, return_tensors="pt") |
| pixel_cnx = inputs_cnx["pixel_values"].to(self.device) |
|
|
| |
| with torch.no_grad(): |
| logits = self.model(pixel_eff, pixel_cnx) |
| probs = torch.nn.functional.softmax(logits, dim=1)[0] |
|
|
| |
| class_probs = { |
| self.class_map[i]: float(probs[i].item()) |
| for i in range(len(self.class_map)) |
| } |
| return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) |
|
|
| except UnidentifiedImageError: |
| raise ValueError("Invalid image file provided") |
| except Exception as e: |
| raise RuntimeError(f"Fusion prediction failed: {str(e)}") |
| |