Spaces:
Running
Running
| import base64 | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| from ultralytics import YOLO | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| # ---------------- YOLO ---------------- | |
| yolo_model = YOLO("artifacts/damage_detector.pt") | |
| # ---------------- CLASSIFIER ---------------- | |
| class Car_Classifier_Resnet(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| self.model = models.resnet18(weights="DEFAULT") | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| for param in self.model.layer4.parameters(): | |
| param.requires_grad = True | |
| for module in self.model.modules(): | |
| if isinstance(module, nn.BatchNorm2d): | |
| for param in module.parameters(): | |
| param.requires_grad = True | |
| self.model.fc = nn.Sequential( | |
| nn.Dropout(0.4), | |
| nn.Linear(self.model.fc.in_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| class_names = [ | |
| "F_Breakage", | |
| "F_Crushed", | |
| "F_Normal", | |
| "R_Breakage", | |
| "R_Crushed", | |
| "R_Normal" | |
| ] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| clf_model = Car_Classifier_Resnet(num_classes=6).to(device) | |
| clf_model.load_state_dict( | |
| torch.load("artifacts/Damage_Classifier_Resnet_18.pth", map_location=device) | |
| ) | |
| clf_model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def predict_damage(image: Image.Image): | |
| image = image.convert("RGB") | |
| # -------- 1. CLASSIFICATION (ResNet) -------- | |
| # Run classification first as requested | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = clf_model(img_tensor) | |
| probs = torch.softmax(out, dim=1) | |
| conf, idx = torch.max(probs, dim=1) | |
| damage_type = class_names[idx.item()] | |
| confidence_score = round(conf.item(), 4) | |
| # -------- 2. YOLO DETECTION -------- | |
| yolo_results = yolo_model.predict( | |
| source=image, | |
| conf=0.05, | |
| imgsz=640, | |
| verbose=False | |
| ) | |
| result = yolo_results[0] | |
| # Check if any boxes were detected | |
| damage_detected = result.boxes is not None and len(result.boxes) > 0 | |
| # Generate the image with bounding boxes drawn | |
| # plot() returns a numpy array in BGR format (OpenCV style) | |
| plotted_image_bgr = result.plot() | |
| # Convert BGR to RGB | |
| plotted_image_rgb = plotted_image_bgr[..., ::-1] | |
| # Convert numpy array back to PIL Image | |
| final_image = Image.fromarray(plotted_image_rgb) | |
| # Encode image to Base64 to send to frontend | |
| buffered = io.BytesIO() | |
| final_image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return { | |
| "damage_detected": damage_detected, | |
| "damage_type": damage_type, | |
| "confidence": confidence_score, | |
| "annotated_image": img_str # Base64 string of the image | |
| } |