File size: 7,933 Bytes
bf07f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import sys
import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
import numpy as np
from PIL import Image
from typing import List, Dict, Any
import timm 

# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from src.utils import get_device, get_model, get_transforms

# ----------------------------------------------------------------------
# --- Global Variables ---
# ----------------------------------------------------------------------

DEVICE = get_device()
IMG_SIZE = 224

# ----------------------------------------------------------------------
# --- Model Ensemble Agent Core (with all fixes) ---
# ----------------------------------------------------------------------

class ModelEnsembleAgent:
    def __init__(self, model_names: List[str], checkpoints_dir: str, num_classes: int, class_names: List[str]):
        self.models = {}
        self.model_names = model_names
        self.num_classes = num_classes
        self.class_names = class_names
        self.transforms = get_transforms('val', IMG_SIZE)
        
        self.device = DEVICE
        self._load_all_models(checkpoints_dir)

    def _load_all_models(self, checkpoints_dir: str):
        """Loads all specified model checkpoints with strict=False fallback."""
        print(f"Loading {len(self.model_names)} models from {checkpoints_dir} on {self.device}...")
        
        for name in self.model_names:
            
            # FIX: Corrected file naming convention (best_modelname.pth)
            checkpoint_path = os.path.join(checkpoints_dir, f"best_{name}.pth")
            
            print(f"  Attempting to load {name} from expected path: {checkpoint_path}...")
            
            try:
                model = get_model(name, self.num_classes, pretrained=False).to(self.device)
                
                checkpoint = torch.load(checkpoint_path, map_location=self.device)
                state_dict = checkpoint.get('model_state_dict', checkpoint)
                
                # FIX: Filter out incompatible head layers that have size mismatches
                # This handles cases where checkpoint was trained with different head architecture
                model_state = model.state_dict()
                filtered_state_dict = {}
                for key, value in state_dict.items():
                    if key in model_state and model_state[key].shape == value.shape:
                        filtered_state_dict[key] = value
                    elif key not in model_state:
                        # Key doesn't exist in current model, skip it
                        pass
                    else:
                        # Shape mismatch, skip this layer (usually head layers)
                        print(f"    (Skipping layer '{key}' due to shape mismatch: {value.shape} vs {model_state[key].shape})")
                
                # Load only compatible layers
                model.load_state_dict(filtered_state_dict, strict=False)

                model.eval()
                self.models[name] = model
                print(f"  ✅ Successfully loaded {name}.")
            
            except FileNotFoundError:
                print(f"  ❌ Checkpoint not found at: {checkpoint_path}. Skipping.")
            except Exception as e:
                # FIX: Detailed error reporting to show the full RuntimeError message
                print(f"  ❌ Failed to load {name}. Error: {e.__class__.__name__}. Details: {e}. Skipping.")
        
        if not self.models:
            raise RuntimeError("No models were successfully loaded. Cannot run ensemble.")

    @torch.no_grad()
    def run_ensemble(self, image_path: str) -> Dict[str, Any]:
        """Runs inference across all loaded models and computes the ensemble prediction."""
        
        try:
            image = Image.open(image_path).convert('RGB')
            input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
        except Exception as e:
            return {"error": f"Failed to load or process image: {e}"}

        all_probs = []
        individual_predictions = {}
        
        for name, model in self.models.items():
            outputs = model(input_tensor)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
            
            all_probs.append(probs)
            
            pred_idx = np.argmax(probs)
            pred_conf = probs[pred_idx]
            
            individual_predictions[name] = {
                "class": self.class_names[pred_idx],
                "confidence": float(pred_conf)
            }

        # Ensemble Decision (Weighted Voting)
        # Use max confidence from each model as the weight
        weights = np.array([np.max(probs) for probs in all_probs])
        # Normalize weights
        weights = weights / np.sum(weights)
        
        # Weighted average of probabilities
        weighted_avg_probs = np.average(all_probs, axis=0, weights=weights)
        ensemble_idx = np.argmax(weighted_avg_probs)
        ensemble_confidence = weighted_avg_probs[ensemble_idx]
        ensemble_class = self.class_names[ensemble_idx]

        return {
            "image_path": image_path,
            "ensemble_prediction": ensemble_class,
            "ensemble_confidence": float(ensemble_confidence),
            "individual_predictions": individual_predictions,
            "fracture_detected": ensemble_class != "Healthy"
        }

# ----------------------------------------------------------------------
# --- Execution Block ---
# ----------------------------------------------------------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Multi-Model Ensemble (Cross-Validation) Agent.')
    parser.add_argument('--image-path', required=True, help='Path to the image for inference.')
    parser.add_argument('--checkpoints-dir', required=True, # Made required since default path was confusing
                        help='Absolute path to the directory containing the model checkpoints (e.g., best_swin.pth).')
    parser.add_argument('--models', type=str, default='swin,mobilenetv2,efficientnetv2,maxvit,densenet169', 
                        help='Comma-separated names of the models to load.')
    parser.add_argument('--num-classes', type=int, default=8)
    parser.add_argument('--class-names', required=True, 
                        help='Comma-separated list of class names.')
    
    args = parser.parse_args()

    models_list = [m.strip() for m in args.models.split(',')]
    class_names_list = [c.strip() for c in args.class_names.split(',')]
    
    try:
        ensemble_agent = ModelEnsembleAgent(
            model_names=models_list,
            checkpoints_dir=args.checkpoints_dir,
            num_classes=args.num_classes,
            class_names=class_names_list
        )
    except RuntimeError as e:
        print(f"\nFATAL ERROR during initialization: {e}")
        exit(1)

    result = ensemble_agent.run_ensemble(args.image_path)
    
    print("\n--- ENSEMBLE AGENT RESULT ---")
    if "error" in result:
        print(f"Error: {result['error']}")
    else:
        print(f"Image: {os.path.basename(result['image_path'])}")
        print(f"FINAL ENSEMBLE PREDICTION: **{result['ensemble_prediction']}** (Confidence: {result['ensemble_confidence']:.4f})")
        
        print("\nIndividual Model Predictions:")
        loaded_model_names = ensemble_agent.models.keys()
        
        for name in models_list:
            if name in loaded_model_names:
                 pred = result['individual_predictions'][name]
                 print(f"  {name.upper():<15}: {pred['class']:<20} (Conf: {pred['confidence']:.4f})")
            else:
                 print(f"  {name.upper():<15}: (Skipped/Failed to Load)")

    print("-----------------------------\n")