car_damage_detector / predict_helper.py
junaid17's picture
Update predict_helper.py
fba4818 verified
import base64
import io
import numpy as np
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO
import torch
import torch.nn as nn
import torchvision.models as models
# ---------------- YOLO ----------------
yolo_model = YOLO("artifacts/damage_detector.pt")
# ---------------- CLASSIFIER ----------------
class Car_Classifier_Resnet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.model = models.resnet18(weights="DEFAULT")
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.layer4.parameters():
param.requires_grad = True
for module in self.model.modules():
if isinstance(module, nn.BatchNorm2d):
for param in module.parameters():
param.requires_grad = True
self.model.fc = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(self.model.fc.in_features, num_classes)
)
def forward(self, x):
return self.model(x)
class_names = [
"F_Breakage",
"F_Crushed",
"F_Normal",
"R_Breakage",
"R_Crushed",
"R_Normal"
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clf_model = Car_Classifier_Resnet(num_classes=6).to(device)
clf_model.load_state_dict(
torch.load("artifacts/Damage_Classifier_Resnet_18.pth", map_location=device)
)
clf_model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def predict_damage(image: Image.Image):
image = image.convert("RGB")
# -------- 1. CLASSIFICATION (ResNet) --------
# Run classification first as requested
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
out = clf_model(img_tensor)
probs = torch.softmax(out, dim=1)
conf, idx = torch.max(probs, dim=1)
damage_type = class_names[idx.item()]
confidence_score = round(conf.item(), 4)
# -------- 2. YOLO DETECTION --------
yolo_results = yolo_model.predict(
source=image,
conf=0.05,
imgsz=640,
verbose=False
)
result = yolo_results[0]
# Check if any boxes were detected
damage_detected = result.boxes is not None and len(result.boxes) > 0
# Generate the image with bounding boxes drawn
# plot() returns a numpy array in BGR format (OpenCV style)
plotted_image_bgr = result.plot()
# Convert BGR to RGB
plotted_image_rgb = plotted_image_bgr[..., ::-1]
# Convert numpy array back to PIL Image
final_image = Image.fromarray(plotted_image_rgb)
# Encode image to Base64 to send to frontend
buffered = io.BytesIO()
final_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"damage_detected": damage_detected,
"damage_type": damage_type,
"confidence": confidence_score,
"annotated_image": img_str # Base64 string of the image
}