junaid17 commited on
Commit
25bfc41
·
verified ·
1 Parent(s): 4e9961d

Update predict_helper.py

Browse files
Files changed (1) hide show
  1. predict_helper.py +100 -94
predict_helper.py CHANGED
@@ -1,94 +1,100 @@
1
- from torchvision import transforms
2
- from PIL import Image
3
- import torch
4
- import torch.nn as nn
5
- import torchvision.models as models
6
-
7
- # Image transformation for inference
8
- transform = transforms.Compose([
9
- transforms.Resize((224, 224)), # Resize to model input size
10
- transforms.ToTensor(), # Convert PIL image to tensor
11
- transforms.Normalize(
12
- mean=[0.485, 0.456, 0.406], # ImageNet mean
13
- std=[0.229, 0.224, 0.225] # ImageNet std
14
- )
15
- ])
16
-
17
- def preprocess_image(image_path):
18
- image = Image.open(image_path).convert("RGB")
19
- image = transform(image)
20
- image = image.unsqueeze(0) # Add batch dimension
21
- return image
22
-
23
- # Initializing the model
24
-
25
- class Car_Classifier_Resnet(nn.Module):
26
- def __init__(self, num_classes):
27
- super().__init__()
28
-
29
- self.model = models.resnet18(weights="DEFAULT")
30
-
31
- for param in self.model.parameters():
32
- param.requires_grad = False
33
-
34
- for param in self.model.layer4.parameters():
35
- param.requires_grad = True
36
-
37
- for module in self.model.modules():
38
- if isinstance(module, nn.BatchNorm2d):
39
- for param in module.parameters():
40
- param.requires_grad = True
41
-
42
- self.model.fc = nn.Sequential(
43
- nn.Dropout(0.4),
44
- nn.Linear(self.model.fc.in_features, num_classes)
45
- )
46
-
47
-
48
- def forward(self, x):
49
- return self.model(x)
50
-
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
-
53
- num_classes = 6
54
- model = Car_Classifier_Resnet(num_classes).to(device)
55
-
56
- model.load_state_dict(torch.load("Damage_Classifier_Resnet_18.pth", map_location=device))
57
-
58
- # Prediction Function
59
-
60
- class_names = [
61
- "F_Breakage",
62
- "F_Crushed",
63
- "F_Normal",
64
- "R_Breakage",
65
- "R_Crushed",
66
- "R_Normal"
67
- ]
68
-
69
- # Prediction Function
70
- def predict_image(image: Image.Image):
71
- model.eval()
72
-
73
- image = image.convert("RGB")
74
- image = transform(image).unsqueeze(0).to(device)
75
-
76
- with torch.no_grad():
77
- outputs = model(image)
78
- probs = torch.softmax(outputs, dim=1)
79
- conf, pred = torch.max(probs, dim=1)
80
-
81
- pred_idx = pred.item()
82
- confidence = conf.item()
83
-
84
- if class_names:
85
- return {
86
- "class_index": pred_idx,
87
- "class_name": class_names[pred_idx],
88
- "confidence": round(confidence, 4)
89
- }
90
- else:
91
- return {
92
- "class_index": pred_idx,
93
- "confidence": round(confidence, 4)
94
- }
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ from ultralytics import YOLO
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.models as models
7
+
8
+ # ---------------- YOLO ----------------
9
+ yolo_model = YOLO("artifacts/damage_detector.pt")
10
+
11
+ # ---------------- CLASSIFIER ----------------
12
+ class Car_Classifier_Resnet(nn.Module):
13
+ def __init__(self, num_classes):
14
+ super().__init__()
15
+
16
+ self.model = models.resnet18(weights="DEFAULT")
17
+
18
+ for param in self.model.parameters():
19
+ param.requires_grad = False
20
+
21
+ for param in self.model.layer4.parameters():
22
+ param.requires_grad = True
23
+
24
+ for module in self.model.modules():
25
+ if isinstance(module, nn.BatchNorm2d):
26
+ for param in module.parameters():
27
+ param.requires_grad = True
28
+
29
+ self.model.fc = nn.Sequential(
30
+ nn.Dropout(0.4),
31
+ nn.Linear(self.model.fc.in_features, num_classes)
32
+ )
33
+
34
+
35
+ def forward(self, x):
36
+ return self.model(x)
37
+
38
+
39
+ class_names = [
40
+ "F_Breakage",
41
+ "F_Crushed",
42
+ "F_Normal",
43
+ "R_Breakage",
44
+ "R_Crushed",
45
+ "R_Normal"
46
+ ]
47
+
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ clf_model = Car_Classifier_Resnet(num_classes=6).to(device)
51
+ clf_model.load_state_dict(
52
+ torch.load("artifacts/Damage_Classifier_Resnet_18.pth", map_location=device)
53
+ )
54
+ clf_model.eval()
55
+
56
+ transform = transforms.Compose([
57
+ transforms.Resize((224, 224)),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(
60
+ mean=[0.485, 0.456, 0.406],
61
+ std=[0.229, 0.224, 0.225]
62
+ )
63
+ ])
64
+
65
+ # here
66
+
67
+ def predict_damage(image: Image.Image):
68
+ image = image.convert("RGB")
69
+
70
+ # -------- YOLO --------
71
+ yolo_results = yolo_model.predict(
72
+ source=image,
73
+ conf=0.05,
74
+ imgsz=640,
75
+ verbose=False
76
+ )
77
+
78
+ bboxes = []
79
+ if yolo_results[0].boxes is not None:
80
+ for box in yolo_results[0].boxes:
81
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
82
+ conf = float(box.conf[0])
83
+ bboxes.append({
84
+ "bbox": [x1, y1, x2, y2],
85
+ "confidence": round(conf, 4)
86
+ })
87
+
88
+ # -------- CLASSIFICATION --------
89
+ img_tensor = transform(image).unsqueeze(0).to(device)
90
+ with torch.no_grad():
91
+ out = clf_model(img_tensor)
92
+ probs = torch.softmax(out, dim=1)
93
+ conf, idx = torch.max(probs, dim=1)
94
+
95
+ return {
96
+ "damage_detected": len(bboxes) > 0,
97
+ "damage_type": class_names[idx.item()],
98
+ "confidence": round(conf.item(), 4),
99
+ "bboxes": bboxes
100
+ }