File size: 6,543 Bytes
7992750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import AstronomyClassifier, MODEL_CONFIG

class AstronomyInference:
    """Astronomy Image Classification Inference with Ensemble Support"""
    
    def __init__(self, use_ensemble=True, device="cpu"):
        self.device = torch.device(device)
        self.class_names = MODEL_CONFIG["class_names"]
        self.num_classes = MODEL_CONFIG["num_classes"]
        self.use_ensemble = use_ensemble
        
        # Load models
        self.models = {}
        self.load_models()
        
        # Setup transforms
        self.transform = A.Compose([
            A.Resize(MODEL_CONFIG["input_size"][0], MODEL_CONFIG["input_size"][1]),
            A.Normalize(
                mean=MODEL_CONFIG["mean"],
                std=MODEL_CONFIG["std"]
            ),
            ToTensorV2()
        ])
    
    def load_models(self):
        """Load both ResNet50 and DenseNet121 models"""
        try:
            # Load ResNet50
            resnet_model = AstronomyClassifier(
                model_name="resnet50",
                num_classes=self.num_classes,
                pretrained=False
            )
            resnet_state_dict = torch.load("best_resnet50.pth", map_location=self.device)
            resnet_model.load_state_dict(resnet_state_dict)
            resnet_model.to(self.device)
            resnet_model.eval()
            self.models["resnet50"] = resnet_model
            print("✅ ResNet50 model loaded successfully")
        except Exception as e:
            print(f"❌ Failed to load ResNet50: {e}")
        
        try:
            # Load DenseNet121
            densenet_model = AstronomyClassifier(
                model_name="densenet121",
                num_classes=self.num_classes,
                pretrained=False
            )
            densenet_state_dict = torch.load("best_densenet121.pth", map_location=self.device)
            densenet_model.load_state_dict(densenet_state_dict)
            densenet_model.to(self.device)
            densenet_model.eval()
            self.models["densenet121"] = densenet_model
            print("✅ DenseNet121 model loaded successfully")
        except Exception as e:
            print(f"❌ Failed to load DenseNet121: {e}")
    
    def preprocess_image(self, image):
        """Preprocess image for inference"""
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert('RGB')
        
        # Apply transforms
        image_np = np.array(image)
        transformed = self.transform(image=image_np)
        image_tensor = transformed['image'].unsqueeze(0)
        
        return image_tensor.to(self.device)
    
    def predict_single_model(self, model, image_tensor):
        """Predict using a single model"""
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = F.softmax(outputs, dim=1)
            confidence, predicted = torch.max(probabilities, 1)
        
        predicted_class = self.class_names[predicted.item()]
        confidence_score = confidence.item()
        all_probs = probabilities[0].cpu().numpy()
        
        return predicted_class, confidence_score, all_probs
    
    def predict_ensemble(self, image_tensor):
        """Predict using ensemble of models"""
        all_probabilities = []
        individual_results = {}
        
        for model_name, model in self.models.items():
            predicted_class, confidence, probs = self.predict_single_model(model, image_tensor)
            all_probabilities.append(probs)
            individual_results[model_name] = {
                "predicted_class": predicted_class,
                "confidence": confidence
            }
        
        # Average probabilities (soft voting)
        avg_probabilities = np.mean(all_probabilities, axis=0)
        predicted_class = self.class_names[np.argmax(avg_probabilities)]
        confidence_score = float(np.max(avg_probabilities))
        
        # Create probability dictionary
        prob_dict = {
            self.class_names[i]: float(avg_probabilities[i]) 
            for i in range(len(self.class_names))
        }
        
        return {
            "predicted_class": predicted_class,
            "confidence": confidence_score,
            "probabilities": prob_dict,
            "individual_results": individual_results
        }
    
    def predict(self, image, return_probabilities=True):
        """Predict image class"""
        # Preprocess
        image_tensor = self.preprocess_image(image)
        
        if self.use_ensemble and len(self.models) > 1:
            # Use ensemble prediction
            result = self.predict_ensemble(image_tensor)
            if return_probabilities:
                return result
            else:
                return {
                    "predicted_class": result["predicted_class"],
                    "confidence": result["confidence"]
                }
        else:
            # Use single model (first available)
            model_name = list(self.models.keys())[0]
            model = self.models[model_name]
            predicted_class, confidence, all_probs = self.predict_single_model(model, image_tensor)
            
            if return_probabilities:
                prob_dict = {
                    self.class_names[i]: float(all_probs[i]) 
                    for i in range(len(self.class_names))
                }
                return {
                    "predicted_class": predicted_class,
                    "confidence": confidence,
                    "probabilities": prob_dict,
                    "model_used": model_name
                }
            else:
                return {
                    "predicted_class": predicted_class,
                    "confidence": confidence,
                    "model_used": model_name
                }

# Global inference instance
inference_model = None

def get_inference_model():
    """Get or create inference model"""
    global inference_model
    if inference_model is None:
        inference_model = AstronomyInference(use_ensemble=True)
    return inference_model