viewtoken-harmon-demo / src /hooks /relpose_eval_hook.py
XinxuanLu's picture
Initial demo
becf13a verified
"""
Relative Pose Evaluation Hook for Harmon Training.
Performs relative pose-conditioned image generation at regular intervals during training
to visualize and evaluate model progress.
"""
import os
import json
import torch
import numpy as np
from PIL import Image
from typing import List, Optional, Dict
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
# Import dataset classes
from src.datasets.relative_pose.co3d_relpose import CO3DRelativePoseDataset
from src.datasets.relative_pose.objaverse_relpose import ObjaverseRelativePoseDataset
@HOOKS.register_module()
class RelativePoseEvaluationHook(Hook):
"""
Hook to evaluate relative pose-conditioned image generation during training.
Samples pairs from dataset and generates target images from source + relpose.
Saves source, generated, and ground truth images for comparison.
Args:
interval (int): Evaluate every N training iterations. Default: 1000
dataset_type (str): 'co3d' or 'objaverse'. Default: 'co3d'
base_dir (str): Path to dataset directory
categories (list): For CO3D, list of categories to load. None = all. Default: None
use_set_lists (bool): For CO3D, use filtered train frames. Default: True
metadata_file (str): For Objaverse, metadata JSON filename. Default: 'metadata.json'
num_samples (int): Number of test pairs to evaluate. Default: 4
num_iter (int): Number of sampling iterations for generation. Default: 64
cfg (float): Classifier-free guidance scale. Default: 3.0
temperature (float): Sampling temperature. Default: 1.0
save_individual (bool): Whether to save individual images. Default: True
seed (int): Random seed for reproducible sampling. Default: 42
max_dataset_samples (int): Limit dataset size for faster loading. Default: 100
"""
priority = 'NORMAL'
def __init__(self,
interval: int = 1000,
dataset_type: str = 'co3d',
base_dir: str = None,
categories: Optional[List[str]] = None,
use_set_lists: bool = True,
metadata_file: str = 'metadata.json',
num_samples: int = 4,
num_iter: int = 64,
cfg: float = 3.0,
temperature: float = 1.0,
save_individual: bool = True,
seed: int = 42,
max_dataset_samples: int = 100):
super().__init__()
self.interval = interval
self.dataset_type = dataset_type.lower()
self.base_dir = base_dir
self.categories = categories
self.use_set_lists = use_set_lists
self.metadata_file = metadata_file
self.num_samples = num_samples
self.num_iter = num_iter
self.cfg = cfg
self.temperature = temperature
self.save_individual = save_individual
self.seed = seed
self.max_dataset_samples = max_dataset_samples
self.test_dataset = None
self.test_samples = None
def before_train(self, runner):
"""Initialize test dataset before training starts."""
if self.base_dir is None:
runner.logger.warning("[RelposEvalHook] No base_dir provided, hook disabled")
return
runner.logger.info(f"[RelposEvalHook] Initializing {self.dataset_type} test dataset...")
# Get model's tokenizer and template (already initialized)
model = runner.model
tokenizer = model.tokenizer
prompt_template = model.prompt_template
# Directly instantiate the dataset class
dataset_kwargs = {
'base_dir': self.base_dir,
'tokenizer': tokenizer,
'prompt_template': prompt_template,
'template_map_fn': None, # Not needed for eval
'image_size': 512,
'num_view_tokens': 8, # Changed from num_relpose_tokens
'tasks': ['relpose2image'],
'task_weights': [1.0],
'max_length': 1024,
'image_length': 1088,
'min_rotation_deg': 15.0,
'max_rotation_deg': 90.0,
'max_samples': self.max_dataset_samples,
}
if self.dataset_type == 'co3d':
dataset_kwargs['categories'] = self.categories
dataset_kwargs['use_set_lists'] = self.use_set_lists
self.test_dataset = CO3DRelativePoseDataset(**dataset_kwargs)
elif self.dataset_type == 'objaverse':
dataset_kwargs['metadata_file'] = self.metadata_file
self.test_dataset = ObjaverseRelativePoseDataset(**dataset_kwargs)
else:
raise ValueError(f"Unknown dataset_type: {self.dataset_type}. Use 'co3d' or 'objaverse'.")
# Sample fixed test indices for reproducibility
np.random.seed(self.seed)
dataset_size = len(self.test_dataset)
self.test_indices = np.random.choice(
dataset_size,
size=min(self.num_samples, dataset_size),
replace=False
).tolist()
# Pre-load test samples
self.test_samples = []
for idx in self.test_indices:
sample = self.test_dataset[idx]
self.test_samples.append(sample)
runner.logger.info(f"[RelposEvalHook] Loaded {len(self.test_samples)} test samples")
def after_train_iter(self, runner, batch_idx: int, data_batch=None, outputs=None):
"""Called after every training iteration."""
if self.test_dataset is None:
return
if self.every_n_train_iters(runner, self.interval):
self._run_evaluation(runner)
def _run_evaluation(self, runner):
"""Run relative pose evaluation and save images."""
model = runner.model
iteration = runner.iter
work_dir = runner.work_dir
# Create output directory (include dataset type to avoid conflicts)
eval_dir = os.path.join(work_dir, 'eval_images', f'iter_{iteration:06d}_{self.dataset_type}')
os.makedirs(eval_dir, exist_ok=True)
runner.logger.info(f"\n[RelposEvalHook] Running evaluation at iteration {iteration}")
runner.logger.info(f" Samples: {len(self.test_samples)}")
runner.logger.info(f" Num iter: {self.num_iter}, CFG: {self.cfg}")
# Switch to eval mode
model.eval()
# Generate images for all test samples
results = []
for i, sample in enumerate(self.test_samples):
result = self._generate_from_sample(model, sample, i, eval_dir)
results.append(result)
# Create comparison grid
grid_image = self._create_comparison_grid(results)
grid_path = os.path.join(eval_dir, 'grid.jpg')
grid_image.save(grid_path, quality=95)
# Save relpose parameters to JSON
params_path = os.path.join(eval_dir, 'relpose_params.json')
params_data = []
for i, result in enumerate(results):
params_data.append({
'sample_idx': i,
'relpose_params': result['relpose_params'].cpu().tolist(),
'rotation_deg': result['rotation_deg'],
})
with open(params_path, 'w') as f:
json.dump(params_data, f, indent=2)
runner.logger.info(f" Saved evaluation to {eval_dir}")
runner.logger.info(f" Grid: {grid_path}")
runner.logger.info(f" Params: {params_path}\n")
# Switch back to train mode
model.train()
def _generate_from_sample(self, model, sample: Dict, sample_idx: int, eval_dir: str) -> Dict:
"""
Generate target image from a sample.
Args:
model: Harmon model
sample: Dataset sample dict
sample_idx: Sample index for naming
eval_dir: Directory to save images
Returns:
Dict with source, generated, ground truth images and params
"""
with torch.no_grad():
# Extract data from sample
src_pixel_values = sample['src_pixel_values'].unsqueeze(0).to(
device=model.device, dtype=model.dtype
)
tgt_pixel_values = sample['tgt_pixel_values'].unsqueeze(0).to(
device=model.device, dtype=model.dtype
)
viewpoint_params = sample['viewpoint_params'].unsqueeze(0).to(
device=model.device, dtype=model.dtype
)
input_ids = sample['input_ids'].unsqueeze(0).to(device=model.device)
# Compute rotation angle for logging (from azimuth/elevation viewpoint params)
# Note: This is now the absolute viewpoint angle difference, not relative rotation
azimuth = viewpoint_params[0, 0].cpu().float().item()
elevation = viewpoint_params[0, 1].cpu().float().item()
rotation_deg = np.sqrt(azimuth**2 + elevation**2) * 180 / np.pi
# Generate target image using the model's sample_relpose method
# Note: sample_relpose now uses viewpoint system internally
generated = model.sample_relpose(
src_image=src_pixel_values,
viewpoint_params=viewpoint_params, # Actually viewpoint params (azimuth, elevation)
input_ids=input_ids,
num_iter=self.num_iter,
cfg=self.cfg,
temperature=self.temperature,
progress=False
)
# Convert tensors to PIL Images
src_image = self._tensor_to_pil(src_pixel_values[0])
gen_image = self._tensor_to_pil(generated[0])
gt_image = self._tensor_to_pil(tgt_pixel_values[0])
# Save individual images if requested
if self.save_individual:
src_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_src.jpg'))
gen_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_gen.jpg'))
gt_image.save(os.path.join(eval_dir, f'sample_{sample_idx:02d}_gt.jpg'))
return {
'src_image': src_image,
'gen_image': gen_image,
'gt_image': gt_image,
'relpose_params': viewpoint_params[0], # Actually viewpoint params
'rotation_deg': float(rotation_deg),
}
def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
"""
Convert tensor to PIL Image.
Args:
tensor: (C, H, W) tensor in range [-1, 1]
Returns:
PIL Image
"""
# Denormalize from [-1, 1] to [0, 255]
tensor = (tensor + 1.0) / 2.0
tensor = torch.clamp(tensor, 0, 1)
# Convert to float32 (NumPy doesn't support bfloat16)
tensor = tensor.to(dtype=torch.float32)
# Convert to numpy and rearrange to HWC
array = tensor.cpu().numpy()
array = np.transpose(array, (1, 2, 0)) # CHW -> HWC
array = (array * 255).astype(np.uint8)
return Image.fromarray(array)
def _create_comparison_grid(self, results: List[Dict]) -> Image.Image:
"""
Create a comparison grid: Source | Generated | Ground Truth (one row per sample).
Args:
results: List of result dicts with images
Returns:
Grid image as PIL Image
"""
num_samples = len(results)
if num_samples == 0:
return Image.new('RGB', (512, 512))
# Get image size (assume all images same size)
img_width, img_height = results[0]['src_image'].size
# Create grid: 3 columns (src, gen, gt) × num_samples rows
grid_width = img_width * 3
grid_height = img_height * num_samples
grid = Image.new('RGB', (grid_width, grid_height))
# Paste images
for row_idx, result in enumerate(results):
y = row_idx * img_height
# Source
grid.paste(result['src_image'], (0, y))
# Generated
grid.paste(result['gen_image'], (img_width, y))
# Ground truth
grid.paste(result['gt_image'], (img_width * 2, y))
return grid