|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
import albumentations as A
|
|
|
from albumentations.pytorch import ToTensorV2
|
|
|
from model import AstronomyClassifier, MODEL_CONFIG
|
|
|
|
|
|
class AstronomyInference:
|
|
|
"""Astronomy Image Classification Inference with Ensemble Support"""
|
|
|
|
|
|
def __init__(self, use_ensemble=True, device="cpu"):
|
|
|
self.device = torch.device(device)
|
|
|
self.class_names = MODEL_CONFIG["class_names"]
|
|
|
self.num_classes = MODEL_CONFIG["num_classes"]
|
|
|
self.use_ensemble = use_ensemble
|
|
|
|
|
|
|
|
|
self.models = {}
|
|
|
self.load_models()
|
|
|
|
|
|
|
|
|
self.transform = A.Compose([
|
|
|
A.Resize(MODEL_CONFIG["input_size"][0], MODEL_CONFIG["input_size"][1]),
|
|
|
A.Normalize(
|
|
|
mean=MODEL_CONFIG["mean"],
|
|
|
std=MODEL_CONFIG["std"]
|
|
|
),
|
|
|
ToTensorV2()
|
|
|
])
|
|
|
|
|
|
def load_models(self):
|
|
|
"""Load both ResNet50 and DenseNet121 models"""
|
|
|
try:
|
|
|
|
|
|
resnet_model = AstronomyClassifier(
|
|
|
model_name="resnet50",
|
|
|
num_classes=self.num_classes,
|
|
|
pretrained=False
|
|
|
)
|
|
|
resnet_state_dict = torch.load("best_resnet50.pth", map_location=self.device)
|
|
|
resnet_model.load_state_dict(resnet_state_dict)
|
|
|
resnet_model.to(self.device)
|
|
|
resnet_model.eval()
|
|
|
self.models["resnet50"] = resnet_model
|
|
|
print("β
ResNet50 model loaded successfully")
|
|
|
except Exception as e:
|
|
|
print(f"β Failed to load ResNet50: {e}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
densenet_model = AstronomyClassifier(
|
|
|
model_name="densenet121",
|
|
|
num_classes=self.num_classes,
|
|
|
pretrained=False
|
|
|
)
|
|
|
densenet_state_dict = torch.load("best_densenet121.pth", map_location=self.device)
|
|
|
densenet_model.load_state_dict(densenet_state_dict)
|
|
|
densenet_model.to(self.device)
|
|
|
densenet_model.eval()
|
|
|
self.models["densenet121"] = densenet_model
|
|
|
print("β
DenseNet121 model loaded successfully")
|
|
|
except Exception as e:
|
|
|
print(f"β Failed to load DenseNet121: {e}")
|
|
|
|
|
|
def preprocess_image(self, image):
|
|
|
"""Preprocess image for inference"""
|
|
|
if isinstance(image, str):
|
|
|
image = Image.open(image).convert('RGB')
|
|
|
elif isinstance(image, np.ndarray):
|
|
|
image = Image.fromarray(image).convert('RGB')
|
|
|
|
|
|
|
|
|
image_np = np.array(image)
|
|
|
transformed = self.transform(image=image_np)
|
|
|
image_tensor = transformed['image'].unsqueeze(0)
|
|
|
|
|
|
return image_tensor.to(self.device)
|
|
|
|
|
|
def predict_single_model(self, model, image_tensor):
|
|
|
"""Predict using a single model"""
|
|
|
with torch.no_grad():
|
|
|
outputs = model(image_tensor)
|
|
|
probabilities = F.softmax(outputs, dim=1)
|
|
|
confidence, predicted = torch.max(probabilities, 1)
|
|
|
|
|
|
predicted_class = self.class_names[predicted.item()]
|
|
|
confidence_score = confidence.item()
|
|
|
all_probs = probabilities[0].cpu().numpy()
|
|
|
|
|
|
return predicted_class, confidence_score, all_probs
|
|
|
|
|
|
def predict_ensemble(self, image_tensor):
|
|
|
"""Predict using ensemble of models"""
|
|
|
all_probabilities = []
|
|
|
individual_results = {}
|
|
|
|
|
|
for model_name, model in self.models.items():
|
|
|
predicted_class, confidence, probs = self.predict_single_model(model, image_tensor)
|
|
|
all_probabilities.append(probs)
|
|
|
individual_results[model_name] = {
|
|
|
"predicted_class": predicted_class,
|
|
|
"confidence": confidence
|
|
|
}
|
|
|
|
|
|
|
|
|
avg_probabilities = np.mean(all_probabilities, axis=0)
|
|
|
predicted_class = self.class_names[np.argmax(avg_probabilities)]
|
|
|
confidence_score = float(np.max(avg_probabilities))
|
|
|
|
|
|
|
|
|
prob_dict = {
|
|
|
self.class_names[i]: float(avg_probabilities[i])
|
|
|
for i in range(len(self.class_names))
|
|
|
}
|
|
|
|
|
|
return {
|
|
|
"predicted_class": predicted_class,
|
|
|
"confidence": confidence_score,
|
|
|
"probabilities": prob_dict,
|
|
|
"individual_results": individual_results
|
|
|
}
|
|
|
|
|
|
def predict(self, image, return_probabilities=True):
|
|
|
"""Predict image class"""
|
|
|
|
|
|
image_tensor = self.preprocess_image(image)
|
|
|
|
|
|
if self.use_ensemble and len(self.models) > 1:
|
|
|
|
|
|
result = self.predict_ensemble(image_tensor)
|
|
|
if return_probabilities:
|
|
|
return result
|
|
|
else:
|
|
|
return {
|
|
|
"predicted_class": result["predicted_class"],
|
|
|
"confidence": result["confidence"]
|
|
|
}
|
|
|
else:
|
|
|
|
|
|
model_name = list(self.models.keys())[0]
|
|
|
model = self.models[model_name]
|
|
|
predicted_class, confidence, all_probs = self.predict_single_model(model, image_tensor)
|
|
|
|
|
|
if return_probabilities:
|
|
|
prob_dict = {
|
|
|
self.class_names[i]: float(all_probs[i])
|
|
|
for i in range(len(self.class_names))
|
|
|
}
|
|
|
return {
|
|
|
"predicted_class": predicted_class,
|
|
|
"confidence": confidence,
|
|
|
"probabilities": prob_dict,
|
|
|
"model_used": model_name
|
|
|
}
|
|
|
else:
|
|
|
return {
|
|
|
"predicted_class": predicted_class,
|
|
|
"confidence": confidence,
|
|
|
"model_used": model_name
|
|
|
}
|
|
|
|
|
|
|
|
|
inference_model = None
|
|
|
|
|
|
def get_inference_model():
|
|
|
"""Get or create inference model"""
|
|
|
global inference_model
|
|
|
if inference_model is None:
|
|
|
inference_model = AstronomyInference(use_ensemble=True)
|
|
|
return inference_model |