AIOmarRehan's picture
Upload 21 files
a4da623 verified
import os
from typing import Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import timm
class XceptionModel:
# Class names must match training
CLASS_NAMES = ["Auto Rickshaws", "Bikes", "Cars", "Motorcycles", "Planes", "Ships", "Trains"]
def __init__(self, model_dir: str, model_file: str = "best_model_finetuned_full.pt"):
self.model_dir = model_dir
self.model_file = model_file
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.inference_transform = None
self.class_names = self.CLASS_NAMES
print(f"[Xception] Using device: {self.device}")
print(f"[Xception] Classes: {self.class_names}")
self._load_model()
def _load_model(self):
try:
model_path = os.path.join(self.model_dir, self.model_file)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Disable TorchDynamo (avoids CatchErrorsWrapper issues)
torch._dynamo.config.suppress_errors = True
torch._dynamo.reset()
# Load the model
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
num_classes = len(self.CLASS_NAMES)
if isinstance(checkpoint, dict) and not hasattr(checkpoint, "forward"):
# State dict: rebuild the model architecture used during training
model = timm.create_model("xception", pretrained=False, num_classes=num_classes)
in_features = model.get_classifier().in_features
model.fc = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes),
)
state_dict = checkpoint
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
else:
# Full model
model = checkpoint
if hasattr(model, "_orig_mod"):
model = model._orig_mod
# Move model to device and set to evaluation mode
self.model = model.to(self.device).eval()
# Load preprocessing transforms
data_config = timm.data.resolve_model_data_config(self.model)
self.inference_transform = timm.data.create_transform(**data_config, is_training=False)
print(f"[Xception] Model loaded successfully from {model_path}")
except Exception as e:
print(f"[Xception] Error loading model: {e}")
raise
def _preprocess_image(self, img: Image.Image) -> torch.Tensor:
img = img.convert("RGB")
tensor = self.inference_transform(img).unsqueeze(0).to(self.device)
return tensor
def predict(self, image: Image.Image) -> Tuple[str, float, Dict[str, float]]:
if image is None:
return "No image provided", 0.0, {}
try:
# Ensure image is PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Preprocess image
inputs = self._preprocess_image(image)
# Forward pass
with torch.no_grad():
outputs = self.model(inputs)
probs = F.softmax(outputs, dim=-1).cpu().numpy()[0]
# Get predictions
class_idx = int(np.argmax(probs))
confidence = float(probs[class_idx])
prob_dict = {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))}
return self.class_names[class_idx], confidence, prob_dict
except Exception as e:
print(f"[Xception] Error during prediction: {e}")
raise