Spaces:
Runtime error
Runtime error
File size: 7,933 Bytes
bf07f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import sys
import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
import numpy as np
from PIL import Image
from typing import List, Dict, Any
import timm
# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from src.utils import get_device, get_model, get_transforms
# ----------------------------------------------------------------------
# --- Global Variables ---
# ----------------------------------------------------------------------
DEVICE = get_device()
IMG_SIZE = 224
# ----------------------------------------------------------------------
# --- Model Ensemble Agent Core (with all fixes) ---
# ----------------------------------------------------------------------
class ModelEnsembleAgent:
def __init__(self, model_names: List[str], checkpoints_dir: str, num_classes: int, class_names: List[str]):
self.models = {}
self.model_names = model_names
self.num_classes = num_classes
self.class_names = class_names
self.transforms = get_transforms('val', IMG_SIZE)
self.device = DEVICE
self._load_all_models(checkpoints_dir)
def _load_all_models(self, checkpoints_dir: str):
"""Loads all specified model checkpoints with strict=False fallback."""
print(f"Loading {len(self.model_names)} models from {checkpoints_dir} on {self.device}...")
for name in self.model_names:
# FIX: Corrected file naming convention (best_modelname.pth)
checkpoint_path = os.path.join(checkpoints_dir, f"best_{name}.pth")
print(f" Attempting to load {name} from expected path: {checkpoint_path}...")
try:
model = get_model(name, self.num_classes, pretrained=False).to(self.device)
checkpoint = torch.load(checkpoint_path, map_location=self.device)
state_dict = checkpoint.get('model_state_dict', checkpoint)
# FIX: Filter out incompatible head layers that have size mismatches
# This handles cases where checkpoint was trained with different head architecture
model_state = model.state_dict()
filtered_state_dict = {}
for key, value in state_dict.items():
if key in model_state and model_state[key].shape == value.shape:
filtered_state_dict[key] = value
elif key not in model_state:
# Key doesn't exist in current model, skip it
pass
else:
# Shape mismatch, skip this layer (usually head layers)
print(f" (Skipping layer '{key}' due to shape mismatch: {value.shape} vs {model_state[key].shape})")
# Load only compatible layers
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
self.models[name] = model
print(f" ✅ Successfully loaded {name}.")
except FileNotFoundError:
print(f" ❌ Checkpoint not found at: {checkpoint_path}. Skipping.")
except Exception as e:
# FIX: Detailed error reporting to show the full RuntimeError message
print(f" ❌ Failed to load {name}. Error: {e.__class__.__name__}. Details: {e}. Skipping.")
if not self.models:
raise RuntimeError("No models were successfully loaded. Cannot run ensemble.")
@torch.no_grad()
def run_ensemble(self, image_path: str) -> Dict[str, Any]:
"""Runs inference across all loaded models and computes the ensemble prediction."""
try:
image = Image.open(image_path).convert('RGB')
input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
except Exception as e:
return {"error": f"Failed to load or process image: {e}"}
all_probs = []
individual_predictions = {}
for name, model in self.models.items():
outputs = model(input_tensor)
probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
all_probs.append(probs)
pred_idx = np.argmax(probs)
pred_conf = probs[pred_idx]
individual_predictions[name] = {
"class": self.class_names[pred_idx],
"confidence": float(pred_conf)
}
# Ensemble Decision (Weighted Voting)
# Use max confidence from each model as the weight
weights = np.array([np.max(probs) for probs in all_probs])
# Normalize weights
weights = weights / np.sum(weights)
# Weighted average of probabilities
weighted_avg_probs = np.average(all_probs, axis=0, weights=weights)
ensemble_idx = np.argmax(weighted_avg_probs)
ensemble_confidence = weighted_avg_probs[ensemble_idx]
ensemble_class = self.class_names[ensemble_idx]
return {
"image_path": image_path,
"ensemble_prediction": ensemble_class,
"ensemble_confidence": float(ensemble_confidence),
"individual_predictions": individual_predictions,
"fracture_detected": ensemble_class != "Healthy"
}
# ----------------------------------------------------------------------
# --- Execution Block ---
# ----------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Multi-Model Ensemble (Cross-Validation) Agent.')
parser.add_argument('--image-path', required=True, help='Path to the image for inference.')
parser.add_argument('--checkpoints-dir', required=True, # Made required since default path was confusing
help='Absolute path to the directory containing the model checkpoints (e.g., best_swin.pth).')
parser.add_argument('--models', type=str, default='swin,mobilenetv2,efficientnetv2,maxvit,densenet169',
help='Comma-separated names of the models to load.')
parser.add_argument('--num-classes', type=int, default=8)
parser.add_argument('--class-names', required=True,
help='Comma-separated list of class names.')
args = parser.parse_args()
models_list = [m.strip() for m in args.models.split(',')]
class_names_list = [c.strip() for c in args.class_names.split(',')]
try:
ensemble_agent = ModelEnsembleAgent(
model_names=models_list,
checkpoints_dir=args.checkpoints_dir,
num_classes=args.num_classes,
class_names=class_names_list
)
except RuntimeError as e:
print(f"\nFATAL ERROR during initialization: {e}")
exit(1)
result = ensemble_agent.run_ensemble(args.image_path)
print("\n--- ENSEMBLE AGENT RESULT ---")
if "error" in result:
print(f"Error: {result['error']}")
else:
print(f"Image: {os.path.basename(result['image_path'])}")
print(f"FINAL ENSEMBLE PREDICTION: **{result['ensemble_prediction']}** (Confidence: {result['ensemble_confidence']:.4f})")
print("\nIndividual Model Predictions:")
loaded_model_names = ensemble_agent.models.keys()
for name in models_list:
if name in loaded_model_names:
pred = result['individual_predictions'][name]
print(f" {name.upper():<15}: {pred['class']:<20} (Conf: {pred['confidence']:.4f})")
else:
print(f" {name.upper():<15}: (Skipped/Failed to Load)")
print("-----------------------------\n") |