Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import numpy as np | |
| from PIL import Image | |
| from typing import List, Dict, Any | |
| import timm | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | |
| from src.utils import get_device, get_model, get_transforms | |
| # ---------------------------------------------------------------------- | |
| # --- Global Variables --- | |
| # ---------------------------------------------------------------------- | |
| DEVICE = get_device() | |
| IMG_SIZE = 224 | |
| # ---------------------------------------------------------------------- | |
| # --- Model Ensemble Agent Core (with all fixes) --- | |
| # ---------------------------------------------------------------------- | |
| class ModelEnsembleAgent: | |
| def __init__(self, model_names: List[str], checkpoints_dir: str, num_classes: int, class_names: List[str]): | |
| self.models = {} | |
| self.model_names = model_names | |
| self.num_classes = num_classes | |
| self.class_names = class_names | |
| self.transforms = get_transforms('val', IMG_SIZE) | |
| self.device = DEVICE | |
| self._load_all_models(checkpoints_dir) | |
| def _load_all_models(self, checkpoints_dir: str): | |
| """Loads all specified model checkpoints with strict=False fallback.""" | |
| print(f"Loading {len(self.model_names)} models from {checkpoints_dir} on {self.device}...") | |
| for name in self.model_names: | |
| # FIX: Corrected file naming convention (best_modelname.pth) | |
| checkpoint_path = os.path.join(checkpoints_dir, f"best_{name}.pth") | |
| print(f" Attempting to load {name} from expected path: {checkpoint_path}...") | |
| try: | |
| model = get_model(name, self.num_classes, pretrained=False).to(self.device) | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| # FIX: Filter out incompatible head layers that have size mismatches | |
| # This handles cases where checkpoint was trained with different head architecture | |
| model_state = model.state_dict() | |
| filtered_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key in model_state and model_state[key].shape == value.shape: | |
| filtered_state_dict[key] = value | |
| elif key not in model_state: | |
| # Key doesn't exist in current model, skip it | |
| pass | |
| else: | |
| # Shape mismatch, skip this layer (usually head layers) | |
| print(f" (Skipping layer '{key}' due to shape mismatch: {value.shape} vs {model_state[key].shape})") | |
| # Load only compatible layers | |
| model.load_state_dict(filtered_state_dict, strict=False) | |
| model.eval() | |
| self.models[name] = model | |
| print(f" ✅ Successfully loaded {name}.") | |
| except FileNotFoundError: | |
| print(f" ❌ Checkpoint not found at: {checkpoint_path}. Skipping.") | |
| except Exception as e: | |
| # FIX: Detailed error reporting to show the full RuntimeError message | |
| print(f" ❌ Failed to load {name}. Error: {e.__class__.__name__}. Details: {e}. Skipping.") | |
| if not self.models: | |
| raise RuntimeError("No models were successfully loaded. Cannot run ensemble.") | |
| def run_ensemble(self, image_path: str) -> Dict[str, Any]: | |
| """Runs inference across all loaded models and computes the ensemble prediction.""" | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| input_tensor = self.transforms(image).unsqueeze(0).to(self.device) | |
| except Exception as e: | |
| return {"error": f"Failed to load or process image: {e}"} | |
| all_probs = [] | |
| individual_predictions = {} | |
| for name, model in self.models.items(): | |
| outputs = model(input_tensor) | |
| probs = torch.softmax(outputs, dim=1).cpu().numpy()[0] | |
| all_probs.append(probs) | |
| pred_idx = np.argmax(probs) | |
| pred_conf = probs[pred_idx] | |
| individual_predictions[name] = { | |
| "class": self.class_names[pred_idx], | |
| "confidence": float(pred_conf) | |
| } | |
| # Ensemble Decision (Weighted Voting) | |
| # Use max confidence from each model as the weight | |
| weights = np.array([np.max(probs) for probs in all_probs]) | |
| # Normalize weights | |
| weights = weights / np.sum(weights) | |
| # Weighted average of probabilities | |
| weighted_avg_probs = np.average(all_probs, axis=0, weights=weights) | |
| ensemble_idx = np.argmax(weighted_avg_probs) | |
| ensemble_confidence = weighted_avg_probs[ensemble_idx] | |
| ensemble_class = self.class_names[ensemble_idx] | |
| return { | |
| "image_path": image_path, | |
| "ensemble_prediction": ensemble_class, | |
| "ensemble_confidence": float(ensemble_confidence), | |
| "individual_predictions": individual_predictions, | |
| "fracture_detected": ensemble_class != "Healthy" | |
| } | |
| # ---------------------------------------------------------------------- | |
| # --- Execution Block --- | |
| # ---------------------------------------------------------------------- | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Multi-Model Ensemble (Cross-Validation) Agent.') | |
| parser.add_argument('--image-path', required=True, help='Path to the image for inference.') | |
| parser.add_argument('--checkpoints-dir', required=True, # Made required since default path was confusing | |
| help='Absolute path to the directory containing the model checkpoints (e.g., best_swin.pth).') | |
| parser.add_argument('--models', type=str, default='swin,mobilenetv2,efficientnetv2,maxvit,densenet169', | |
| help='Comma-separated names of the models to load.') | |
| parser.add_argument('--num-classes', type=int, default=8) | |
| parser.add_argument('--class-names', required=True, | |
| help='Comma-separated list of class names.') | |
| args = parser.parse_args() | |
| models_list = [m.strip() for m in args.models.split(',')] | |
| class_names_list = [c.strip() for c in args.class_names.split(',')] | |
| try: | |
| ensemble_agent = ModelEnsembleAgent( | |
| model_names=models_list, | |
| checkpoints_dir=args.checkpoints_dir, | |
| num_classes=args.num_classes, | |
| class_names=class_names_list | |
| ) | |
| except RuntimeError as e: | |
| print(f"\nFATAL ERROR during initialization: {e}") | |
| exit(1) | |
| result = ensemble_agent.run_ensemble(args.image_path) | |
| print("\n--- ENSEMBLE AGENT RESULT ---") | |
| if "error" in result: | |
| print(f"Error: {result['error']}") | |
| else: | |
| print(f"Image: {os.path.basename(result['image_path'])}") | |
| print(f"FINAL ENSEMBLE PREDICTION: **{result['ensemble_prediction']}** (Confidence: {result['ensemble_confidence']:.4f})") | |
| print("\nIndividual Model Predictions:") | |
| loaded_model_names = ensemble_agent.models.keys() | |
| for name in models_list: | |
| if name in loaded_model_names: | |
| pred = result['individual_predictions'][name] | |
| print(f" {name.upper():<15}: {pred['class']:<20} (Conf: {pred['confidence']:.4f})") | |
| else: | |
| print(f" {name.upper():<15}: (Skipped/Failed to Load)") | |
| print("-----------------------------\n") |