junaid17 commited on
Commit
fba4818
·
verified ·
1 Parent(s): 85b73dd

Update predict_helper.py

Browse files
Files changed (1) hide show
  1. predict_helper.py +40 -33
predict_helper.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from PIL import Image
2
  from torchvision import transforms
3
  from ultralytics import YOLO
@@ -12,30 +15,23 @@ yolo_model = YOLO("artifacts/damage_detector.pt")
12
  class Car_Classifier_Resnet(nn.Module):
13
  def __init__(self, num_classes):
14
  super().__init__()
15
-
16
  self.model = models.resnet18(weights="DEFAULT")
17
-
18
  for param in self.model.parameters():
19
  param.requires_grad = False
20
-
21
  for param in self.model.layer4.parameters():
22
  param.requires_grad = True
23
-
24
  for module in self.model.modules():
25
  if isinstance(module, nn.BatchNorm2d):
26
  for param in module.parameters():
27
  param.requires_grad = True
28
-
29
  self.model.fc = nn.Sequential(
30
  nn.Dropout(0.4),
31
  nn.Linear(self.model.fc.in_features, num_classes)
32
  )
33
 
34
-
35
  def forward(self, x):
36
  return self.model(x)
37
 
38
-
39
  class_names = [
40
  "F_Breakage",
41
  "F_Crushed",
@@ -46,7 +42,6 @@ class_names = [
46
  ]
47
 
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
-
50
  clf_model = Car_Classifier_Resnet(num_classes=6).to(device)
51
  clf_model.load_state_dict(
52
  torch.load("artifacts/Damage_Classifier_Resnet_18.pth", map_location=device)
@@ -62,39 +57,51 @@ transform = transforms.Compose([
62
  )
63
  ])
64
 
65
- # here
66
-
67
  def predict_damage(image: Image.Image):
68
  image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # -------- YOLO --------
71
  yolo_results = yolo_model.predict(
72
  source=image,
73
  conf=0.05,
74
  imgsz=640,
75
  verbose=False
76
  )
77
-
78
- bboxes = []
79
- if yolo_results[0].boxes is not None:
80
- for box in yolo_results[0].boxes:
81
- x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
82
- conf = float(box.conf[0])
83
- bboxes.append({
84
- "bbox": [x1, y1, x2, y2],
85
- "confidence": round(conf, 4)
86
- })
87
-
88
- # -------- CLASSIFICATION --------
89
- img_tensor = transform(image).unsqueeze(0).to(device)
90
- with torch.no_grad():
91
- out = clf_model(img_tensor)
92
- probs = torch.softmax(out, dim=1)
93
- conf, idx = torch.max(probs, dim=1)
 
 
 
94
 
95
  return {
96
- "damage_detected": len(bboxes) > 0,
97
- "damage_type": class_names[idx.item()],
98
- "confidence": round(conf.item(), 4),
99
- "bboxes": bboxes
100
- }
 
1
+ import base64
2
+ import io
3
+ import numpy as np
4
  from PIL import Image
5
  from torchvision import transforms
6
  from ultralytics import YOLO
 
15
  class Car_Classifier_Resnet(nn.Module):
16
  def __init__(self, num_classes):
17
  super().__init__()
 
18
  self.model = models.resnet18(weights="DEFAULT")
 
19
  for param in self.model.parameters():
20
  param.requires_grad = False
 
21
  for param in self.model.layer4.parameters():
22
  param.requires_grad = True
 
23
  for module in self.model.modules():
24
  if isinstance(module, nn.BatchNorm2d):
25
  for param in module.parameters():
26
  param.requires_grad = True
 
27
  self.model.fc = nn.Sequential(
28
  nn.Dropout(0.4),
29
  nn.Linear(self.model.fc.in_features, num_classes)
30
  )
31
 
 
32
  def forward(self, x):
33
  return self.model(x)
34
 
 
35
  class_names = [
36
  "F_Breakage",
37
  "F_Crushed",
 
42
  ]
43
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
45
  clf_model = Car_Classifier_Resnet(num_classes=6).to(device)
46
  clf_model.load_state_dict(
47
  torch.load("artifacts/Damage_Classifier_Resnet_18.pth", map_location=device)
 
57
  )
58
  ])
59
 
 
 
60
  def predict_damage(image: Image.Image):
61
  image = image.convert("RGB")
62
+
63
+ # -------- 1. CLASSIFICATION (ResNet) --------
64
+ # Run classification first as requested
65
+ img_tensor = transform(image).unsqueeze(0).to(device)
66
+ with torch.no_grad():
67
+ out = clf_model(img_tensor)
68
+ probs = torch.softmax(out, dim=1)
69
+ conf, idx = torch.max(probs, dim=1)
70
+
71
+ damage_type = class_names[idx.item()]
72
+ confidence_score = round(conf.item(), 4)
73
 
74
+ # -------- 2. YOLO DETECTION --------
75
  yolo_results = yolo_model.predict(
76
  source=image,
77
  conf=0.05,
78
  imgsz=640,
79
  verbose=False
80
  )
81
+
82
+ result = yolo_results[0]
83
+
84
+ # Check if any boxes were detected
85
+ damage_detected = result.boxes is not None and len(result.boxes) > 0
86
+
87
+ # Generate the image with bounding boxes drawn
88
+ # plot() returns a numpy array in BGR format (OpenCV style)
89
+ plotted_image_bgr = result.plot()
90
+
91
+ # Convert BGR to RGB
92
+ plotted_image_rgb = plotted_image_bgr[..., ::-1]
93
+
94
+ # Convert numpy array back to PIL Image
95
+ final_image = Image.fromarray(plotted_image_rgb)
96
+
97
+ # Encode image to Base64 to send to frontend
98
+ buffered = io.BytesIO()
99
+ final_image.save(buffered, format="JPEG")
100
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
101
 
102
  return {
103
+ "damage_detected": damage_detected,
104
+ "damage_type": damage_type,
105
+ "confidence": confidence_score,
106
+ "annotated_image": img_str # Base64 string of the image
107
+ }