pranit144 commited on
Commit
ddc8ff7
·
verified ·
1 Parent(s): 5366d2a

Update vit_classifier.py

Browse files
Files changed (1) hide show
  1. vit_classifier.py +97 -100
vit_classifier.py CHANGED
@@ -1,100 +1,97 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.models as models
5
- import torchvision.transforms as transforms
6
- from PIL import Image
7
-
8
- # Parameters
9
- IMG_HEIGHT = 224
10
- IMG_WIDTH = 224
11
-
12
- # Define classes (must match training - sorted alphabetically)
13
- CLASSES = sorted([
14
- "Healthy",
15
- "Arcing_Contact_Misalignment",
16
- "Arcing_Contact_Wear",
17
- "Main Contact Misalignment",
18
- "main_contact_wear"
19
- ])
20
-
21
- class ViTClassifier:
22
- _instance = None
23
- _model = None
24
- _device = None
25
- _transform = None
26
-
27
- @classmethod
28
- def get_instance(cls, model_path=None):
29
- if model_path is None:
30
- # Default to the model in the same directory as this file
31
- model_path = os.path.join(os.path.dirname(__file__), "vit_model.pth")
32
- if cls._instance is None:
33
- cls._instance = cls()
34
- cls._instance._load_model(model_path)
35
- return cls._instance
36
-
37
- def _load_model(self, model_path):
38
- # Define transforms here
39
- self._transform = transforms.Compose([
40
- transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
41
- transforms.ToTensor(),
42
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
43
- ])
44
-
45
- # Force CPU for deployment to save memory as per instructions
46
- self._device = torch.device("cpu")
47
- print(f"Using device: {self._device}")
48
-
49
- print(f"Loading model from {model_path}...")
50
-
51
- try:
52
- # Load pretrained ViT model structure
53
- weights = models.ViT_B_16_Weights.DEFAULT
54
- self._model = models.vit_b_16(weights=weights)
55
-
56
- # Replace the classifier head to match training
57
- num_features = self._model.heads.head.in_features
58
- self._model.heads.head = nn.Linear(num_features, len(CLASSES))
59
-
60
- # Load trained weights
61
- if os.path.exists(model_path):
62
- # Load with map_location=cpu
63
- self._model.load_state_dict(torch.load(model_path, map_location=self._device))
64
- self._model.to(self._device)
65
- self._model.eval()
66
- print("Model loaded successfully.")
67
- else:
68
- print(f"Error: Model file not found at {model_path}")
69
- self._model = None
70
- except Exception as e:
71
- print(f"Error loading model: {e}")
72
- self._model = None
73
-
74
- def predict(self, image_path_or_file):
75
- """
76
- Predict using ViT model.
77
- Args:
78
- image_path_or_file: Path to image or file-like object
79
- Returns: (predicted_class, confidence_score)
80
- """
81
- if self._model is None:
82
- return None, 0.0
83
-
84
- try:
85
- image = Image.open(image_path_or_file).convert('RGB')
86
- image_tensor = self._transform(image).unsqueeze(0).to(self._device)
87
-
88
- with torch.no_grad():
89
- outputs = self._model(image_tensor)
90
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
91
- confidence, predicted_idx = torch.max(probabilities, 1)
92
-
93
- predicted_class = CLASSES[predicted_idx.item()]
94
- conf_val = confidence.item()
95
-
96
- return predicted_class, conf_val
97
-
98
- except Exception as e:
99
- print(f"Error processing image: {e}")
100
- return None, 0.0
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+
8
+ # Parameters
9
+ IMG_HEIGHT = 224
10
+ IMG_WIDTH = 224
11
+
12
+ # Define classes (must match training - sorted alphabetically)
13
+ CLASSES = sorted([
14
+ "Healthy",
15
+ "Arcing_Contact_Misalignment",
16
+ "Arcing_Contact_Wear",
17
+ "Main Contact Misalignment",
18
+ "main_contact_wear"
19
+ ])
20
+
21
+ class ViTClassifier:
22
+ _instance = None
23
+ _model = None
24
+ _device = None
25
+ _transform = None
26
+
27
+ @classmethod
28
+ def get_instance(cls, model_path=None):
29
+ if model_path is None:
30
+ model_path = os.path.join(os.path.dirname(__file__), "vit_model.pth")
31
+ if cls._instance is None:
32
+ cls._instance = cls()
33
+ cls._instance._load_model(model_path)
34
+ return cls._instance
35
+
36
+ def _load_model(self, model_path):
37
+ self._transform = transforms.Compose([
38
+ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
41
+ ])
42
+
43
+ self._device = torch.device("cpu")
44
+ print(f"Using device: {self._device}")
45
+ print(f"Loading model from {model_path}...")
46
+
47
+ try:
48
+ weights = models.ViT_B_16_Weights.DEFAULT
49
+ self._model = models.vit_b_16(weights=weights)
50
+
51
+ num_features = self._model.heads.head.in_features
52
+ self._model.heads.head = nn.Linear(num_features, len(CLASSES))
53
+
54
+ if os.path.exists(model_path):
55
+ self._model.load_state_dict(torch.load(model_path, map_location=self._device))
56
+ self._model.to(self._device)
57
+ self._model.eval()
58
+ print("Model loaded successfully.")
59
+ else:
60
+ print(f"Error: Model file not found at {model_path}")
61
+ self._model = None
62
+
63
+ except Exception as e:
64
+ print(f"Error loading model: {e}")
65
+ self._model = None
66
+
67
+ def predict(self, image_path_or_file):
68
+ """
69
+ Returns:
70
+ predicted_class (str)
71
+ confidence (float)
72
+ probabilities (dict) → class: probability
73
+ """
74
+ if self._model is None:
75
+ return None, 0.0, {}
76
+
77
+ try:
78
+ image = Image.open(image_path_or_file).convert('RGB')
79
+ image_tensor = self._transform(image).unsqueeze(0).to(self._device)
80
+
81
+ with torch.no_grad():
82
+ outputs = self._model(image_tensor)
83
+ probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()[0]
84
+
85
+ # Highest confidence prediction
86
+ predicted_idx = probs.argmax()
87
+ predicted_class = CLASSES[predicted_idx]
88
+ confidence = float(probs[predicted_idx])
89
+
90
+ # All class probabilities
91
+ probability_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
92
+
93
+ return predicted_class, confidence, probability_dict
94
+
95
+ except Exception as e:
96
+ print(f"Error processing image: {e}")
97
+ return None, 0.0, {}