File size: 3,173 Bytes
fba4818
 
 
25bfc41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fba4818
 
 
 
 
 
 
 
 
 
 
25bfc41
fba4818
25bfc41
 
 
 
 
 
fba4818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25bfc41
 
fba4818
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import base64
import io
import numpy as np
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO
import torch
import torch.nn as nn
import torchvision.models as models

# ---------------- YOLO ----------------
yolo_model = YOLO("artifacts/damage_detector.pt")

# ---------------- CLASSIFIER ----------------
class Car_Classifier_Resnet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = models.resnet18(weights="DEFAULT")
        for param in self.model.parameters():
            param.requires_grad = False
        for param in self.model.layer4.parameters():
            param.requires_grad = True
        for module in self.model.modules():
            if isinstance(module, nn.BatchNorm2d):
                for param in module.parameters():
                    param.requires_grad = True
        self.model.fc = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(self.model.fc.in_features, num_classes)
        )

    def forward(self, x):
        return self.model(x)

class_names = [
    "F_Breakage",
    "F_Crushed",
    "F_Normal",
    "R_Breakage",
    "R_Crushed",
    "R_Normal"
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clf_model = Car_Classifier_Resnet(num_classes=6).to(device)
clf_model.load_state_dict(
    torch.load("artifacts/Damage_Classifier_Resnet_18.pth", map_location=device)
)
clf_model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def predict_damage(image: Image.Image):
    image = image.convert("RGB")
    
    # -------- 1. CLASSIFICATION (ResNet) --------
    # Run classification first as requested
    img_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        out = clf_model(img_tensor)
        probs = torch.softmax(out, dim=1)
        conf, idx = torch.max(probs, dim=1)
    
    damage_type = class_names[idx.item()]
    confidence_score = round(conf.item(), 4)

    # -------- 2. YOLO DETECTION --------
    yolo_results = yolo_model.predict(
        source=image,
        conf=0.05,
        imgsz=640,
        verbose=False
    )
    
    result = yolo_results[0]
    
    # Check if any boxes were detected
    damage_detected = result.boxes is not None and len(result.boxes) > 0

    # Generate the image with bounding boxes drawn
    # plot() returns a numpy array in BGR format (OpenCV style)
    plotted_image_bgr = result.plot()
    
    # Convert BGR to RGB
    plotted_image_rgb = plotted_image_bgr[..., ::-1]
    
    # Convert numpy array back to PIL Image
    final_image = Image.fromarray(plotted_image_rgb)

    # Encode image to Base64 to send to frontend
    buffered = io.BytesIO()
    final_image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return {
        "damage_detected": damage_detected,
        "damage_type": damage_type,
        "confidence": confidence_score,
        "annotated_image": img_str  # Base64 string of the image
    }