miqa
File size: 20,273 Bytes
e29b006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import argparse
import logging
import os
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional

import pandas as pd
import torch
import torchmetrics

from configs.args_base import get_args
from data import build_dataloader
from models.MIQA_base import get_timm_model, get_torch_model
from models.RA_MIQA import RegionVisionTransformer
from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES
from train import AverageMeter
from utils.hf_download_utils import ensure_checkpoint_from_hf


def get_checkpoint_path(model_name: str, train_dataset: str, metric_type: str = "composite") -> str:
    base_dir = Path("models") / "checkpoints" / f"{metric_type}_metric"
    base_dir.mkdir(parents=True, exist_ok=True)
    filename = MODEL_FILENAMES[metric_type][train_dataset][model_name]
    return str(base_dir / filename)


def get_available_models(train_dataset: str, metric_type: str) -> List[str]:
    """
    Get list of available models for a specific training dataset and metric type.

    This helper function is useful for validation and for providing helpful error messages
    when a user requests a model that isn't available for their chosen configuration.

    Args:
        train_dataset: Training dataset type (cls, det, ins)
        metric_type: Training metric objective (composite, consistency, accuracy)

    Returns:
        List of available model names for this configuration
    """
    if metric_type in MODEL_FILENAMES:
        if train_dataset in MODEL_FILENAMES[metric_type]:
            return list(MODEL_FILENAMES[metric_type][train_dataset].keys())
    return []


def ensure_model_weights(model_name: str, train_dataset: str, metric_type: str,
                         logger: logging.Logger) -> Optional[str]:
    """
    Ensure model weights exist, download if necessary.

    This function implements a caching strategy: it first checks if the checkpoint already
    exists locally. If not, it downloads it from Hugging Face Hub. This means
    the first run will download weights, but subsequent runs will be much faster.

    Args:
        model_name: Name of the model architecture
        train_dataset: Training dataset type (cls, det, or ins)
        metric_type: Training metric objective (composite, consistency, or accuracy)
        logger: Logger instance for status messages

    Returns:
        Path to checkpoint if successful, None if weights cannot be obtained
    """
    # Generate the expected checkpoint path
    checkpoint_path = get_checkpoint_path(model_name, train_dataset, metric_type)

    # First, check if we already have this checkpoint cached locally
    if os.path.exists(checkpoint_path):
        logger.info(f"βœ“ Found existing checkpoint: {checkpoint_path}")
        return checkpoint_path

    # Checkpoint not found locally, so we need to download it
    logger.info(f"Checkpoint not found at {checkpoint_path}")

    # Verify this model configuration is supported
    if metric_type not in MODEL_FILENAMES:
        logger.error(f"βœ— Metric type '{metric_type}' not recognized")
        logger.error(f"   Available metric types: {list(MODEL_FILENAMES.keys())}")
        return None

    if train_dataset not in MODEL_FILENAMES[metric_type]:
        logger.error(f"βœ— Train dataset '{train_dataset}' not available for metric type '{metric_type}'")
        return None

    if model_name not in MODEL_FILENAMES[metric_type][train_dataset]:
        available_models = get_available_models(train_dataset, metric_type)
        logger.error(f"βœ— Model '{model_name}' not available for {train_dataset}/{metric_type}")
        logger.error(f"   Available models: {available_models}")
        return None

    filename = MODEL_FILENAMES[metric_type][train_dataset][model_name]
    logger.info(
        f"Attempting to download checkpoint from Hugging Face: "
        f"repo={HF_REPO_ID}, file={filename}, rev={HF_REVISION}"
    )
    try:
        local_path = ensure_checkpoint_from_hf(
            repo_id=HF_REPO_ID,
            filename=filename,
            local_dir=str(Path("models") / "checkpoints" / f"{metric_type}_metric"),
            revision=HF_REVISION,
        )
        logger.info("βœ“ Successfully downloaded checkpoint from Hugging Face")
        return local_path
    except Exception as e:
        logger.error(f"βœ— Failed to download checkpoint from Hugging Face: {e}")
        return None

def load_model_weights(model: torch.nn.Module, weights_path: str, args: argparse.Namespace,
                       logger: logging.Logger) -> bool:
    """
    Load model weights from checkpoint file.

    This function handles the actual loading of weights into the model, with proper error
    handling and support for different checkpoint formats (direct state dict or wrapped
    in a dictionary with metadata).

    Args:
        model: The model to load weights into
        weights_path: Path to the checkpoint file
        args: Command line arguments
        logger: Logger instance

    Returns:
        True if weights loaded successfully, False otherwise
    """
    if not os.path.isfile(weights_path):
        logger.error(f"βœ— Checkpoint file not found: '{weights_path}'")
        return False

    logger.info(f"Loading checkpoint from '{weights_path}'")

    try:
        # Load checkpoint to CPU first to avoid GPU memory issues
        checkpoint = torch.load(weights_path, map_location="cpu")

        # Extract state dict - handle different checkpoint formats
        # Some checkpoints store weights directly, others wrap them in a 'state_dict' key
        state_dict = checkpoint.get('state_dict', checkpoint)

        # Remove 'module.' prefix if present
        # This prefix is added when models are trained with DataParallel/DistributedDataParallel
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace('module.', '') if k.startswith('module.') else k
            new_state_dict[name] = v

        # Load the processed weights into the model
        model.load_state_dict(new_state_dict)
        logger.info(f"βœ“ Successfully loaded checkpoint")

        # Log additional useful information from the checkpoint if available
        if 'epoch' in checkpoint:
            logger.info(f"  Checkpoint epoch: {checkpoint['epoch']}")
        if 'best_srcc' in checkpoint:
            logger.info(f"  Best SRCC: {checkpoint['best_srcc']:.4f}")
        if 'metric_type' in checkpoint:
            logger.info(f"  Metric type: {checkpoint['metric_type']}")

        return True

    except Exception as e:
        logger.error(f"βœ— Error loading checkpoint: {str(e)}")
        return False


def create_model(model_name: str, args: argparse.Namespace, logger: logging.Logger) -> torch.nn.Module:
    """
    Create model instance based on model name.

    This function handles the instantiation of different model architectures. It includes
    special handling for the RegionVisionTransformer (RA_MIQA) which has a different
    initialization process than standard vision models.

    Args:
        model_name: Name of the model architecture
        args: Command line arguments
        logger: Logger instance

    Returns:
        Initialized model (without loaded weights yet)
    """
    # Special handling for our custom RegionVisionTransformer architecture
    if model_name == 'ra_miqa':
        logger.info(f"Creating RA_MIQA Model")
        model = RegionVisionTransformer(
            base_model_name='vit_small_patch16_224',
            pretrained=True,
            mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
            checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth'
        )
    else:
        # For standard architectures, try PyTorch hub first, then fall back to timm
        try:
            logger.info(f"Creating model from PyTorch: {model_name}")
            model = get_torch_model(model_name=model_name, pretrained=False, num_classes=1)
        except Exception as e:
            logger.info(f"PyTorch model not found, trying timm library: {model_name}")
            try:
                model = get_timm_model(model_name=model_name, pretrained=False, num_classes=1)
            except Exception as e:
                logger.error(f"βœ— Failed to create model: {str(e)}")
                raise

    return model


@torch.no_grad()
def inference(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module,
              args: argparse.Namespace, criterion: torch.nn.Module,
              logger: logging.Logger) -> Dict:
    """
    Run inference on validation set and compute metrics.

    This function performs the actual evaluation of the model on the test dataset. It runs
    in evaluation mode with no gradient computation, processes all batches, and computes
    standard image quality assessment metrics (SRCC, PLCC, KLCC).

    Args:
        val_loader: DataLoader for validation data
        model: Model to evaluate
        args: Command line arguments
        criterion: Loss function (MSE)
        logger: Logger instance

    Returns:
        Dictionary containing predictions, ground truth, and computed metrics
    """
    # Set model to evaluation mode - this disables dropout and uses running stats for batchnorm
    model.eval()
    val_dataset_len = len(val_loader.dataset)
    val_loader_len = len(val_loader)

    # Initialize tracking variables for performance monitoring
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')

    # Storage lists for accumulating results across all batches
    temp_pred_scores = []
    temp_gt_scores = []
    temp_img_names = []

    logger.info(f"Starting inference on {val_dataset_len} images...")

    for i, batch in enumerate(val_loader):
        # Move data to GPU if available
        image_cropped = batch['image_cropped'].cuda(args.gpu, non_blocking=True)
        image_resized = batch['image_resized'].cuda(args.gpu, non_blocking=True)
        target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1)

        # Forward pass - compute predictions
        output = model(image_cropped, image_resized)
        loss = criterion(output.view(-1), target.view(-1))
        losses.update(loss.item(), target.size(0))

        # Accumulate results for later metric computation
        temp_pred_scores.append(output.view(-1))
        temp_gt_scores.append(target.view(-1))
        temp_img_names.extend(batch['image_name'])

        # Log progress periodically
        if i % args.print_freq == 0:
            logger.info(
                f"  [{i}/{val_loader_len}] "
                f"Loss: {losses.val:.4f} (avg: {losses.avg:.4f})"
            )

    # Concatenate all batch results into single tensors
    final_preds = torch.cat(temp_pred_scores)
    final_grotruth = torch.cat(temp_gt_scores)

    # Handle patch-based predictions if the model uses multiple patches per image
    if hasattr(args, 'patch_num') and args.patch_num > 1:
        logger.info(f"Averaging predictions over {args.patch_num} patches per image")
        preds_matrix = final_preds.view(-1, args.patch_num)
        final_preds = preds_matrix.mean(dim=-1).squeeze()
        final_grotruth = final_grotruth.view(-1, args.patch_num).mean(dim=-1).squeeze()

    logger.info(
        f"Dataset size: {val_dataset_len}, "
        f"Predictions shape: {final_preds.shape}, "
        f"Ground truth shape: {final_grotruth.shape}"
    )

    # Sanity check for invalid values that would corrupt metric computation
    if torch.isnan(final_preds).any() or torch.isinf(final_preds).any():
        raise ValueError("Found NaN or inf values in predictions")
    if torch.isnan(final_grotruth).any() or torch.isinf(final_grotruth).any():
        raise ValueError("Found NaN or inf values in ground truth")

    # Compute standard image quality assessment metrics
    # SRCC: Spearman's rank correlation coefficient - measures monotonic relationship
    test_srcc = torchmetrics.functional.spearman_corrcoef(final_preds, final_grotruth).item()
    # PLCC: Pearson's linear correlation coefficient - measures linear relationship
    test_plcc = torchmetrics.functional.pearson_corrcoef(final_preds, final_grotruth).item()
    # KLCC: Kendall's rank correlation coefficient - another rank-based metric
    test_klcc = torchmetrics.functional.kendall_rank_corrcoef(final_preds, final_grotruth).item()

    # Package all results into a dictionary for return
    results = {
        'image_names': temp_img_names,
        'predictions': final_preds.cpu().numpy().tolist(),
        'ground_truth': final_grotruth.cpu().numpy().tolist(),
        'metrics': {
            'srcc': test_srcc,
            'plcc': test_plcc,
            'klcc': test_klcc,
            'loss': losses.avg
        }
    }

    return results


def save_results(results: Dict, model_name: str, train_dataset: str,
                 test_dataset: str, metric_type: str, output_dir: str,
                 logger: logging.Logger) -> None:
    """
    Save inference results to CSV file with detailed metrics.

    This function saves both detailed per-image results and prints a summary of the
    overall performance metrics. The filename includes all relevant configuration
    details for easy identification.

    Args:
        results: Results dictionary from inference
        model_name: Name of the model
        train_dataset: Training dataset type
        test_dataset: Test dataset name
        metric_type: Training metric objective
        output_dir: Base directory to save results
        logger: Logger instance
    """
    # Create the evaluations subdirectory
    eval_dir = os.path.join(output_dir, 'evaluations')
    os.makedirs(eval_dir, exist_ok=True)

    # Prepare detailed per-image results with predictions and errors
    csv_data = []
    for img_name, pred, gt in zip(results['image_names'],
                                  results['predictions'],
                                  results['ground_truth']):
        csv_data.append({
            'image_name': img_name,
            'prediction': pred,
            'ground_truth': gt,
            'absolute_error': abs(pred - gt)
        })

    # Create descriptive filename that includes all configuration details
    csv_filename = f"{model_name}_{train_dataset}_{metric_type}_on_{test_dataset}.csv"
    csv_path = os.path.join(eval_dir, csv_filename)

    # Save detailed results to CSV
    df = pd.DataFrame(csv_data)
    df.to_csv(csv_path, index=False)
    logger.info(f"Detailed results saved to: {csv_path}")

    # Print formatted metrics summary to console and log
    logger.info("\n" + "=" * 70)
    logger.info("EVALUATION METRICS")
    logger.info("=" * 70)
    logger.info(f"Model: {model_name}")
    logger.info(f"Trained on: {train_dataset}")
    logger.info(f"Metric type: {metric_type}")
    logger.info(f"Tested on: {test_dataset}")
    logger.info("-" * 70)
    logger.info(f"SRCC (Spearman): {results['metrics']['srcc']:.4f}")
    logger.info(f"PLCC (Pearson):  {results['metrics']['plcc']:.4f}")
    logger.info(f"KLCC (Kendall):  {results['metrics']['klcc']:.4f}")
    logger.info(f"MSE Loss:        {results['metrics']['loss']:.4f}")
    logger.info("=" * 70 + "\n")

def main(args: argparse.Namespace, logger: logging.Logger) -> None:
    """
    Main inference pipeline orchestrating all steps.

    This function coordinates the entire evaluation process: validating inputs,
    ensuring model weights are available, loading data, creating and loading the model,
    running inference, and saving results.

    Args:
        args: Command line arguments
        logger: Logger instance
    """
    # Validate required arguments
    if not args.model_name:
        raise ValueError("Please specify --model_name")
    if not args.train_dataset:
        raise ValueError("Please specify --train_dataset (cls, det, or ins)")
    if not args.test_dataset:
        raise ValueError("Please specify --test_dataset")
    if not args.metric_type:
        raise ValueError("Please specify --metric_type (composite, consistency, or accuracy)")

    logger.info(f"\nStarting MIQA Inference Pipeline")
    logger.info(f"Model: {args.model_name}")
    logger.info(f"Trained on: {args.train_dataset}")
    logger.info(f"Metric type: {args.metric_type}")
    logger.info(f"Testing on: {args.test_dataset}")

    # Ensure model weights are available (download if necessary)
    checkpoint_path = ensure_model_weights(args.model_name, args.train_dataset,
                                           args.metric_type, logger)

    if checkpoint_path is None:
        logger.error("Cannot proceed without model weights")
        return

    # Build dataset and dataloader
    logger.info(f"\nLoading {args.test_dataset} dataset...")
    args.dataset = args.test_dataset  # Set dataset name for dataloader builder

    args.eval_only = True  # Indicate evaluation mode
    val_dataset = build_dataloader.build_dataset(args)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    logger.info(f"βœ“ Loaded {len(val_dataset)} images with {args.workers} workers")

    # Create model architecture
    logger.info(f"\nCreating model architecture...")
    args.arch = args.model_name
    model = create_model(args.model_name, args, logger)

    # Load pre-trained weights into model
    if not load_model_weights(model, checkpoint_path, args, logger):
        logger.error("Failed to load model weights")
        return

    # Move model to GPU if available
    if args.gpu is not None and torch.cuda.is_available():
        model = model.cuda(args.gpu)
        logger.info(f"βœ“ Model moved to GPU {args.gpu}")
    else:
        logger.warning("GPU not available, using CPU (this will be slower)")

    # Create loss function for evaluation
    criterion = torch.nn.MSELoss()

    # Run inference on the test set
    logger.info(f"\nRunning inference...")
    results = inference(val_loader, model, args, criterion, logger)

    # Save results and print summary
    save_results(results, args.model_name, args.train_dataset,
                 args.test_dataset, args.metric_type, args.output_dir, logger)


if __name__ == '__main__':
    # Parse command line arguments
    parser = get_args()
    parser.add_argument('--model_name', type=str, required=True,
                        choices=['ra_miqa'],
                        help='Model architecture (Hub registry currently ships RA-MIQA only)')
    parser.add_argument('--train_dataset', type=str, required=True,
                        choices=['cls', 'det', 'ins'],
                        help='Dataset type the model was trained on (cls=classification, det=detection, ins=instance)')
    parser.add_argument('--test_dataset', type=str, required=True,
                        help='Name of the dataset to test on')
    parser.add_argument('--metric_type', type=str, required=True,
                        choices=['composite', 'consistency', 'accuracy'],
                        help='Training metric objective used (composite=both metrics, consistency=consistency-focused, accuracy=accuracy-focused)')
    parser.add_argument('--output_dir', type=str, default='outputs',
                        help='Directory to save results (default: outputs)')

    args = parser.parse_args()

    # Create output directory structure
    os.makedirs(args.output_dir, exist_ok=True)

    # Configure logging to both file and console
    log_filename = f"inference_{args.model_name}_{args.train_dataset}_{args.metric_type}.log"
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, log_filename)),
            logging.StreamHandler()
        ]
    )

    logger = logging.getLogger('miqa_inference')

    # Run main inference pipeline with error handling
    try:
        main(args, logger)
    except Exception as e:
        logger.error(f"Error during inference: {str(e)}", exc_info=True)
        raise