Saqib772's picture
image classification
7992750 verified
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import AstronomyClassifier, MODEL_CONFIG
class AstronomyInference:
"""Astronomy Image Classification Inference with Ensemble Support"""
def __init__(self, use_ensemble=True, device="cpu"):
self.device = torch.device(device)
self.class_names = MODEL_CONFIG["class_names"]
self.num_classes = MODEL_CONFIG["num_classes"]
self.use_ensemble = use_ensemble
# Load models
self.models = {}
self.load_models()
# Setup transforms
self.transform = A.Compose([
A.Resize(MODEL_CONFIG["input_size"][0], MODEL_CONFIG["input_size"][1]),
A.Normalize(
mean=MODEL_CONFIG["mean"],
std=MODEL_CONFIG["std"]
),
ToTensorV2()
])
def load_models(self):
"""Load both ResNet50 and DenseNet121 models"""
try:
# Load ResNet50
resnet_model = AstronomyClassifier(
model_name="resnet50",
num_classes=self.num_classes,
pretrained=False
)
resnet_state_dict = torch.load("best_resnet50.pth", map_location=self.device)
resnet_model.load_state_dict(resnet_state_dict)
resnet_model.to(self.device)
resnet_model.eval()
self.models["resnet50"] = resnet_model
print("βœ… ResNet50 model loaded successfully")
except Exception as e:
print(f"❌ Failed to load ResNet50: {e}")
try:
# Load DenseNet121
densenet_model = AstronomyClassifier(
model_name="densenet121",
num_classes=self.num_classes,
pretrained=False
)
densenet_state_dict = torch.load("best_densenet121.pth", map_location=self.device)
densenet_model.load_state_dict(densenet_state_dict)
densenet_model.to(self.device)
densenet_model.eval()
self.models["densenet121"] = densenet_model
print("βœ… DenseNet121 model loaded successfully")
except Exception as e:
print(f"❌ Failed to load DenseNet121: {e}")
def preprocess_image(self, image):
"""Preprocess image for inference"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Apply transforms
image_np = np.array(image)
transformed = self.transform(image=image_np)
image_tensor = transformed['image'].unsqueeze(0)
return image_tensor.to(self.device)
def predict_single_model(self, model, image_tensor):
"""Predict using a single model"""
with torch.no_grad():
outputs = model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
predicted_class = self.class_names[predicted.item()]
confidence_score = confidence.item()
all_probs = probabilities[0].cpu().numpy()
return predicted_class, confidence_score, all_probs
def predict_ensemble(self, image_tensor):
"""Predict using ensemble of models"""
all_probabilities = []
individual_results = {}
for model_name, model in self.models.items():
predicted_class, confidence, probs = self.predict_single_model(model, image_tensor)
all_probabilities.append(probs)
individual_results[model_name] = {
"predicted_class": predicted_class,
"confidence": confidence
}
# Average probabilities (soft voting)
avg_probabilities = np.mean(all_probabilities, axis=0)
predicted_class = self.class_names[np.argmax(avg_probabilities)]
confidence_score = float(np.max(avg_probabilities))
# Create probability dictionary
prob_dict = {
self.class_names[i]: float(avg_probabilities[i])
for i in range(len(self.class_names))
}
return {
"predicted_class": predicted_class,
"confidence": confidence_score,
"probabilities": prob_dict,
"individual_results": individual_results
}
def predict(self, image, return_probabilities=True):
"""Predict image class"""
# Preprocess
image_tensor = self.preprocess_image(image)
if self.use_ensemble and len(self.models) > 1:
# Use ensemble prediction
result = self.predict_ensemble(image_tensor)
if return_probabilities:
return result
else:
return {
"predicted_class": result["predicted_class"],
"confidence": result["confidence"]
}
else:
# Use single model (first available)
model_name = list(self.models.keys())[0]
model = self.models[model_name]
predicted_class, confidence, all_probs = self.predict_single_model(model, image_tensor)
if return_probabilities:
prob_dict = {
self.class_names[i]: float(all_probs[i])
for i in range(len(self.class_names))
}
return {
"predicted_class": predicted_class,
"confidence": confidence,
"probabilities": prob_dict,
"model_used": model_name
}
else:
return {
"predicted_class": predicted_class,
"confidence": confidence,
"model_used": model_name
}
# Global inference instance
inference_model = None
def get_inference_model():
"""Get or create inference model"""
global inference_model
if inference_model is None:
inference_model = AstronomyInference(use_ensemble=True)
return inference_model