import os import sys import torch import argparse import logging from pathlib import Path from typing import List, Dict, Optional, Tuple from collections import OrderedDict import json from datetime import datetime from tqdm import tqdm import numpy as np import cv2 # OpenCV for video processing import matplotlib.pyplot as plt # Matplotlib for plotting import io # Image processing imports from PIL import Image, ImageDraw, ImageFont import torchvision.transforms as transforms # Import your existing model components # Ensure these files (models/, utils/) are in the same directory or accessible in PYTHONPATH from models.MIQA_base import get_torch_model, get_timm_model from models.RA_MIQA import RegionVisionTransformer from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES from utils.hf_download_utils import ensure_checkpoint_from_hf SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv'} class MIQAInference: """ MODIFIED Inference wrapper for MIQA models. Now includes a method to predict on PIL Image objects directly. """ def __init__(self, task: str, model_name: str = 'ra_miqa', metric_type: str = 'composite', device: Optional[str] = None): self.task = task.lower() self.model_name = model_name self.metric_type = metric_type self.logger = self._setup_logger() if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) self.logger.info(f"🚀 Initializing MIQA Inference System") self.logger.info(f" Task: {self.task.upper()}") self.logger.info(f" Model: {self.model_name}") self.logger.info(f" Metric Type: {self.metric_type}") self.logger.info(f" Device: {self.device}") self._validate_config() self.model = self._load_model() self.transforms1, self.transforms2 = self._get_transforms() self.logger.info("✅ System ready for inference\n") def _setup_logger(self) -> logging.Logger: """Configure logging with both file and console output.""" logger = logging.getLogger('MIQA_Inference') logger.setLevel(logging.INFO) if logger.hasHandlers(): return logger logger.propagate = False # Console handler with clean formatting console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_formatter = logging.Formatter('%(message)s') console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) return logger def _validate_config(self) -> None: """Validate that the requested configuration is supported.""" if self.metric_type not in ['composite', 'consistency', 'accuracy']: raise ValueError( f"Invalid metric_type '{self.metric_type}'. " f"Supported: ['composite', 'consistency', 'accuracy']" ) if self.task not in MODEL_FILENAMES[self.metric_type]: raise ValueError( f"Invalid task '{self.task}'. " f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}" ) if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]: available = list(MODEL_FILENAMES[self.metric_type][self.task].keys()) raise ValueError( f"Model '{self.model_name}' not available for task '{self.task}'. " f"Available models: {available}" ) def _get_checkpoint_path(self) -> str: """Generate the path where model checkpoint should be stored.""" base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric' base_dir.mkdir(parents=True, exist_ok=True) filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name] return str(base_dir / filename) def _download_weights(self, checkpoint_path: str) -> bool: """ Download model weights if not present locally. Returns: True if weights are available (already existed or successfully downloaded) """ if os.path.exists(checkpoint_path): self.logger.info(f"✓ Found cached model weights") return True self.logger.info( f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, " f"file={Path(checkpoint_path).name}, rev={HF_REVISION}" ) try: ensure_checkpoint_from_hf( repo_id=HF_REPO_ID, filename=Path(checkpoint_path).name, local_dir=str(Path(checkpoint_path).parent), revision=HF_REVISION, ) self.logger.info("✓ Successfully downloaded model weights") return True except Exception as e: self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}") return False def _create_model(self) -> torch.nn.Module: """Create the model architecture.""" if self.model_name == 'ra_miqa': self.logger.info("Building Region-Aware Vision Transformer...") model = RegionVisionTransformer( base_model_name='vit_small_patch16_224', pretrained=False, # We'll load our trained weights mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' ) else: try: self.logger.info(f"Building {self.model_name} from PyTorch...") model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1) except Exception: self.logger.info(f"Building {self.model_name} from timm library...") model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1) return model def _load_model(self) -> torch.nn.Module: """Load model with pre-trained weights.""" checkpoint_path = self._get_checkpoint_path() # Ensure weights are available if not self._download_weights(checkpoint_path): raise RuntimeError("Cannot proceed without model weights") # Create model architecture self.logger.info("🔧 Loading model...") model = self._create_model() # Load weights checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint.get('state_dict', checkpoint) # Remove 'module.' prefix if present (from DataParallel training) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=True) model = model.to(self.device) model.eval() # Set to evaluation mode self.logger.info("✓ Model loaded successfully") return model def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose | None]: """ Return preprocessing transforms based on model type. """ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) SIMPLE_MEAN = (0.5, 0.5, 0.5) SIMPLE_STD = (0.5, 0.5, 0.5) # Default (for single-input backbones) transform_imagenet = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) transform_simple = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) ]) # 1️⃣ CNNs(ResNet / EfficientNet) if any(k in self.model_name for k in ['resnet', 'efficientnet']): return transform_imagenet, None # 2️⃣ ViT elif 'vit' in self.model_name: return transform_simple, None # 3️⃣ ra_miqa elif 'ra_miqa' in self.model_name: transform_1 = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) ]) transform_2 = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop((288, 288)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) return transform_1, transform_2 # fallback else: print(f"[Warning] Unknown model type '{self.model_name}', using ImageNet normalization.") return transform_imagenet, None @torch.no_grad() def predict_image_object(self, image: Image.Image) -> float: """ NEW METHOD: Run inference on a PIL Image object. """ # Preprocess the image img1 = self.transforms1(image).unsqueeze(0).to(self.device) img2 = self.transforms2(image).unsqueeze(0).to(self.device) if self.transforms2 else None # Run inference based on model input requirements if img2 is None: output = self.model(img1) else: output = self.model(img1, img2) score = output.item() if torch.is_tensor(output) else float(output) return score class VideoMIQAProcessor: """ A wrapper to process videos using the MIQAInference engine and create a visualized output video with scores and plots. """ # --- Visualization Constants --- PANEL_WIDTH = 480 FONT = cv2.FONT_HERSHEY_SIMPLEX FONT_SCALE_L = 1.0 FONT_SCALE_M = 0.8 FONT_COLOR = (255, 255, 255) # White LINE_THICKNESS = 2 # Plotting style plt.style.use('dark_background') def __init__(self, miqa_engine: MIQAInference): self.miqa_engine = miqa_engine self.logger = miqa_engine.logger def _create_score_plot(self, scores: List[float], width: int, height: int) -> np.ndarray: """ Creates a line chart of scores using Matplotlib and returns it as an OpenCV image. """ fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100) ax.plot(scores, color='#4287f5', linewidth=2) ax.set_xlim(0, max(1, len(scores))) ax.set_ylim(0, 1) ax.set_title("Quality Score Fluctuation", fontsize=10) ax.set_xlabel("Frame", fontsize=8) ax.set_ylabel("Score", fontsize=8) ax.grid(True, alpha=0.3) fig.tight_layout(pad=1.5) # Render plot to an in-memory buffer buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) plt.close(fig) # Convert buffer to a PIL Image and then to an OpenCV image plot_img_pil = Image.open(buf) plot_img_np = np.array(plot_img_pil) plot_img_bgr = cv2.cvtColor(plot_img_np, cv2.COLOR_RGBA2BGR) return plot_img_bgr def process_video(self, input_path: str, output_path: str): """ Reads a video, analyzes each frame for quality, and writes an annotated output video. """ self.logger.info(f"📹 Starting processing for: {Path(input_path).name}") cap = cv2.VideoCapture(input_path) if not cap.isOpened(): self.logger.error(f"❌ Failed to open video: {input_path}") return # Video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # New dimensions for output video (with side panel) output_width = orig_width + self.PANEL_WIDTH output_height = orig_height # Setup video writer fourcc = cv2.VideoWriter_fourcc(*'avc1') out = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height)) scores = [] progress_bar = tqdm(range(frame_count), desc="Analyzing frames", ncols=100) for frame_idx in progress_bar: ret, frame = cap.read() if not ret: break # --- MIQA Inference --- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) score = self.miqa_engine.predict_image_object(pil_image) scores.append(score) # --- Visualization Panel --- panel = np.zeros((orig_height, self.PANEL_WIDTH, 3), dtype=np.uint8) # 1. Task Info task_text = f"Task: {self.miqa_engine.task.upper()}" cv2.putText(panel, task_text, (20, 50), self.FONT, self.FONT_SCALE_M, self.FONT_COLOR, self.LINE_THICKNESS) # 2. Current Score score_text = f"Quality Score: {score:.3f}" # Color coding for score text norm_score = max(0, score) if norm_score < 0.5: color = (0, int(255 * (norm_score * 2)), 255) # Red -> Yellow else: color = (0, 255, int(255 * (2 - norm_score * 2))) # Yellow -> Green cv2.putText(panel, score_text, (20, 110), self.FONT, self.FONT_SCALE_L, color, self.LINE_THICKNESS + 1) # 3. Frame Info frame_text = f"Frame: {frame_idx + 1}/{frame_count}" cv2.putText(panel, frame_text, (20, orig_height - 30), self.FONT, self.FONT_SCALE_M, self.FONT_COLOR, 1) # 4. Score Plot if len(scores) > 1: plot_height = 300 plot_width = self.PANEL_WIDTH - 40 # with margins plot_img = self._create_score_plot(scores, plot_width, plot_height) # Position the plot on the panel y_offset = 160 panel[y_offset:y_offset + plot_img.shape[0], 20:20 + plot_img.shape[1]] = plot_img # --- Combine and Write Frame --- combined_frame = np.concatenate((frame, panel), axis=1) out.write(combined_frame) # Release resources cap.release() out.release() self.logger.info(f"✅ Finished processing. Annotated video saved to: {output_path}\n") def main(): """Command-line interface for Video MIQA inference.""" parser = argparse.ArgumentParser( description='MIQA for Video: Machine-centric Image Quality Assessment on Video Frames', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Analyze a single video and save the annotated output python video_annotator_inference.py --input my_video.mp4 --task cls --model ra_miqa # Analyze all videos in a directory python video_annotator_inference.py --input ./video_folder/ --task det --model resnet50 """ ) parser.add_argument('--input', type=str, required=True, help='Path to input video file or a directory containing videos.') parser.add_argument('--task', type=str, required=True, choices=['cls', 'det', 'ins'], help='Task type: cls (classification), det (detection), ins (instance).') parser.add_argument('--model', type=str, default='ra_miqa', choices=['ra_miqa'], help='Model architecture (default: ra_miqa; Hub weights are RA-MIQA only).') parser.add_argument('--metric-type', type=str, default='composite', choices=['composite', 'consistency', 'accuracy'], help='Training metric type (default: composite).') parser.add_argument('--device', type=str, default=None, choices=['cuda', 'cpu'], help='Device to run on (auto-detect if not specified).') parser.add_argument('--output-dir', type=str, default='inference_results', help='Directory to save the output annotated videos.') args = parser.parse_args() try: # Initialize the core inference engine miqa_engine = MIQAInference( task=args.task, model_name=args.model, metric_type=args.metric_type, device=args.device ) # Initialize the video processor video_processor = VideoMIQAProcessor(miqa_engine) # Find videos to process input_path = Path(args.input) videos_to_process = [] if input_path.is_dir(): for ext in SUPPORTED_VIDEO_EXTENSIONS: videos_to_process.extend(input_path.glob(f"*{ext}")) elif input_path.is_file() and input_path.suffix.lower() in SUPPORTED_VIDEO_EXTENSIONS: videos_to_process.append(input_path) if not videos_to_process: raise FileNotFoundError(f"No supported video files found in '{args.input}'") # Create output directory output_dir = Path(args.output_dir) / 'video' /args.task / args.metric_type output_dir.mkdir(parents=True, exist_ok=True) # Process each video for video_path in videos_to_process: output_filename = f"{video_path.stem}_miqa_{args.model}_{args.task}.mp4" output_filepath = str(output_dir / output_filename) video_processor.process_video(str(video_path), output_filepath) except Exception as e: # Use the logger if it exists, otherwise print try: miqa_engine.logger.error(f"\n❌ An error occurred: {str(e)}") except: print(f"\n❌ An error occurred: {str(e)}", file=sys.stderr) sys.exit(1) if __name__ == '__main__': main()