MedAI-ACM / src /agents /cross_validation_agent.py
Tirath5504's picture
deploy
bf07f10
import os
import sys
import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
import numpy as np
from PIL import Image
from typing import List, Dict, Any
import timm
# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from src.utils import get_device, get_model, get_transforms
# ----------------------------------------------------------------------
# --- Global Variables ---
# ----------------------------------------------------------------------
DEVICE = get_device()
IMG_SIZE = 224
# ----------------------------------------------------------------------
# --- Model Ensemble Agent Core (with all fixes) ---
# ----------------------------------------------------------------------
class ModelEnsembleAgent:
def __init__(self, model_names: List[str], checkpoints_dir: str, num_classes: int, class_names: List[str]):
self.models = {}
self.model_names = model_names
self.num_classes = num_classes
self.class_names = class_names
self.transforms = get_transforms('val', IMG_SIZE)
self.device = DEVICE
self._load_all_models(checkpoints_dir)
def _load_all_models(self, checkpoints_dir: str):
"""Loads all specified model checkpoints with strict=False fallback."""
print(f"Loading {len(self.model_names)} models from {checkpoints_dir} on {self.device}...")
for name in self.model_names:
# FIX: Corrected file naming convention (best_modelname.pth)
checkpoint_path = os.path.join(checkpoints_dir, f"best_{name}.pth")
print(f" Attempting to load {name} from expected path: {checkpoint_path}...")
try:
model = get_model(name, self.num_classes, pretrained=False).to(self.device)
checkpoint = torch.load(checkpoint_path, map_location=self.device)
state_dict = checkpoint.get('model_state_dict', checkpoint)
# FIX: Filter out incompatible head layers that have size mismatches
# This handles cases where checkpoint was trained with different head architecture
model_state = model.state_dict()
filtered_state_dict = {}
for key, value in state_dict.items():
if key in model_state and model_state[key].shape == value.shape:
filtered_state_dict[key] = value
elif key not in model_state:
# Key doesn't exist in current model, skip it
pass
else:
# Shape mismatch, skip this layer (usually head layers)
print(f" (Skipping layer '{key}' due to shape mismatch: {value.shape} vs {model_state[key].shape})")
# Load only compatible layers
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
self.models[name] = model
print(f" ✅ Successfully loaded {name}.")
except FileNotFoundError:
print(f" ❌ Checkpoint not found at: {checkpoint_path}. Skipping.")
except Exception as e:
# FIX: Detailed error reporting to show the full RuntimeError message
print(f" ❌ Failed to load {name}. Error: {e.__class__.__name__}. Details: {e}. Skipping.")
if not self.models:
raise RuntimeError("No models were successfully loaded. Cannot run ensemble.")
@torch.no_grad()
def run_ensemble(self, image_path: str) -> Dict[str, Any]:
"""Runs inference across all loaded models and computes the ensemble prediction."""
try:
image = Image.open(image_path).convert('RGB')
input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
except Exception as e:
return {"error": f"Failed to load or process image: {e}"}
all_probs = []
individual_predictions = {}
for name, model in self.models.items():
outputs = model(input_tensor)
probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
all_probs.append(probs)
pred_idx = np.argmax(probs)
pred_conf = probs[pred_idx]
individual_predictions[name] = {
"class": self.class_names[pred_idx],
"confidence": float(pred_conf)
}
# Ensemble Decision (Weighted Voting)
# Use max confidence from each model as the weight
weights = np.array([np.max(probs) for probs in all_probs])
# Normalize weights
weights = weights / np.sum(weights)
# Weighted average of probabilities
weighted_avg_probs = np.average(all_probs, axis=0, weights=weights)
ensemble_idx = np.argmax(weighted_avg_probs)
ensemble_confidence = weighted_avg_probs[ensemble_idx]
ensemble_class = self.class_names[ensemble_idx]
return {
"image_path": image_path,
"ensemble_prediction": ensemble_class,
"ensemble_confidence": float(ensemble_confidence),
"individual_predictions": individual_predictions,
"fracture_detected": ensemble_class != "Healthy"
}
# ----------------------------------------------------------------------
# --- Execution Block ---
# ----------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Multi-Model Ensemble (Cross-Validation) Agent.')
parser.add_argument('--image-path', required=True, help='Path to the image for inference.')
parser.add_argument('--checkpoints-dir', required=True, # Made required since default path was confusing
help='Absolute path to the directory containing the model checkpoints (e.g., best_swin.pth).')
parser.add_argument('--models', type=str, default='swin,mobilenetv2,efficientnetv2,maxvit,densenet169',
help='Comma-separated names of the models to load.')
parser.add_argument('--num-classes', type=int, default=8)
parser.add_argument('--class-names', required=True,
help='Comma-separated list of class names.')
args = parser.parse_args()
models_list = [m.strip() for m in args.models.split(',')]
class_names_list = [c.strip() for c in args.class_names.split(',')]
try:
ensemble_agent = ModelEnsembleAgent(
model_names=models_list,
checkpoints_dir=args.checkpoints_dir,
num_classes=args.num_classes,
class_names=class_names_list
)
except RuntimeError as e:
print(f"\nFATAL ERROR during initialization: {e}")
exit(1)
result = ensemble_agent.run_ensemble(args.image_path)
print("\n--- ENSEMBLE AGENT RESULT ---")
if "error" in result:
print(f"Error: {result['error']}")
else:
print(f"Image: {os.path.basename(result['image_path'])}")
print(f"FINAL ENSEMBLE PREDICTION: **{result['ensemble_prediction']}** (Confidence: {result['ensemble_confidence']:.4f})")
print("\nIndividual Model Predictions:")
loaded_model_names = ensemble_agent.models.keys()
for name in models_list:
if name in loaded_model_names:
pred = result['individual_predictions'][name]
print(f" {name.upper():<15}: {pred['class']:<20} (Conf: {pred['confidence']:.4f})")
else:
print(f" {name.upper():<15}: (Skipped/Failed to Load)")
print("-----------------------------\n")