Spaces:
Running on Zero
Running on Zero
| """ | |
| Relative Pose Evaluation Hook for Harmon Training. | |
| Performs relative pose-conditioned image generation at regular intervals during training | |
| to visualize and evaluate model progress. | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from typing import List, Optional, Dict | |
| from mmengine.hooks import Hook | |
| from mmengine.registry import HOOKS | |
| # Import dataset classes | |
| from src.datasets.relative_pose.co3d_relpose import CO3DRelativePoseDataset | |
| from src.datasets.relative_pose.objaverse_relpose import ObjaverseRelativePoseDataset | |
| class RelativePoseEvaluationHook(Hook): | |
| """ | |
| Hook to evaluate relative pose-conditioned image generation during training. | |
| Samples pairs from dataset and generates target images from source + relpose. | |
| Saves source, generated, and ground truth images for comparison. | |
| Args: | |
| interval (int): Evaluate every N training iterations. Default: 1000 | |
| dataset_type (str): 'co3d' or 'objaverse'. Default: 'co3d' | |
| base_dir (str): Path to dataset directory | |
| categories (list): For CO3D, list of categories to load. None = all. Default: None | |
| use_set_lists (bool): For CO3D, use filtered train frames. Default: True | |
| metadata_file (str): For Objaverse, metadata JSON filename. Default: 'metadata.json' | |
| num_samples (int): Number of test pairs to evaluate. Default: 4 | |
| num_iter (int): Number of sampling iterations for generation. Default: 64 | |
| cfg (float): Classifier-free guidance scale. Default: 3.0 | |
| temperature (float): Sampling temperature. Default: 1.0 | |
| save_individual (bool): Whether to save individual images. Default: True | |
| seed (int): Random seed for reproducible sampling. Default: 42 | |
| max_dataset_samples (int): Limit dataset size for faster loading. Default: 100 | |
| """ | |
| priority = 'NORMAL' | |
| def __init__(self, | |
| interval: int = 1000, | |
| dataset_type: str = 'co3d', | |
| base_dir: str = None, | |
| categories: Optional[List[str]] = None, | |
| use_set_lists: bool = True, | |
| metadata_file: str = 'metadata.json', | |
| num_samples: int = 4, | |
| num_iter: int = 64, | |
| cfg: float = 3.0, | |
| temperature: float = 1.0, | |
| save_individual: bool = True, | |
| seed: int = 42, | |
| max_dataset_samples: int = 100): | |
| super().__init__() | |
| self.interval = interval | |
| self.dataset_type = dataset_type.lower() | |
| self.base_dir = base_dir | |
| self.categories = categories | |
| self.use_set_lists = use_set_lists | |
| self.metadata_file = metadata_file | |
| self.num_samples = num_samples | |
| self.num_iter = num_iter | |
| self.cfg = cfg | |
| self.temperature = temperature | |
| self.save_individual = save_individual | |
| self.seed = seed | |
| self.max_dataset_samples = max_dataset_samples | |
| self.test_dataset = None | |
| self.test_samples = None | |
| def before_train(self, runner): | |
| """Initialize test dataset before training starts.""" | |
| if self.base_dir is None: | |
| runner.logger.warning("[RelposEvalHook] No base_dir provided, hook disabled") | |
| return | |
| runner.logger.info(f"[RelposEvalHook] Initializing {self.dataset_type} test dataset...") | |
| # Get model's tokenizer and template (already initialized) | |
| model = runner.model | |
| tokenizer = model.tokenizer | |
| prompt_template = model.prompt_template | |
| # Directly instantiate the dataset class | |
| dataset_kwargs = { | |
| 'base_dir': self.base_dir, | |
| 'tokenizer': tokenizer, | |
| 'prompt_template': prompt_template, | |
| 'template_map_fn': None, # Not needed for eval | |
| 'image_size': 512, | |
| 'num_view_tokens': 8, # Changed from num_relpose_tokens | |
| 'tasks': ['relpose2image'], | |
| 'task_weights': [1.0], | |
| 'max_length': 1024, | |
| 'image_length': 1088, | |
| 'min_rotation_deg': 15.0, | |
| 'max_rotation_deg': 90.0, | |
| 'max_samples': self.max_dataset_samples, | |
| } | |
| if self.dataset_type == 'co3d': | |
| dataset_kwargs['categories'] = self.categories | |
| dataset_kwargs['use_set_lists'] = self.use_set_lists | |
| self.test_dataset = CO3DRelativePoseDataset(**dataset_kwargs) | |
| elif self.dataset_type == 'objaverse': | |
| dataset_kwargs['metadata_file'] = self.metadata_file | |
| self.test_dataset = ObjaverseRelativePoseDataset(**dataset_kwargs) | |
| else: | |
| raise ValueError(f"Unknown dataset_type: {self.dataset_type}. Use 'co3d' or 'objaverse'.") | |
| # Sample fixed test indices for reproducibility | |
| np.random.seed(self.seed) | |
| dataset_size = len(self.test_dataset) | |
| self.test_indices = np.random.choice( | |
| dataset_size, | |
| size=min(self.num_samples, dataset_size), | |
| replace=False | |
| ).tolist() | |
| # Pre-load test samples | |
| self.test_samples = [] | |
| for idx in self.test_indices: | |
| sample = self.test_dataset[idx] | |
| self.test_samples.append(sample) | |
| runner.logger.info(f"[RelposEvalHook] Loaded {len(self.test_samples)} test samples") | |
| def after_train_iter(self, runner, batch_idx: int, data_batch=None, outputs=None): | |
| """Called after every training iteration.""" | |
| if self.test_dataset is None: | |
| return | |
| if self.every_n_train_iters(runner, self.interval): | |
| self._run_evaluation(runner) | |
| def _run_evaluation(self, runner): | |
| """Run relative pose evaluation and save images.""" | |
| model = runner.model | |
| iteration = runner.iter | |
| work_dir = runner.work_dir | |
| # Create output directory (include dataset type to avoid conflicts) | |
| eval_dir = os.path.join(work_dir, 'eval_images', f'iter_{iteration:06d}_{self.dataset_type}') | |
| os.makedirs(eval_dir, exist_ok=True) | |
| runner.logger.info(f"\n[RelposEvalHook] Running evaluation at iteration {iteration}") | |
| runner.logger.info(f" Samples: {len(self.test_samples)}") | |
| runner.logger.info(f" Num iter: {self.num_iter}, CFG: {self.cfg}") | |
| # Switch to eval mode | |
| model.eval() | |
| # Generate images for all test samples | |
| results = [] | |
| for i, sample in enumerate(self.test_samples): | |
| result = self._generate_from_sample(model, sample, i, eval_dir) | |
| results.append(result) | |
| # Create comparison grid | |
| grid_image = self._create_comparison_grid(results) | |
| grid_path = os.path.join(eval_dir, 'grid.jpg') | |
| grid_image.save(grid_path, quality=95) | |
| # Save relpose parameters to JSON | |
| params_path = os.path.join(eval_dir, 'relpose_params.json') | |
| params_data = [] | |
| for i, result in enumerate(results): | |
| params_data.append({ | |
| 'sample_idx': i, | |
| 'relpose_params': result['relpose_params'].cpu().tolist(), | |
| 'rotation_deg': result['rotation_deg'], | |
| }) | |
| with open(params_path, 'w') as f: | |
| json.dump(params_data, f, indent=2) | |
| runner.logger.info(f" Saved evaluation to {eval_dir}") | |
| runner.logger.info(f" Grid: {grid_path}") | |
| runner.logger.info(f" Params: {params_path}\n") | |
| # Switch back to train mode | |
| model.train() | |
| def _generate_from_sample(self, model, sample: Dict, sample_idx: int, eval_dir: str) -> Dict: | |
| """ | |
| Generate target image from a sample. | |
| Args: | |
| model: Harmon model | |
| sample: Dataset sample dict | |
| sample_idx: Sample index for naming | |
| eval_dir: Directory to save images | |
| Returns: | |
| Dict with source, generated, ground truth images and params | |
| """ | |
| with torch.no_grad(): | |
| # Extract data from sample | |
| src_pixel_values = sample['src_pixel_values'].unsqueeze(0).to( | |
| device=model.device, dtype=model.dtype | |
| ) | |
| tgt_pixel_values = sample['tgt_pixel_values'].unsqueeze(0).to( | |
| device=model.device, dtype=model.dtype | |
| ) | |
| viewpoint_params = sample['viewpoint_params'].unsqueeze(0).to( | |
| device=model.device, dtype=model.dtype | |
| ) | |
| input_ids = sample['input_ids'].unsqueeze(0).to(device=model.device) | |
| # Compute rotation angle for logging (from azimuth/elevation viewpoint params) | |
| # Note: This is now the absolute viewpoint angle difference, not relative rotation | |
| azimuth = viewpoint_params[0, 0].cpu().float().item() | |
| elevation = viewpoint_params[0, 1].cpu().float().item() | |
| rotation_deg = np.sqrt(azimuth**2 + elevation**2) * 180 / np.pi | |
| # Generate target image using the model's sample_relpose method | |
| # Note: sample_relpose now uses viewpoint system internally | |
| generated = model.sample_relpose( | |
| src_image=src_pixel_values, | |
| viewpoint_params=viewpoint_params, # Actually viewpoint params (azimuth, elevation) | |
| input_ids=input_ids, | |
| num_iter=self.num_iter, | |
| cfg=self.cfg, | |
| temperature=self.temperature, | |
| progress=False | |
| ) | |
| # Convert tensors to PIL Images | |
| src_image = self._tensor_to_pil(src_pixel_values[0]) | |
| gen_image = self._tensor_to_pil(generated[0]) | |
| gt_image = self._tensor_to_pil(tgt_pixel_values[0]) | |
| # Save individual images if requested | |
| if self.save_individual: | |
| src_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_src.jpg')) | |
| gen_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_gen.jpg')) | |
| gt_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_gt.jpg')) | |
| return { | |
| 'src_image': src_image, | |
| 'gen_image': gen_image, | |
| 'gt_image': gt_image, | |
| 'relpose_params': viewpoint_params[0], # Actually viewpoint params | |
| 'rotation_deg': float(rotation_deg), | |
| } | |
| def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: | |
| """ | |
| Convert tensor to PIL Image. | |
| Args: | |
| tensor: (C, H, W) tensor in range [-1, 1] | |
| Returns: | |
| PIL Image | |
| """ | |
| # Denormalize from [-1, 1] to [0, 255] | |
| tensor = (tensor + 1.0) / 2.0 | |
| tensor = torch.clamp(tensor, 0, 1) | |
| # Convert to float32 (NumPy doesn't support bfloat16) | |
| tensor = tensor.to(dtype=torch.float32) | |
| # Convert to numpy and rearrange to HWC | |
| array = tensor.cpu().numpy() | |
| array = np.transpose(array, (1, 2, 0)) # CHW -> HWC | |
| array = (array * 255).astype(np.uint8) | |
| return Image.fromarray(array) | |
| def _create_comparison_grid(self, results: List[Dict]) -> Image.Image: | |
| """ | |
| Create a comparison grid: Source | Generated | Ground Truth (one row per sample). | |
| Args: | |
| results: List of result dicts with images | |
| Returns: | |
| Grid image as PIL Image | |
| """ | |
| num_samples = len(results) | |
| if num_samples == 0: | |
| return Image.new('RGB', (512, 512)) | |
| # Get image size (assume all images same size) | |
| img_width, img_height = results[0]['src_image'].size | |
| # Create grid: 3 columns (src, gen, gt) × num_samples rows | |
| grid_width = img_width * 3 | |
| grid_height = img_height * num_samples | |
| grid = Image.new('RGB', (grid_width, grid_height)) | |
| # Paste images | |
| for row_idx, result in enumerate(results): | |
| y = row_idx * img_height | |
| # Source | |
| grid.paste(result['src_image'], (0, y)) | |
| # Generated | |
| grid.paste(result['gen_image'], (img_width, y)) | |
| # Ground truth | |
| grid.paste(result['gt_image'], (img_width * 2, y)) | |
| return grid | |