Spaces:
Running on Zero
Running on Zero
| """ | |
| Viewpoint Evaluation Hook for Harmon Training. | |
| Performs viewpoint-conditioned image generation at regular intervals during training | |
| to visualize and evaluate model progress. | |
| """ | |
| import os | |
| import re | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from typing import List, Optional, Union | |
| from mmengine.hooks import Hook | |
| from mmengine.registry import HOOKS | |
| from mmengine.model import is_model_wrapper | |
| from mmengine.dist import master_only | |
| import torch.distributed as dist | |
| from src.datasets.camera_utils import CameraTransformUtils, compute_angular_offset | |
| try: | |
| import wandb | |
| HAS_WANDB = True | |
| except ImportError: | |
| HAS_WANDB = False | |
| class ViewpointEvaluationHook(Hook): | |
| """ | |
| Hook to evaluate viewpoint-conditioned image generation during training. | |
| Generates images for a grid of viewpoint angles (azimuth × elevation) | |
| and saves them to visualize training progress. | |
| Args: | |
| interval (int): Evaluate every N training iterations. Default: 1000 | |
| prompts (Union[str, List[str]]): Text prompt(s) for image generation. | |
| Can be a single string or list of strings. Default: ["a 3D object"] | |
| prompt (str, optional): Deprecated. Use 'prompts' instead for backward compatibility. | |
| azimuths (List[float]): List of azimuth angles in degrees. | |
| Default: [0, 45, 90, 135, 180, 225, 270, 315] | |
| elevations (List[float]): List of elevation angles in degrees. | |
| Default: [10, 30] | |
| radius (float): Camera distance from origin for camera matrix creation. Default: 5.0 | |
| num_iter (int): Number of sampling iterations for generation. Default: 64 | |
| cfg (float): Classifier-free guidance scale. Default: 3.0 | |
| temperature (float): Sampling temperature. Default: 1.0 | |
| save_individual (bool): Whether to save individual images. Default: True | |
| front_bg_indicator (bool): If True, add "real background" to prompts. Default: False | |
| """ | |
| priority = 'NORMAL' | |
| def __init__(self, | |
| interval: int = 1000, | |
| prompts: Optional[Union[str, List[str]]] = None, | |
| prompt: Optional[str] = None, # Backward compatibility | |
| azimuths: List[float] = [0, 45, 90, 135, 180, 225, 270, 315], | |
| elevations: List[float] = [10, 30], | |
| radius: float = 5.0, | |
| num_iter: int = 64, | |
| cfg: float = 3.0, | |
| temperature: float = 1.0, | |
| save_individual: bool = True, | |
| num_view_tokens: int = 2, | |
| viewpoint_param_type: str = 'spherical', | |
| view_token_placement: str = 'surround', | |
| front_bg_indicator: bool = False, | |
| dtype: torch.dtype = torch.bfloat16): | |
| super().__init__() | |
| self.interval = interval | |
| self.viewpoint_param_type = viewpoint_param_type | |
| self.num_view_tokens = num_view_tokens | |
| self.dtype = dtype | |
| self.front_bg_indicator = front_bg_indicator | |
| # Validate and store view token placement | |
| if view_token_placement not in ['surround', 'front', 'random']: | |
| raise ValueError(f"view_token_placement must be 'surround', 'front' or 'random', got '{view_token_placement}'") | |
| self.view_token_placement = view_token_placement | |
| # Handle backward compatibility: prompt vs prompts | |
| if prompts is not None: | |
| # Convert single string to list if needed | |
| self.prompts = [prompts] if isinstance(prompts, str) else prompts | |
| elif prompt is not None: | |
| # Backward compatibility with old 'prompt' parameter | |
| self.prompts = [prompt] | |
| else: | |
| # Default value | |
| self.prompts = ["a 3D object"] | |
| # Determine object counts from prompt format | |
| # String: 1 object (e.g., 'lion') | |
| # List: N objects (e.g., ['lion', 'girl']) | |
| self.object_counts = [] | |
| for p in self.prompts: | |
| if isinstance(p, list): | |
| self.object_counts.append(len(p)) | |
| else: | |
| self.object_counts.append(1) | |
| self.azimuths = azimuths | |
| self.elevations = elevations | |
| self.radius = radius | |
| self.num_iter = num_iter | |
| self.cfg = cfg | |
| self.temperature = temperature | |
| self.save_individual = save_individual | |
| def after_train_iter(self, runner, batch_idx: int, data_batch=None, outputs=None): | |
| """Called after every training iteration.""" | |
| if self.every_n_train_iters(runner, self.interval): | |
| self._run_evaluation(runner) | |
| # Barrier so non-master ranks wait for rank 0 to finish eval. | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| def _sanitize_prompt_name(self, prompt: str, max_length: int = 50) -> str: | |
| """ | |
| Convert prompt to a valid filename. | |
| Args: | |
| prompt: Text prompt | |
| max_length: Maximum length of the sanitized name | |
| Returns: | |
| Sanitized filename-safe string | |
| """ | |
| # Convert to lowercase and replace spaces with underscores | |
| sanitized = prompt.lower().replace(' ', '_') | |
| # Keep only alphanumeric characters and underscores | |
| sanitized = re.sub(r'[^a-z0-9_]', '', sanitized) | |
| # Truncate to max length | |
| sanitized = sanitized[:max_length] | |
| return sanitized | |
| def _create_camera_matrix(self, azimuth_deg: float, elevation_deg: float, radius: float, target: np.ndarray = None): | |
| """ | |
| Create camera matrix from spherical coordinates. | |
| Args: | |
| azimuth_deg: Azimuth angle in degrees | |
| elevation_deg: Elevation angle in degrees | |
| radius: Distance from origin | |
| Returns: | |
| R: (3, 3) rotation matrix (Blender convention) | |
| T: (3,) translation vector (Blender convention) | |
| """ | |
| azimuth_rad = azimuth_deg * np.pi / 180.0 | |
| elevation_rad = elevation_deg * np.pi / 180.0 | |
| # Camera position in Blender coordinates | |
| x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad) | |
| y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad) | |
| z = radius * np.sin(elevation_rad) | |
| C = np.array([x, y, z], dtype=np.float32) | |
| # Look-at target (origin) | |
| if target is None: | |
| target = np.array([0, 0, 0], dtype=np.float32) | |
| # Create rotation matrix (Blender convention) | |
| R = CameraTransformUtils.create_lookat_rotation(C, target) | |
| T = -R @ C | |
| return R, T | |
| def _run_evaluation(self, runner): | |
| """Run viewpoint evaluation and save images for all prompts. | |
| Note: Only runs on rank 0 (GPU 0) to avoid redundant computation and file conflicts. | |
| """ | |
| model = runner.model | |
| iteration = runner.iter | |
| work_dir = runner.work_dir | |
| # Create base output directory | |
| base_eval_dir = os.path.join(work_dir, 'eval_images', f'iter_{iteration:06d}') | |
| os.makedirs(base_eval_dir, exist_ok=True) | |
| runner.logger.info(f"\n[ViewpointEvalHook] Running evaluation at iteration {iteration}") | |
| # Detect and log viewpoint parameter type | |
| runner.logger.info(f" Viewpoint param type: {self.viewpoint_param_type}") | |
| runner.logger.info(f" Radius: {self.radius}") | |
| runner.logger.info(f" Number of prompts: {len(self.prompts)}") | |
| runner.logger.info(f" Azimuths: {self.azimuths}") | |
| runner.logger.info(f" Elevations: {self.elevations}") | |
| # Switch to eval mode | |
| model.eval() | |
| # Evaluate each prompt | |
| for prompt_idx, prompt in enumerate(self.prompts): | |
| num_objects = self.object_counts[prompt_idx] | |
| # Convert prompt to string for logging/directory | |
| if isinstance(prompt, list): | |
| prompt_str = ' and '.join(prompt) | |
| else: | |
| prompt_str = prompt | |
| runner.logger.info(f"\n [{prompt_idx + 1}/{len(self.prompts)}] Evaluating prompt: '{prompt_str}' ({num_objects} object(s))") | |
| # Create prompt-specific directory | |
| prompt_name = self._sanitize_prompt_name(prompt_str) | |
| prompt_dir = os.path.join(base_eval_dir, prompt_name) | |
| os.makedirs(prompt_dir, exist_ok=True) | |
| # Generate images for all viewpoint combinations | |
| all_images = [] | |
| for elevation in self.elevations: | |
| row_images = [] | |
| for idx, azimuth in enumerate(self.azimuths): | |
| image = self._generate_image(model, prompt, azimuth, elevation, idx) | |
| row_images.append(image) | |
| # Save individual image if requested | |
| if self.save_individual: | |
| img_path = os.path.join(prompt_dir, f'az{int(azimuth):03d}_el{int(elevation):03d}.jpg') | |
| image.save(img_path) | |
| all_images.append(row_images) | |
| # Create and save grid | |
| grid_image = self._create_grid(all_images) | |
| grid_path = os.path.join(prompt_dir, 'grid.jpg') | |
| grid_image.save(grid_path) | |
| # Log to wandb if available | |
| if HAS_WANDB and wandb.run is not None: | |
| wandb.log({ | |
| f"eval/{prompt_name}": wandb.Image(grid_image, caption=prompt_str), | |
| }, step=iteration) | |
| runner.logger.info(f" Saved to {prompt_dir}/") | |
| runner.logger.info(f"\n All evaluation images saved to {base_eval_dir}\n") | |
| # Switch back to train mode | |
| model.train() | |
| def _generate_image(self, model, prompt: Union[str, List[str]], azimuth: float, elevation: float, idx=0) -> Image.Image: | |
| """ | |
| Generate a single image for given viewpoint. | |
| Args: | |
| model: Harmon model | |
| prompt: Text description (str for 1 object, List[str] for 2 objects) | |
| azimuth: Azimuth angle in degrees | |
| elevation: Elevation angle in degrees | |
| Returns: | |
| PIL Image | |
| """ | |
| # Detect number of objects from prompt type | |
| num_objects = len(prompt) if isinstance(prompt, list) else 1 | |
| with torch.no_grad(): | |
| # Unwrap DDP model if needed to access device | |
| actual_model = model.module if is_model_wrapper(model) else model | |
| device = actual_model.device | |
| # Build caption with view tokens | |
| caption_with_tokens = self._build_caption_with_tokens(prompt) | |
| if self.viewpoint_param_type == 'spherical': | |
| # Spherical mode: just azimuth and elevation in radians | |
| azimuth_rad = azimuth * np.pi / 180.0 | |
| elevation_rad = elevation * np.pi / 180.0 | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor( | |
| [azimuth_rad, elevation_rad], | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.tensor( | |
| [True, True], | |
| dtype=torch.bool | |
| ).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| # Same viewpoint for both objects (flattened: [az1, el1, az2, el2]) | |
| viewpoint_params = torch.tensor( | |
| [azimuth_rad, elevation_rad, azimuth_rad, elevation_rad], | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.tensor( | |
| [True, True, True, True], | |
| dtype=torch.bool | |
| ).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| elif self.viewpoint_param_type == 'azimuth_only': | |
| # Azimuth-only mode: just azimuth in radians (no elevation) | |
| azimuth_rad = azimuth * np.pi / 180.0 | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor( | |
| [azimuth_rad], | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.tensor( | |
| [True], | |
| dtype=torch.bool | |
| ).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| # Same azimuth for both objects (flattened: [az1, az2]) | |
| viewpoint_params = torch.tensor( | |
| [azimuth_rad, -azimuth_rad], | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.tensor( | |
| [True, True], | |
| dtype=torch.bool | |
| ).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| elif self.viewpoint_param_type == 'rotation_translation': | |
| # Rotation_translation mode: create camera matrix and compute relative pose | |
| # Create camera matrix for current viewpoint | |
| target = np.array([0.4, 0.0, 0.4], dtype=np.float32) | |
| R, T = self._create_camera_matrix(azimuth, elevation, self.radius, target) | |
| T = T / 7.0 | |
| # Convert to 9D rotation representation | |
| rot_9d = R.flatten() | |
| # Concatenate: [rot_9d (9), translation (3)] | |
| viewpoint_params_np = np.concatenate([rot_9d, T]) | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(12, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| # Same viewpoint for both objects (flattened: [rot1, trans1, rot2, trans2]) | |
| viewpoint_params_np_multi = np.concatenate([viewpoint_params_np, viewpoint_params_np]) | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np_multi, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(24, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| elif self.viewpoint_param_type == 'relative_rotation_translation': | |
| # Rotation_translation mode: create camera matrix and compute relative pose | |
| # Create camera matrix for current viewpoint | |
| target = np.array([0.2, 0.2, 0.2], dtype=np.float32) | |
| R, T = self._create_camera_matrix(azimuth, elevation, self.radius, target) | |
| # Compute default camera (canonical reference) | |
| R_default, T_default = self._create_camera_matrix(0, 0, 4) | |
| # Compute relative pose from default camera | |
| R_rel = R @ R_default.T | |
| T_rel = T - R_rel @ T_default | |
| # Scale translation down for stability | |
| T_rel = T_rel / 7.0 | |
| # Convert to 9D rotation representation | |
| rot_9d = R_rel.flatten() | |
| # Concatenate: [rot_9d (9), translation (3)] | |
| viewpoint_params_np = np.concatenate([rot_9d, T_rel]) | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(12, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| # Same viewpoint for both objects (flattened: [rot1, trans1, rot2, trans2]) | |
| viewpoint_params_np_multi = np.concatenate([viewpoint_params_np, viewpoint_params_np]) | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np_multi, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(24, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| elif self.viewpoint_param_type == 'factorized': | |
| # Factorized mode: azimuth, elevation, radius, pitch, yaw | |
| target = np.array([0.2, 0.2, 0.2], dtype=np.float32) | |
| R, T = self._create_camera_matrix(azimuth, elevation, self.radius, target) | |
| # Compute camera position in world coordinates: C = -R^T @ T | |
| camera_position = -R.T @ T | |
| # Compute radius (distance from origin) | |
| radius = np.linalg.norm(camera_position) | |
| # Normalize radius to [-1, 1] for range [3, 8] | |
| radius_normalized = (radius - 5.5) / 2.5 | |
| # Compute pitch and yaw using compute_angular_offset | |
| R_torch = torch.from_numpy(R).float() | |
| T_torch = torch.from_numpy(T).float() | |
| angular_offset = compute_angular_offset(R_torch, T_torch, normalizer=1.0) # normalizer=1.0 since we already computed position | |
| pitch = angular_offset[0].item() # radians | |
| yaw = angular_offset[1].item() # radians | |
| # Convert azimuth and elevation to radians | |
| azimuth_rad = azimuth * np.pi / 180.0 | |
| if azimuth_rad > np.pi: | |
| azimuth_rad -= 2 * np.pi # Convert to [-pi, pi] | |
| elevation_rad = elevation * np.pi / 180.0 | |
| # Build viewpoint params: [azimuth, elevation, radius_norm, pitch, yaw] | |
| if idx == 0: | |
| pitch = 0.15 | |
| yaw = 0.15 | |
| elif idx == 1: | |
| pitch = -0.15 | |
| yaw = -0.15 | |
| elif idx == 2: | |
| pitch = 0.15 | |
| yaw = -0.15 | |
| elif idx == 3: | |
| pitch = -0.15 | |
| yaw = 0.15 | |
| else: | |
| pitch = 0.2 | |
| yaw = 0.2 | |
| viewpoint_params_np = np.array([azimuth_rad, elevation_rad, radius_normalized, pitch, yaw], dtype=np.float32) | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(5, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| # Same viewpoint for both objects (flattened: [az1, el1, r1, p1, y1, az2, el2, r2, p2, y2]) | |
| viewpoint_params_np_multi = np.concatenate([viewpoint_params_np, viewpoint_params_np]) | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np_multi, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(10, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| elif self.viewpoint_param_type == 'rotation_factorized': | |
| # Rotation factorized mode: R_rel (9D) + azimuth + elevation + radius | |
| target = np.array([0.3, 0.3, 0.3], dtype=np.float32) | |
| R_actual, T = self._create_camera_matrix(azimuth, elevation, self.radius, target) | |
| # Compute camera position in world coordinates: C = -R^T @ T | |
| camera_position = -R_actual.T @ T | |
| # Compute radius (distance from origin) | |
| radius = np.linalg.norm(camera_position) | |
| # Normalize radius to [-1, 1] for range [3, 8] | |
| radius_normalized = (radius - 5.5) / 2.5 | |
| # Create canonical rotation matrix (camera looking at origin from current position) | |
| target_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32) | |
| up_vector = np.array([0.0, 0.0, 1.0], dtype=np.float32) | |
| R_canonical = CameraTransformUtils.create_lookat_rotation( | |
| camera_position, target_pos, up_vector | |
| ) | |
| # Compute relative rotation: R_rel = R_canonical.T @ R_actual | |
| R_rel = R_canonical.T @ R_actual | |
| R_rel_9d = R_rel.flatten() | |
| # Convert azimuth and elevation to radians | |
| azimuth_rad = azimuth * np.pi / 180.0 | |
| if azimuth_rad > np.pi: | |
| azimuth_rad -= 2 * np.pi # Convert to [-pi, pi] | |
| elevation_rad = elevation * np.pi / 180.0 | |
| # Build viewpoint params: [R_rel_9d (9), azimuth, elevation, radius_normalized] | |
| viewpoint_params_np = np.concatenate([R_rel_9d, [azimuth_rad, elevation_rad, radius_normalized]]) | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(12, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| # Same viewpoint for both objects (flattened: [R_rel1, az1, el1, r1, R_rel2, az2, el2, r2]) | |
| viewpoint_params_np_multi = np.concatenate([viewpoint_params_np, viewpoint_params_np]) | |
| viewpoint_params = torch.tensor( | |
| viewpoint_params_np_multi, | |
| dtype=torch.float32 | |
| ).to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(24, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| elif self.viewpoint_param_type == 'plucker': | |
| # Plucker mode: compute direction and moment from camera matrix | |
| target = np.array([0.2, 0.2, 0.2], dtype=np.float32) | |
| R, T = self._create_camera_matrix(azimuth, elevation, self.radius, target) | |
| # Camera position in world coordinates: o = -R.T @ T | |
| camera_position = -R.T @ T | |
| # Camera viewing direction (forward) = negative of Z axis (row 2) | |
| direction = -R[2, :] # Already unit vector | |
| # Moment vector: m = o × d | |
| moment = np.cross(camera_position, direction) | |
| viewpoint_params_np = np.concatenate([direction, moment]) | |
| if num_objects == 1: | |
| viewpoint_params = torch.tensor(viewpoint_params_np, dtype=torch.float32) | |
| viewpoint_params = viewpoint_params.to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(6, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = None | |
| else: # num_objects == 2 | |
| viewpoint_params_np_multi = np.concatenate([viewpoint_params_np, viewpoint_params_np]) | |
| viewpoint_params = torch.tensor(viewpoint_params_np_multi, dtype=torch.float32) | |
| viewpoint_params = viewpoint_params.to(device=device, dtype=self.dtype).unsqueeze(0) | |
| valid_mask = torch.ones(12, dtype=torch.bool).to(device=device).unsqueeze(0) | |
| num_objects_tensor = torch.tensor([2], dtype=torch.long).to(device=device) | |
| else: | |
| raise ValueError(f"Unknown viewpoint_param_type: {self.viewpoint_param_type}") | |
| # Build conditional prompt (without applying template; handled by model) | |
| if "a fighter jet with attached missiles" in caption_with_tokens: | |
| # Special case to avoid issues with certain prompts | |
| if self.front_bg_indicator: | |
| conditional_input = ( | |
| "Generate an image: real background, {}" | |
| .format(caption_with_tokens) | |
| ) | |
| else: | |
| conditional_input = ( | |
| "Generate an image: {}" | |
| .format(caption_with_tokens) | |
| ) | |
| else: | |
| if self.front_bg_indicator: | |
| conditional_input = ( | |
| "Generate an image: real background, A photo of {} on a desert landscape with cactis in the background" | |
| .format(caption_with_tokens) | |
| ) | |
| else: | |
| conditional_input = ( | |
| "Generate an image: A photo of {} on a desert landscape with cactis in the background" | |
| .format(caption_with_tokens) | |
| ) | |
| # Prepare conditional/unconditional text conditions using model helper | |
| class_info = actual_model.prepare_text_conditions( | |
| conditional_input, | |
| cfg_prompt="Generate an image." | |
| ) | |
| input_ids = class_info['input_ids'] | |
| attention_mask = class_info['attention_mask'] | |
| # Convert to embeddings | |
| inputs_embeds = actual_model.llm.get_input_embeddings()(input_ids).to(dtype=self.dtype) | |
| # Inject viewpoint embeddings only into the conditional branch | |
| cond_inputs_embeds = inputs_embeds[:1].clone() | |
| cond_input_ids = input_ids[:1] | |
| cond_inputs_embeds = actual_model.inject_viewpoint_embeddings( | |
| cond_input_ids, | |
| viewpoint_params, | |
| cond_inputs_embeds, | |
| valid_mask, | |
| num_objects=num_objects_tensor | |
| ) | |
| if self.cfg != 1.0: | |
| # Replace conditional row and keep unconditional untouched | |
| inputs_embeds = torch.cat([cond_inputs_embeds, inputs_embeds[1:]], dim=0) | |
| else: | |
| inputs_embeds = cond_inputs_embeds | |
| input_ids = input_ids[:1] | |
| attention_mask = attention_mask[:1] | |
| # Generate image | |
| images = actual_model.sample( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| num_iter=self.num_iter, | |
| cfg=self.cfg, | |
| temperature=self.temperature, | |
| progress=True, | |
| image_shape=(32, 32), | |
| ) | |
| # Convert to PIL Image | |
| image = self._tensor_to_pil(images[0]) | |
| return image | |
| def _build_caption_with_tokens(self, prompt: Union[str, List[str]]) -> str: | |
| """Build caption with view tokens inserted. | |
| Args: | |
| prompt: Single string ('lion') or list of strings (['lion', 'girl']) | |
| Returns: | |
| Caption with view tokens inserted | |
| """ | |
| view_tokens = [f"<view_token_{i}>" for i in range(self.num_view_tokens)] | |
| if isinstance(prompt, list): | |
| # Multi-object: duplicate all tokens for each object | |
| view_token_str = "".join(view_tokens) | |
| # Build: "<all_tokens> a lion and <all_tokens> a girl" | |
| caption_with_tokens = f"{view_token_str} a {prompt[0]} and {view_token_str} a {prompt[1]}" | |
| else: | |
| # Single object (existing logic) | |
| if self.view_token_placement == 'surround': | |
| # Surround mode: tokens split around prompt | |
| half_num = self.num_view_tokens // 2 | |
| caption_with_tokens = ( | |
| "".join(view_tokens[:half_num]) + " " + | |
| prompt + " " + | |
| "".join(view_tokens[half_num:]) | |
| ) | |
| elif self.view_token_placement == 'front' or self.view_token_placement == 'random': | |
| # Front mode: all tokens at the front | |
| caption_with_tokens = "".join(view_tokens) + " " + prompt | |
| else: | |
| raise ValueError(f"Invalid view_token_placement: {self.view_token_placement}") | |
| return caption_with_tokens | |
| def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: | |
| """ | |
| Convert tensor to PIL Image. | |
| Args: | |
| tensor: (C, H, W) tensor in range [-1, 1] | |
| Returns: | |
| PIL Image | |
| """ | |
| # Denormalize from [-1, 1] to [0, 255] | |
| tensor = (tensor + 1.0) / 2.0 | |
| tensor = torch.clamp(tensor, 0, 1) | |
| # Convert to float32 (NumPy doesn't support bfloat16) | |
| tensor = tensor.to(dtype=torch.float32) | |
| # Convert to numpy and rearrange to HWC | |
| array = tensor.cpu().numpy() | |
| array = np.transpose(array, (1, 2, 0)) # CHW -> HWC | |
| array = (array * 255).astype(np.uint8) | |
| return Image.fromarray(array) | |
| def _create_grid(self, images: List[List[Image.Image]]) -> Image.Image: | |
| """ | |
| Create a grid of images. | |
| Args: | |
| images: List of rows, each row is a list of PIL Images | |
| Returns: | |
| Grid image as PIL Image | |
| """ | |
| rows = len(images) | |
| cols = len(images[0]) | |
| # Get image size (assume all images same size) | |
| img_width, img_height = images[0][0].size | |
| # Create grid canvas | |
| grid_width = cols * img_width | |
| grid_height = rows * img_height | |
| grid = Image.new('RGB', (grid_width, grid_height)) | |
| # Paste images | |
| for row_idx, row in enumerate(images): | |
| for col_idx, img in enumerate(row): | |
| x = col_idx * img_width | |
| y = row_idx * img_height | |
| grid.paste(img, (x, y)) | |
| return grid | |