|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
IMG_HEIGHT = 224 |
|
|
IMG_WIDTH = 224 |
|
|
|
|
|
|
|
|
CLASSES = sorted([ |
|
|
"Healthy", |
|
|
"Arcing_Contact_Misalignment", |
|
|
"Arcing_Contact_Wear", |
|
|
"Main Contact Misalignment", |
|
|
"main_contact_wear" |
|
|
]) |
|
|
|
|
|
class ViTClassifier: |
|
|
_instance = None |
|
|
_model = None |
|
|
_device = None |
|
|
_transform = None |
|
|
|
|
|
@classmethod |
|
|
def get_instance(cls, model_path=None): |
|
|
if model_path is None: |
|
|
model_path = os.path.join(os.path.dirname(__file__), "vit_model.pth") |
|
|
if cls._instance is None: |
|
|
cls._instance = cls() |
|
|
cls._instance._load_model(model_path) |
|
|
return cls._instance |
|
|
|
|
|
def _load_model(self, model_path): |
|
|
self._transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
self._device = torch.device("cpu") |
|
|
print(f"Using device: {self._device}") |
|
|
print(f"Loading model from {model_path}...") |
|
|
|
|
|
try: |
|
|
weights = models.ViT_B_16_Weights.DEFAULT |
|
|
self._model = models.vit_b_16(weights=weights) |
|
|
|
|
|
num_features = self._model.heads.head.in_features |
|
|
self._model.heads.head = nn.Linear(num_features, len(CLASSES)) |
|
|
|
|
|
if os.path.exists(model_path): |
|
|
self._model.load_state_dict(torch.load(model_path, map_location=self._device)) |
|
|
self._model.to(self._device) |
|
|
self._model.eval() |
|
|
print("Model loaded successfully.") |
|
|
else: |
|
|
print(f"Error: Model file not found at {model_path}") |
|
|
self._model = None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
self._model = None |
|
|
|
|
|
def predict(self, image_path_or_file): |
|
|
""" |
|
|
Returns: |
|
|
predicted_class (str) |
|
|
confidence (float) |
|
|
probabilities (dict) → class: probability |
|
|
""" |
|
|
if self._model is None: |
|
|
return None, 0.0, {} |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path_or_file).convert('RGB') |
|
|
image_tensor = self._transform(image).unsqueeze(0).to(self._device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self._model(image_tensor) |
|
|
probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
predicted_idx = probs.argmax() |
|
|
predicted_class = CLASSES[predicted_idx] |
|
|
confidence = float(probs[predicted_idx]) |
|
|
|
|
|
|
|
|
probability_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))} |
|
|
|
|
|
return predicted_class, confidence, probability_dict |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
return None, 0.0, {} |