import os import sys import argparse import torch import torch.nn as nn import torchvision.transforms as T import numpy as np from PIL import Image from typing import List, Dict, Any import timm # Add parent directory to path for imports sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) from src.utils import get_device, get_model, get_transforms # ---------------------------------------------------------------------- # --- Global Variables --- # ---------------------------------------------------------------------- DEVICE = get_device() IMG_SIZE = 224 # ---------------------------------------------------------------------- # --- Model Ensemble Agent Core (with all fixes) --- # ---------------------------------------------------------------------- class ModelEnsembleAgent: def __init__(self, model_names: List[str], checkpoints_dir: str, num_classes: int, class_names: List[str]): self.models = {} self.model_names = model_names self.num_classes = num_classes self.class_names = class_names self.transforms = get_transforms('val', IMG_SIZE) self.device = DEVICE self._load_all_models(checkpoints_dir) def _load_all_models(self, checkpoints_dir: str): """Loads all specified model checkpoints with strict=False fallback.""" print(f"Loading {len(self.model_names)} models from {checkpoints_dir} on {self.device}...") for name in self.model_names: # FIX: Corrected file naming convention (best_modelname.pth) checkpoint_path = os.path.join(checkpoints_dir, f"best_{name}.pth") print(f" Attempting to load {name} from expected path: {checkpoint_path}...") try: model = get_model(name, self.num_classes, pretrained=False).to(self.device) checkpoint = torch.load(checkpoint_path, map_location=self.device) state_dict = checkpoint.get('model_state_dict', checkpoint) # FIX: Filter out incompatible head layers that have size mismatches # This handles cases where checkpoint was trained with different head architecture model_state = model.state_dict() filtered_state_dict = {} for key, value in state_dict.items(): if key in model_state and model_state[key].shape == value.shape: filtered_state_dict[key] = value elif key not in model_state: # Key doesn't exist in current model, skip it pass else: # Shape mismatch, skip this layer (usually head layers) print(f" (Skipping layer '{key}' due to shape mismatch: {value.shape} vs {model_state[key].shape})") # Load only compatible layers model.load_state_dict(filtered_state_dict, strict=False) model.eval() self.models[name] = model print(f" ✅ Successfully loaded {name}.") except FileNotFoundError: print(f" ❌ Checkpoint not found at: {checkpoint_path}. Skipping.") except Exception as e: # FIX: Detailed error reporting to show the full RuntimeError message print(f" ❌ Failed to load {name}. Error: {e.__class__.__name__}. Details: {e}. Skipping.") if not self.models: raise RuntimeError("No models were successfully loaded. Cannot run ensemble.") @torch.no_grad() def run_ensemble(self, image_path: str) -> Dict[str, Any]: """Runs inference across all loaded models and computes the ensemble prediction.""" try: image = Image.open(image_path).convert('RGB') input_tensor = self.transforms(image).unsqueeze(0).to(self.device) except Exception as e: return {"error": f"Failed to load or process image: {e}"} all_probs = [] individual_predictions = {} for name, model in self.models.items(): outputs = model(input_tensor) probs = torch.softmax(outputs, dim=1).cpu().numpy()[0] all_probs.append(probs) pred_idx = np.argmax(probs) pred_conf = probs[pred_idx] individual_predictions[name] = { "class": self.class_names[pred_idx], "confidence": float(pred_conf) } # Ensemble Decision (Weighted Voting) # Use max confidence from each model as the weight weights = np.array([np.max(probs) for probs in all_probs]) # Normalize weights weights = weights / np.sum(weights) # Weighted average of probabilities weighted_avg_probs = np.average(all_probs, axis=0, weights=weights) ensemble_idx = np.argmax(weighted_avg_probs) ensemble_confidence = weighted_avg_probs[ensemble_idx] ensemble_class = self.class_names[ensemble_idx] return { "image_path": image_path, "ensemble_prediction": ensemble_class, "ensemble_confidence": float(ensemble_confidence), "individual_predictions": individual_predictions, "fracture_detected": ensemble_class != "Healthy" } # ---------------------------------------------------------------------- # --- Execution Block --- # ---------------------------------------------------------------------- if __name__ == '__main__': parser = argparse.ArgumentParser(description='Multi-Model Ensemble (Cross-Validation) Agent.') parser.add_argument('--image-path', required=True, help='Path to the image for inference.') parser.add_argument('--checkpoints-dir', required=True, # Made required since default path was confusing help='Absolute path to the directory containing the model checkpoints (e.g., best_swin.pth).') parser.add_argument('--models', type=str, default='swin,mobilenetv2,efficientnetv2,maxvit,densenet169', help='Comma-separated names of the models to load.') parser.add_argument('--num-classes', type=int, default=8) parser.add_argument('--class-names', required=True, help='Comma-separated list of class names.') args = parser.parse_args() models_list = [m.strip() for m in args.models.split(',')] class_names_list = [c.strip() for c in args.class_names.split(',')] try: ensemble_agent = ModelEnsembleAgent( model_names=models_list, checkpoints_dir=args.checkpoints_dir, num_classes=args.num_classes, class_names=class_names_list ) except RuntimeError as e: print(f"\nFATAL ERROR during initialization: {e}") exit(1) result = ensemble_agent.run_ensemble(args.image_path) print("\n--- ENSEMBLE AGENT RESULT ---") if "error" in result: print(f"Error: {result['error']}") else: print(f"Image: {os.path.basename(result['image_path'])}") print(f"FINAL ENSEMBLE PREDICTION: **{result['ensemble_prediction']}** (Confidence: {result['ensemble_confidence']:.4f})") print("\nIndividual Model Predictions:") loaded_model_names = ensemble_agent.models.keys() for name in models_list: if name in loaded_model_names: pred = result['individual_predictions'][name] print(f" {name.upper():<15}: {pred['class']:<20} (Conf: {pred['confidence']:.4f})") else: print(f" {name.upper():<15}: (Skipped/Failed to Load)") print("-----------------------------\n")