""" Relative Pose Evaluation Hook for Harmon Training. Performs relative pose-conditioned image generation at regular intervals during training to visualize and evaluate model progress. """ import os import json import torch import numpy as np from PIL import Image from typing import List, Optional, Dict from mmengine.hooks import Hook from mmengine.registry import HOOKS # Import dataset classes from src.datasets.relative_pose.co3d_relpose import CO3DRelativePoseDataset from src.datasets.relative_pose.objaverse_relpose import ObjaverseRelativePoseDataset @HOOKS.register_module() class RelativePoseEvaluationHook(Hook): """ Hook to evaluate relative pose-conditioned image generation during training. Samples pairs from dataset and generates target images from source + relpose. Saves source, generated, and ground truth images for comparison. Args: interval (int): Evaluate every N training iterations. Default: 1000 dataset_type (str): 'co3d' or 'objaverse'. Default: 'co3d' base_dir (str): Path to dataset directory categories (list): For CO3D, list of categories to load. None = all. Default: None use_set_lists (bool): For CO3D, use filtered train frames. Default: True metadata_file (str): For Objaverse, metadata JSON filename. Default: 'metadata.json' num_samples (int): Number of test pairs to evaluate. Default: 4 num_iter (int): Number of sampling iterations for generation. Default: 64 cfg (float): Classifier-free guidance scale. Default: 3.0 temperature (float): Sampling temperature. Default: 1.0 save_individual (bool): Whether to save individual images. Default: True seed (int): Random seed for reproducible sampling. Default: 42 max_dataset_samples (int): Limit dataset size for faster loading. Default: 100 """ priority = 'NORMAL' def __init__(self, interval: int = 1000, dataset_type: str = 'co3d', base_dir: str = None, categories: Optional[List[str]] = None, use_set_lists: bool = True, metadata_file: str = 'metadata.json', num_samples: int = 4, num_iter: int = 64, cfg: float = 3.0, temperature: float = 1.0, save_individual: bool = True, seed: int = 42, max_dataset_samples: int = 100): super().__init__() self.interval = interval self.dataset_type = dataset_type.lower() self.base_dir = base_dir self.categories = categories self.use_set_lists = use_set_lists self.metadata_file = metadata_file self.num_samples = num_samples self.num_iter = num_iter self.cfg = cfg self.temperature = temperature self.save_individual = save_individual self.seed = seed self.max_dataset_samples = max_dataset_samples self.test_dataset = None self.test_samples = None def before_train(self, runner): """Initialize test dataset before training starts.""" if self.base_dir is None: runner.logger.warning("[RelposEvalHook] No base_dir provided, hook disabled") return runner.logger.info(f"[RelposEvalHook] Initializing {self.dataset_type} test dataset...") # Get model's tokenizer and template (already initialized) model = runner.model tokenizer = model.tokenizer prompt_template = model.prompt_template # Directly instantiate the dataset class dataset_kwargs = { 'base_dir': self.base_dir, 'tokenizer': tokenizer, 'prompt_template': prompt_template, 'template_map_fn': None, # Not needed for eval 'image_size': 512, 'num_view_tokens': 8, # Changed from num_relpose_tokens 'tasks': ['relpose2image'], 'task_weights': [1.0], 'max_length': 1024, 'image_length': 1088, 'min_rotation_deg': 15.0, 'max_rotation_deg': 90.0, 'max_samples': self.max_dataset_samples, } if self.dataset_type == 'co3d': dataset_kwargs['categories'] = self.categories dataset_kwargs['use_set_lists'] = self.use_set_lists self.test_dataset = CO3DRelativePoseDataset(**dataset_kwargs) elif self.dataset_type == 'objaverse': dataset_kwargs['metadata_file'] = self.metadata_file self.test_dataset = ObjaverseRelativePoseDataset(**dataset_kwargs) else: raise ValueError(f"Unknown dataset_type: {self.dataset_type}. Use 'co3d' or 'objaverse'.") # Sample fixed test indices for reproducibility np.random.seed(self.seed) dataset_size = len(self.test_dataset) self.test_indices = np.random.choice( dataset_size, size=min(self.num_samples, dataset_size), replace=False ).tolist() # Pre-load test samples self.test_samples = [] for idx in self.test_indices: sample = self.test_dataset[idx] self.test_samples.append(sample) runner.logger.info(f"[RelposEvalHook] Loaded {len(self.test_samples)} test samples") def after_train_iter(self, runner, batch_idx: int, data_batch=None, outputs=None): """Called after every training iteration.""" if self.test_dataset is None: return if self.every_n_train_iters(runner, self.interval): self._run_evaluation(runner) def _run_evaluation(self, runner): """Run relative pose evaluation and save images.""" model = runner.model iteration = runner.iter work_dir = runner.work_dir # Create output directory (include dataset type to avoid conflicts) eval_dir = os.path.join(work_dir, 'eval_images', f'iter_{iteration:06d}_{self.dataset_type}') os.makedirs(eval_dir, exist_ok=True) runner.logger.info(f"\n[RelposEvalHook] Running evaluation at iteration {iteration}") runner.logger.info(f" Samples: {len(self.test_samples)}") runner.logger.info(f" Num iter: {self.num_iter}, CFG: {self.cfg}") # Switch to eval mode model.eval() # Generate images for all test samples results = [] for i, sample in enumerate(self.test_samples): result = self._generate_from_sample(model, sample, i, eval_dir) results.append(result) # Create comparison grid grid_image = self._create_comparison_grid(results) grid_path = os.path.join(eval_dir, 'grid.jpg') grid_image.save(grid_path, quality=95) # Save relpose parameters to JSON params_path = os.path.join(eval_dir, 'relpose_params.json') params_data = [] for i, result in enumerate(results): params_data.append({ 'sample_idx': i, 'relpose_params': result['relpose_params'].cpu().tolist(), 'rotation_deg': result['rotation_deg'], }) with open(params_path, 'w') as f: json.dump(params_data, f, indent=2) runner.logger.info(f" Saved evaluation to {eval_dir}") runner.logger.info(f" Grid: {grid_path}") runner.logger.info(f" Params: {params_path}\n") # Switch back to train mode model.train() def _generate_from_sample(self, model, sample: Dict, sample_idx: int, eval_dir: str) -> Dict: """ Generate target image from a sample. Args: model: Harmon model sample: Dataset sample dict sample_idx: Sample index for naming eval_dir: Directory to save images Returns: Dict with source, generated, ground truth images and params """ with torch.no_grad(): # Extract data from sample src_pixel_values = sample['src_pixel_values'].unsqueeze(0).to( device=model.device, dtype=model.dtype ) tgt_pixel_values = sample['tgt_pixel_values'].unsqueeze(0).to( device=model.device, dtype=model.dtype ) viewpoint_params = sample['viewpoint_params'].unsqueeze(0).to( device=model.device, dtype=model.dtype ) input_ids = sample['input_ids'].unsqueeze(0).to(device=model.device) # Compute rotation angle for logging (from azimuth/elevation viewpoint params) # Note: This is now the absolute viewpoint angle difference, not relative rotation azimuth = viewpoint_params[0, 0].cpu().float().item() elevation = viewpoint_params[0, 1].cpu().float().item() rotation_deg = np.sqrt(azimuth**2 + elevation**2) * 180 / np.pi # Generate target image using the model's sample_relpose method # Note: sample_relpose now uses viewpoint system internally generated = model.sample_relpose( src_image=src_pixel_values, viewpoint_params=viewpoint_params, # Actually viewpoint params (azimuth, elevation) input_ids=input_ids, num_iter=self.num_iter, cfg=self.cfg, temperature=self.temperature, progress=False ) # Convert tensors to PIL Images src_image = self._tensor_to_pil(src_pixel_values[0]) gen_image = self._tensor_to_pil(generated[0]) gt_image = self._tensor_to_pil(tgt_pixel_values[0]) # Save individual images if requested if self.save_individual: src_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_src.jpg')) gen_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_gen.jpg')) gt_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_gt.jpg')) return { 'src_image': src_image, 'gen_image': gen_image, 'gt_image': gt_image, 'relpose_params': viewpoint_params[0], # Actually viewpoint params 'rotation_deg': float(rotation_deg), } def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: """ Convert tensor to PIL Image. Args: tensor: (C, H, W) tensor in range [-1, 1] Returns: PIL Image """ # Denormalize from [-1, 1] to [0, 255] tensor = (tensor + 1.0) / 2.0 tensor = torch.clamp(tensor, 0, 1) # Convert to float32 (NumPy doesn't support bfloat16) tensor = tensor.to(dtype=torch.float32) # Convert to numpy and rearrange to HWC array = tensor.cpu().numpy() array = np.transpose(array, (1, 2, 0)) # CHW -> HWC array = (array * 255).astype(np.uint8) return Image.fromarray(array) def _create_comparison_grid(self, results: List[Dict]) -> Image.Image: """ Create a comparison grid: Source | Generated | Ground Truth (one row per sample). Args: results: List of result dicts with images Returns: Grid image as PIL Image """ num_samples = len(results) if num_samples == 0: return Image.new('RGB', (512, 512)) # Get image size (assume all images same size) img_width, img_height = results[0]['src_image'].size # Create grid: 3 columns (src, gen, gt) × num_samples rows grid_width = img_width * 3 grid_height = img_height * num_samples grid = Image.new('RGB', (grid_width, grid_height)) # Paste images for row_idx, result in enumerate(results): y = row_idx * img_height # Source grid.paste(result['src_image'], (0, y)) # Generated grid.paste(result['gen_image'], (img_width, y)) # Ground truth grid.paste(result['gt_image'], (img_width * 2, y)) return grid