import argparse import logging import os from collections import OrderedDict from pathlib import Path from typing import Dict, List, Optional import pandas as pd import torch import torchmetrics from configs.args_base import get_args from data import build_dataloader from models.MIQA_base import get_timm_model, get_torch_model from models.RA_MIQA import RegionVisionTransformer from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES from train import AverageMeter from utils.hf_download_utils import ensure_checkpoint_from_hf def get_checkpoint_path(model_name: str, train_dataset: str, metric_type: str = "composite") -> str: base_dir = Path("models") / "checkpoints" / f"{metric_type}_metric" base_dir.mkdir(parents=True, exist_ok=True) filename = MODEL_FILENAMES[metric_type][train_dataset][model_name] return str(base_dir / filename) def get_available_models(train_dataset: str, metric_type: str) -> List[str]: """ Get list of available models for a specific training dataset and metric type. This helper function is useful for validation and for providing helpful error messages when a user requests a model that isn't available for their chosen configuration. Args: train_dataset: Training dataset type (cls, det, ins) metric_type: Training metric objective (composite, consistency, accuracy) Returns: List of available model names for this configuration """ if metric_type in MODEL_FILENAMES: if train_dataset in MODEL_FILENAMES[metric_type]: return list(MODEL_FILENAMES[metric_type][train_dataset].keys()) return [] def ensure_model_weights(model_name: str, train_dataset: str, metric_type: str, logger: logging.Logger) -> Optional[str]: """ Ensure model weights exist, download if necessary. This function implements a caching strategy: it first checks if the checkpoint already exists locally. If not, it downloads it from Hugging Face Hub. This means the first run will download weights, but subsequent runs will be much faster. Args: model_name: Name of the model architecture train_dataset: Training dataset type (cls, det, or ins) metric_type: Training metric objective (composite, consistency, or accuracy) logger: Logger instance for status messages Returns: Path to checkpoint if successful, None if weights cannot be obtained """ # Generate the expected checkpoint path checkpoint_path = get_checkpoint_path(model_name, train_dataset, metric_type) # First, check if we already have this checkpoint cached locally if os.path.exists(checkpoint_path): logger.info(f"✓ Found existing checkpoint: {checkpoint_path}") return checkpoint_path # Checkpoint not found locally, so we need to download it logger.info(f"Checkpoint not found at {checkpoint_path}") # Verify this model configuration is supported if metric_type not in MODEL_FILENAMES: logger.error(f"✗ Metric type '{metric_type}' not recognized") logger.error(f" Available metric types: {list(MODEL_FILENAMES.keys())}") return None if train_dataset not in MODEL_FILENAMES[metric_type]: logger.error(f"✗ Train dataset '{train_dataset}' not available for metric type '{metric_type}'") return None if model_name not in MODEL_FILENAMES[metric_type][train_dataset]: available_models = get_available_models(train_dataset, metric_type) logger.error(f"✗ Model '{model_name}' not available for {train_dataset}/{metric_type}") logger.error(f" Available models: {available_models}") return None filename = MODEL_FILENAMES[metric_type][train_dataset][model_name] logger.info( f"Attempting to download checkpoint from Hugging Face: " f"repo={HF_REPO_ID}, file={filename}, rev={HF_REVISION}" ) try: local_path = ensure_checkpoint_from_hf( repo_id=HF_REPO_ID, filename=filename, local_dir=str(Path("models") / "checkpoints" / f"{metric_type}_metric"), revision=HF_REVISION, ) logger.info("✓ Successfully downloaded checkpoint from Hugging Face") return local_path except Exception as e: logger.error(f"✗ Failed to download checkpoint from Hugging Face: {e}") return None def load_model_weights(model: torch.nn.Module, weights_path: str, args: argparse.Namespace, logger: logging.Logger) -> bool: """ Load model weights from checkpoint file. This function handles the actual loading of weights into the model, with proper error handling and support for different checkpoint formats (direct state dict or wrapped in a dictionary with metadata). Args: model: The model to load weights into weights_path: Path to the checkpoint file args: Command line arguments logger: Logger instance Returns: True if weights loaded successfully, False otherwise """ if not os.path.isfile(weights_path): logger.error(f"✗ Checkpoint file not found: '{weights_path}'") return False logger.info(f"Loading checkpoint from '{weights_path}'") try: # Load checkpoint to CPU first to avoid GPU memory issues checkpoint = torch.load(weights_path, map_location="cpu") # Extract state dict - handle different checkpoint formats # Some checkpoints store weights directly, others wrap them in a 'state_dict' key state_dict = checkpoint.get('state_dict', checkpoint) # Remove 'module.' prefix if present # This prefix is added when models are trained with DataParallel/DistributedDataParallel new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') if k.startswith('module.') else k new_state_dict[name] = v # Load the processed weights into the model model.load_state_dict(new_state_dict) logger.info(f"✓ Successfully loaded checkpoint") # Log additional useful information from the checkpoint if available if 'epoch' in checkpoint: logger.info(f" Checkpoint epoch: {checkpoint['epoch']}") if 'best_srcc' in checkpoint: logger.info(f" Best SRCC: {checkpoint['best_srcc']:.4f}") if 'metric_type' in checkpoint: logger.info(f" Metric type: {checkpoint['metric_type']}") return True except Exception as e: logger.error(f"✗ Error loading checkpoint: {str(e)}") return False def create_model(model_name: str, args: argparse.Namespace, logger: logging.Logger) -> torch.nn.Module: """ Create model instance based on model name. This function handles the instantiation of different model architectures. It includes special handling for the RegionVisionTransformer (RA_MIQA) which has a different initialization process than standard vision models. Args: model_name: Name of the model architecture args: Command line arguments logger: Logger instance Returns: Initialized model (without loaded weights yet) """ # Special handling for our custom RegionVisionTransformer architecture if model_name == 'ra_miqa': logger.info(f"Creating RA_MIQA Model") model = RegionVisionTransformer( base_model_name='vit_small_patch16_224', pretrained=True, mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' ) else: # For standard architectures, try PyTorch hub first, then fall back to timm try: logger.info(f"Creating model from PyTorch: {model_name}") model = get_torch_model(model_name=model_name, pretrained=False, num_classes=1) except Exception as e: logger.info(f"PyTorch model not found, trying timm library: {model_name}") try: model = get_timm_model(model_name=model_name, pretrained=False, num_classes=1) except Exception as e: logger.error(f"✗ Failed to create model: {str(e)}") raise return model @torch.no_grad() def inference(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, args: argparse.Namespace, criterion: torch.nn.Module, logger: logging.Logger) -> Dict: """ Run inference on validation set and compute metrics. This function performs the actual evaluation of the model on the test dataset. It runs in evaluation mode with no gradient computation, processes all batches, and computes standard image quality assessment metrics (SRCC, PLCC, KLCC). Args: val_loader: DataLoader for validation data model: Model to evaluate args: Command line arguments criterion: Loss function (MSE) logger: Logger instance Returns: Dictionary containing predictions, ground truth, and computed metrics """ # Set model to evaluation mode - this disables dropout and uses running stats for batchnorm model.eval() val_dataset_len = len(val_loader.dataset) val_loader_len = len(val_loader) # Initialize tracking variables for performance monitoring batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') # Storage lists for accumulating results across all batches temp_pred_scores = [] temp_gt_scores = [] temp_img_names = [] logger.info(f"Starting inference on {val_dataset_len} images...") for i, batch in enumerate(val_loader): # Move data to GPU if available image_cropped = batch['image_cropped'].cuda(args.gpu, non_blocking=True) image_resized = batch['image_resized'].cuda(args.gpu, non_blocking=True) target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1) # Forward pass - compute predictions output = model(image_cropped, image_resized) loss = criterion(output.view(-1), target.view(-1)) losses.update(loss.item(), target.size(0)) # Accumulate results for later metric computation temp_pred_scores.append(output.view(-1)) temp_gt_scores.append(target.view(-1)) temp_img_names.extend(batch['image_name']) # Log progress periodically if i % args.print_freq == 0: logger.info( f" [{i}/{val_loader_len}] " f"Loss: {losses.val:.4f} (avg: {losses.avg:.4f})" ) # Concatenate all batch results into single tensors final_preds = torch.cat(temp_pred_scores) final_grotruth = torch.cat(temp_gt_scores) # Handle patch-based predictions if the model uses multiple patches per image if hasattr(args, 'patch_num') and args.patch_num > 1: logger.info(f"Averaging predictions over {args.patch_num} patches per image") preds_matrix = final_preds.view(-1, args.patch_num) final_preds = preds_matrix.mean(dim=-1).squeeze() final_grotruth = final_grotruth.view(-1, args.patch_num).mean(dim=-1).squeeze() logger.info( f"Dataset size: {val_dataset_len}, " f"Predictions shape: {final_preds.shape}, " f"Ground truth shape: {final_grotruth.shape}" ) # Sanity check for invalid values that would corrupt metric computation if torch.isnan(final_preds).any() or torch.isinf(final_preds).any(): raise ValueError("Found NaN or inf values in predictions") if torch.isnan(final_grotruth).any() or torch.isinf(final_grotruth).any(): raise ValueError("Found NaN or inf values in ground truth") # Compute standard image quality assessment metrics # SRCC: Spearman's rank correlation coefficient - measures monotonic relationship test_srcc = torchmetrics.functional.spearman_corrcoef(final_preds, final_grotruth).item() # PLCC: Pearson's linear correlation coefficient - measures linear relationship test_plcc = torchmetrics.functional.pearson_corrcoef(final_preds, final_grotruth).item() # KLCC: Kendall's rank correlation coefficient - another rank-based metric test_klcc = torchmetrics.functional.kendall_rank_corrcoef(final_preds, final_grotruth).item() # Package all results into a dictionary for return results = { 'image_names': temp_img_names, 'predictions': final_preds.cpu().numpy().tolist(), 'ground_truth': final_grotruth.cpu().numpy().tolist(), 'metrics': { 'srcc': test_srcc, 'plcc': test_plcc, 'klcc': test_klcc, 'loss': losses.avg } } return results def save_results(results: Dict, model_name: str, train_dataset: str, test_dataset: str, metric_type: str, output_dir: str, logger: logging.Logger) -> None: """ Save inference results to CSV file with detailed metrics. This function saves both detailed per-image results and prints a summary of the overall performance metrics. The filename includes all relevant configuration details for easy identification. Args: results: Results dictionary from inference model_name: Name of the model train_dataset: Training dataset type test_dataset: Test dataset name metric_type: Training metric objective output_dir: Base directory to save results logger: Logger instance """ # Create the evaluations subdirectory eval_dir = os.path.join(output_dir, 'evaluations') os.makedirs(eval_dir, exist_ok=True) # Prepare detailed per-image results with predictions and errors csv_data = [] for img_name, pred, gt in zip(results['image_names'], results['predictions'], results['ground_truth']): csv_data.append({ 'image_name': img_name, 'prediction': pred, 'ground_truth': gt, 'absolute_error': abs(pred - gt) }) # Create descriptive filename that includes all configuration details csv_filename = f"{model_name}_{train_dataset}_{metric_type}_on_{test_dataset}.csv" csv_path = os.path.join(eval_dir, csv_filename) # Save detailed results to CSV df = pd.DataFrame(csv_data) df.to_csv(csv_path, index=False) logger.info(f"Detailed results saved to: {csv_path}") # Print formatted metrics summary to console and log logger.info("\n" + "=" * 70) logger.info("EVALUATION METRICS") logger.info("=" * 70) logger.info(f"Model: {model_name}") logger.info(f"Trained on: {train_dataset}") logger.info(f"Metric type: {metric_type}") logger.info(f"Tested on: {test_dataset}") logger.info("-" * 70) logger.info(f"SRCC (Spearman): {results['metrics']['srcc']:.4f}") logger.info(f"PLCC (Pearson): {results['metrics']['plcc']:.4f}") logger.info(f"KLCC (Kendall): {results['metrics']['klcc']:.4f}") logger.info(f"MSE Loss: {results['metrics']['loss']:.4f}") logger.info("=" * 70 + "\n") def main(args: argparse.Namespace, logger: logging.Logger) -> None: """ Main inference pipeline orchestrating all steps. This function coordinates the entire evaluation process: validating inputs, ensuring model weights are available, loading data, creating and loading the model, running inference, and saving results. Args: args: Command line arguments logger: Logger instance """ # Validate required arguments if not args.model_name: raise ValueError("Please specify --model_name") if not args.train_dataset: raise ValueError("Please specify --train_dataset (cls, det, or ins)") if not args.test_dataset: raise ValueError("Please specify --test_dataset") if not args.metric_type: raise ValueError("Please specify --metric_type (composite, consistency, or accuracy)") logger.info(f"\nStarting MIQA Inference Pipeline") logger.info(f"Model: {args.model_name}") logger.info(f"Trained on: {args.train_dataset}") logger.info(f"Metric type: {args.metric_type}") logger.info(f"Testing on: {args.test_dataset}") # Ensure model weights are available (download if necessary) checkpoint_path = ensure_model_weights(args.model_name, args.train_dataset, args.metric_type, logger) if checkpoint_path is None: logger.error("Cannot proceed without model weights") return # Build dataset and dataloader logger.info(f"\nLoading {args.test_dataset} dataset...") args.dataset = args.test_dataset # Set dataset name for dataloader builder args.eval_only = True # Indicate evaluation mode val_dataset = build_dataloader.build_dataset(args) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True ) logger.info(f"✓ Loaded {len(val_dataset)} images with {args.workers} workers") # Create model architecture logger.info(f"\nCreating model architecture...") args.arch = args.model_name model = create_model(args.model_name, args, logger) # Load pre-trained weights into model if not load_model_weights(model, checkpoint_path, args, logger): logger.error("Failed to load model weights") return # Move model to GPU if available if args.gpu is not None and torch.cuda.is_available(): model = model.cuda(args.gpu) logger.info(f"✓ Model moved to GPU {args.gpu}") else: logger.warning("GPU not available, using CPU (this will be slower)") # Create loss function for evaluation criterion = torch.nn.MSELoss() # Run inference on the test set logger.info(f"\nRunning inference...") results = inference(val_loader, model, args, criterion, logger) # Save results and print summary save_results(results, args.model_name, args.train_dataset, args.test_dataset, args.metric_type, args.output_dir, logger) if __name__ == '__main__': # Parse command line arguments parser = get_args() parser.add_argument('--model_name', type=str, required=True, choices=['ra_miqa'], help='Model architecture (Hub registry currently ships RA-MIQA only)') parser.add_argument('--train_dataset', type=str, required=True, choices=['cls', 'det', 'ins'], help='Dataset type the model was trained on (cls=classification, det=detection, ins=instance)') parser.add_argument('--test_dataset', type=str, required=True, help='Name of the dataset to test on') parser.add_argument('--metric_type', type=str, required=True, choices=['composite', 'consistency', 'accuracy'], help='Training metric objective used (composite=both metrics, consistency=consistency-focused, accuracy=accuracy-focused)') parser.add_argument('--output_dir', type=str, default='outputs', help='Directory to save results (default: outputs)') args = parser.parse_args() # Create output directory structure os.makedirs(args.output_dir, exist_ok=True) # Configure logging to both file and console log_filename = f"inference_{args.model_name}_{args.train_dataset}_{args.metric_type}.log" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(os.path.join(args.output_dir, log_filename)), logging.StreamHandler() ] ) logger = logging.getLogger('miqa_inference') # Run main inference pipeline with error handling try: main(args, logger) except Exception as e: logger.error(f"Error during inference: {str(e)}", exc_info=True) raise