import os import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from PIL import Image # Parameters IMG_HEIGHT = 224 IMG_WIDTH = 224 # Define classes (must match training - sorted alphabetically) CLASSES = sorted([ "Healthy", "Arcing_Contact_Misalignment", "Arcing_Contact_Wear", "Main Contact Misalignment", "main_contact_wear" ]) class ViTClassifier: _instance = None _model = None _device = None _transform = None @classmethod def get_instance(cls, model_path=None): if model_path is None: model_path = os.path.join(os.path.dirname(__file__), "vit_model.pth") if cls._instance is None: cls._instance = cls() cls._instance._load_model(model_path) return cls._instance def _load_model(self, model_path): self._transform = transforms.Compose([ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) self._device = torch.device("cpu") print(f"Using device: {self._device}") print(f"Loading model from {model_path}...") try: weights = models.ViT_B_16_Weights.DEFAULT self._model = models.vit_b_16(weights=weights) num_features = self._model.heads.head.in_features self._model.heads.head = nn.Linear(num_features, len(CLASSES)) if os.path.exists(model_path): self._model.load_state_dict(torch.load(model_path, map_location=self._device)) self._model.to(self._device) self._model.eval() print("Model loaded successfully.") else: print(f"Error: Model file not found at {model_path}") self._model = None except Exception as e: print(f"Error loading model: {e}") self._model = None def predict(self, image_path_or_file): """ Returns: predicted_class (str) confidence (float) probabilities (dict) → class: probability """ if self._model is None: return None, 0.0, {} try: image = Image.open(image_path_or_file).convert('RGB') image_tensor = self._transform(image).unsqueeze(0).to(self._device) with torch.no_grad(): outputs = self._model(image_tensor) probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()[0] # Highest confidence prediction predicted_idx = probs.argmax() predicted_class = CLASSES[predicted_idx] confidence = float(probs[predicted_idx]) # All class probabilities probability_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))} return predicted_class, confidence, probability_dict except Exception as e: print(f"Error processing image: {e}") return None, 0.0, {}