Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import ConvNextModel, ConvNextImageProcessor | |
| class Car_Classifier_Resnet(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| self.model = models.resnet18(weights="DEFAULT") | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| for param in self.model.layer3.parameters(): | |
| param.requires_grad = True | |
| for param in self.model.layer4.parameters(): | |
| param.requires_grad = True | |
| self.model.fc = nn.Sequential( | |
| nn.Dropout(0.5), | |
| nn.Linear(self.model.fc.in_features, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ResnetCarDamagePredictor: | |
| def __init__(self, checkpoint_path, class_map): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.class_map = class_map | |
| self.test_transforms = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| try: | |
| self.model = Car_Classifier_Resnet(num_classes=len(class_map)) | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| state_dict = checkpoint.get("model_state_dict", checkpoint) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load ResNet model: {str(e)}") | |
| def resnet_predict(self, image_input): | |
| try: | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input.convert("RGB") | |
| else: | |
| raise TypeError("image_input must be a file path or PIL.Image") | |
| image = self.test_transforms(image) | |
| image = image.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(image) | |
| probs = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| class_probs = { | |
| self.class_map[i]: float(probs[i].item()) | |
| for i in range(len(self.class_map)) | |
| } | |
| return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) | |
| except UnidentifiedImageError: | |
| raise ValueError("Invalid image file provided") | |
| except Exception as e: | |
| raise RuntimeError(f"ResNet prediction failed: {str(e)}") | |
| class FusionClassifier(nn.Module): | |
| def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"): | |
| super().__init__() | |
| eff = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1) | |
| for param in eff.parameters(): | |
| param.requires_grad = False | |
| for param in eff.features[5].parameters(): | |
| param.requires_grad = True | |
| for param in eff.features[6].parameters(): | |
| param.requires_grad = True | |
| for param in eff.features[7].parameters(): | |
| param.requires_grad = True | |
| self.eff_features = eff.features | |
| self.eff_avgpool = eff.avgpool | |
| self.eff_out_dim = eff.classifier[1].in_features | |
| cnx = ConvNextModel.from_pretrained(convnext_model_name) | |
| for param in cnx.parameters(): | |
| param.requires_grad = False | |
| for param in cnx.encoder.stages[2].parameters(): | |
| param.requires_grad = True | |
| for param in cnx.encoder.stages[3].parameters(): | |
| param.requires_grad = True | |
| for param in cnx.layernorm.parameters(): | |
| param.requires_grad = True | |
| self.cnx_backbone = cnx | |
| self.cnx_out_dim = 768 | |
| fused_dim = self.eff_out_dim + self.cnx_out_dim | |
| self.fusion_head = nn.Sequential( | |
| nn.Dropout(p=0.4), | |
| nn.Linear(fused_dim, 512), | |
| nn.LayerNorm(512), | |
| nn.GELU(), | |
| nn.Dropout(p=0.3), | |
| nn.Linear(512, 256), | |
| nn.LayerNorm(256), | |
| nn.GELU(), | |
| nn.Dropout(p=0.2), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, pixel_values_eff, pixel_values_cnx): | |
| x_eff = self.eff_features(pixel_values_eff) | |
| x_eff = self.eff_avgpool(x_eff) | |
| x_eff = torch.flatten(x_eff, 1) | |
| cnx_out = self.cnx_backbone(pixel_values=pixel_values_cnx, return_dict=True) | |
| x_cnx = cnx_out.pooler_output | |
| fused = torch.cat([x_eff, x_cnx], dim=1) | |
| return self.fusion_head(fused) | |
| class FusionCarDamagePredictor: | |
| def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.class_map = class_map | |
| self.eff_normalize = transforms.Compose([ | |
| transforms.Resize((260, 260)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| self.convnext_processor = ConvNextImageProcessor.from_pretrained(convnext_model_name) | |
| try: | |
| self.model = FusionClassifier( | |
| num_classes=len(class_map), | |
| convnext_model_name=convnext_model_name | |
| ) | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| state_dict = checkpoint.get("model_state_dict", checkpoint) | |
| first_tensor = next(iter(state_dict.values())) | |
| if first_tensor.dtype == torch.float16: | |
| self.model = self.model.half() | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load Fusion model: {str(e)}") | |
| def predict(self, image_input): | |
| try: | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input.convert("RGB") | |
| else: | |
| raise TypeError("image_input must be a file path or PIL.Image") | |
| pixel_eff = self.eff_normalize(image) | |
| pixel_eff = pixel_eff.unsqueeze(0).to(self.device) | |
| inputs_cnx = self.convnext_processor(images=image, return_tensors="pt") | |
| pixel_cnx = inputs_cnx["pixel_values"].to(self.device) | |
| if next(self.model.parameters()).dtype == torch.float16: | |
| pixel_eff = pixel_eff.half() | |
| pixel_cnx = pixel_cnx.half() | |
| with torch.no_grad(): | |
| logits = self.model(pixel_eff, pixel_cnx) | |
| probs = torch.nn.functional.softmax(logits, dim=1)[0] | |
| class_probs = { | |
| self.class_map[i]: float(probs[i].item()) | |
| for i in range(len(self.class_map)) | |
| } | |
| return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) | |
| except UnidentifiedImageError: | |
| raise ValueError("Invalid image file provided") | |
| except Exception as e: | |
| raise RuntimeError(f"Fusion prediction failed: {str(e)}") | |