File size: 3,129 Bytes
ddc8ff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# Parameters
IMG_HEIGHT = 224
IMG_WIDTH = 224

# Define classes (must match training - sorted alphabetically)
CLASSES = sorted([
    "Healthy",
    "Arcing_Contact_Misalignment",
    "Arcing_Contact_Wear",
    "Main Contact Misalignment",
    "main_contact_wear"
])

class ViTClassifier:
    _instance = None
    _model = None
    _device = None
    _transform = None

    @classmethod
    def get_instance(cls, model_path=None):
        if model_path is None:
            model_path = os.path.join(os.path.dirname(__file__), "vit_model.pth")
        if cls._instance is None:
            cls._instance = cls()
            cls._instance._load_model(model_path)
        return cls._instance

    def _load_model(self, model_path):
        self._transform = transforms.Compose([
            transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        self._device = torch.device("cpu")
        print(f"Using device: {self._device}")
        print(f"Loading model from {model_path}...")

        try:
            weights = models.ViT_B_16_Weights.DEFAULT
            self._model = models.vit_b_16(weights=weights)

            num_features = self._model.heads.head.in_features
            self._model.heads.head = nn.Linear(num_features, len(CLASSES))

            if os.path.exists(model_path):
                self._model.load_state_dict(torch.load(model_path, map_location=self._device))
                self._model.to(self._device)
                self._model.eval()
                print("Model loaded successfully.")
            else:
                print(f"Error: Model file not found at {model_path}")
                self._model = None

        except Exception as e:
            print(f"Error loading model: {e}")
            self._model = None

    def predict(self, image_path_or_file):
        """
        Returns:
          predicted_class (str)
          confidence (float)
          probabilities (dict) → class: probability
        """
        if self._model is None:
            return None, 0.0, {}

        try:
            image = Image.open(image_path_or_file).convert('RGB')
            image_tensor = self._transform(image).unsqueeze(0).to(self._device)

            with torch.no_grad():
                outputs = self._model(image_tensor)
                probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()[0]

            # Highest confidence prediction
            predicted_idx = probs.argmax()
            predicted_class = CLASSES[predicted_idx]
            confidence = float(probs[predicted_idx])

            # All class probabilities
            probability_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}

            return predicted_class, confidence, probability_dict

        except Exception as e:
            print(f"Error processing image: {e}")
            return None, 0.0, {}