Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.modules.module import T | |
| from mmengine.model import BaseModel | |
| from torch.autograd.function import Function | |
| from mmengine.logging import print_log | |
| from xtuner.model.utils import guess_load_checkpoint | |
| from xtuner.utils import IMAGE_TOKEN_INDEX | |
| from .harmon import Harmon | |
| from .viewpoint_mlp import ViewpointPredictionHead | |
| from .mar.lora import apply_lora_to_mar | |
| try: | |
| from peft import get_peft_model, LoraConfig | |
| HAS_PEFT = True | |
| except ImportError: | |
| HAS_PEFT = False | |
| class _ScaleGradient(Function): | |
| def forward(ctx, input, scale): | |
| ctx.scale = scale | |
| return input | |
| def backward(ctx, grad_output): | |
| return grad_output * ctx.scale, None | |
| class HarmonDev(Harmon, BaseModel): | |
| def __init__(self, | |
| grad_scale=0.1, | |
| loss_weights={'image2text': 1.0, 'text2image': 1.0, 'recon': 1.0, | |
| 'viewpoint2image': 1.0, 'image2viewpoint': 1.0, | |
| 'relpose2image': 1.0}, | |
| pretrained_pth=None, | |
| freeze_llm=False, | |
| gradient_checkpointing=True, | |
| viewpoint_prediction_head=None, | |
| lora=None, | |
| lora_mar=None, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.grad_scale = grad_scale | |
| self.loss_weights = loss_weights | |
| self.debuged = False | |
| # Load pretrained checkpoint BEFORE initializing viewpoint tokens | |
| # This ensures embedding sizes match during loading | |
| print(f"Initializing HarmonDev model...") | |
| if pretrained_pth is not None: | |
| pretrained_state_dict = guess_load_checkpoint(pretrained_pth) | |
| self.load_state_dict(pretrained_state_dict, strict=False) | |
| print_log(f'Load pretrained weight from {pretrained_pth}') | |
| # Convert VAE to same dtype as LLM (bfloat16) | |
| # VAE is frozen but needs to match input dtype | |
| self.vae = self.vae.to(dtype=self.llm.dtype) | |
| # Convert MAR to match LLM dtype | |
| self.mar = self.mar.to(dtype=self.llm.dtype) | |
| # Convert projection layers to match LLM dtype | |
| self.proj_in = self.proj_in.to(dtype=self.llm.dtype) | |
| self.proj_out = self.proj_out.to(dtype=self.llm.dtype) | |
| # Apply LoRA to MAR (after checkpoint loading + dtype setup) | |
| if lora_mar is not None: | |
| apply_lora_to_mar(self.mar, **lora_mar) | |
| # NOW initialize viewpoint tokens (after loading checkpoint) | |
| # This adds new tokens and resizes embeddings | |
| self._init_viewpoint_tokens() | |
| # Convert viewpoint MLP to match LLM dtype | |
| if hasattr(self, 'viewpoint_mlp') and self.viewpoint_mlp is not None: | |
| self.viewpoint_mlp = self.viewpoint_mlp.to(dtype=self.llm.dtype) | |
| # Initialize viewpoint prediction head (optional, only needed for image2viewpoint task) | |
| if self.use_viewpoint_tokens and self.loss_weights.get('image2viewpoint', 0.0) > 0 and viewpoint_prediction_head is not None: | |
| if isinstance(viewpoint_prediction_head, dict): | |
| self.viewpoint_head = ViewpointPredictionHead(**viewpoint_prediction_head) | |
| else: | |
| self.viewpoint_head = viewpoint_prediction_head | |
| # Convert viewpoint head to match LLM dtype | |
| self.viewpoint_head = self.viewpoint_head.to(dtype=self.llm.dtype) | |
| else: | |
| self.viewpoint_head = None | |
| # Apply LoRA to LLM (after loading checkpoint, before freezing) | |
| if lora is not None: | |
| assert HAS_PEFT, 'peft is required for LoRA. Install with: pip install peft' | |
| if isinstance(lora, dict): | |
| lora_config = LoraConfig(**lora) | |
| else: | |
| lora_config = lora | |
| self.llm = get_peft_model(self.llm, lora_config) | |
| # Store ref to base transformer model for forward_mae_encoder | |
| # PeftModelForCausalLM -> LoraModel -> Qwen2ForCausalLM -> Qwen2Model | |
| self._llm_base_model = self.llm.get_base_model().model | |
| self.llm.print_trainable_parameters() | |
| print_log('Applied LoRA to LLM') | |
| # With LoRA, freeze_llm controls the base model (already frozen by peft), | |
| # LoRA adapters remain trainable | |
| elif freeze_llm: | |
| self.llm.requires_grad_(False) | |
| # gradient checkpointing | |
| if gradient_checkpointing: | |
| self.gradient_checkpointing_enable() | |
| else: | |
| self.gradient_checkpointing_disable() | |
| def gradient_checkpointing_disable(self): | |
| self.llm.gradient_checkpointing_disable() | |
| self.mar.gradient_checkpointing_disable() | |
| def gradient_checkpointing_enable(self): | |
| self.llm.gradient_checkpointing_enable() | |
| self.mar.gradient_checkpointing_enable() | |
| def state_dict(self, *args, **kwargs): | |
| state_dict = super().state_dict(*args, **kwargs) | |
| # Filter out VAE weights using in-place deletion to preserve OrderedDict type and _metadata | |
| keys_to_remove = [k for k in state_dict.keys() if 'vae.' in k] | |
| for k in keys_to_remove: | |
| del state_dict[k] | |
| return state_dict | |
| def train(self: T, mode: bool = True) -> T: | |
| super().train(mode=mode) | |
| self.vae.train(mode=False) | |
| return self | |
| def text2image_loss(self, data_dict): | |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) | |
| x = self.encode(x) # b m n c | |
| b, m, n, _ = x.shape | |
| gt_latents = x.clone().detach().view(b, m*n, -1) | |
| orders = self.mar.sample_orders(bsz=b, seq_len=m*n) | |
| mask = self.mar.random_masking(x.flatten(1, 2), orders) | |
| input_ids = data_dict['input_ids'].to(self.device) | |
| attention_mask = data_dict['attention_mask'].to(self.device) | |
| x_enc = self.forward_mae_encoder(x, mask, input_ids=input_ids, | |
| attention_mask=attention_mask) | |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n)) | |
| loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) | |
| return loss | |
| def image2text_loss(self, data_dict): | |
| input_ids = data_dict['input_ids'].to(self.device) | |
| attention_mask = data_dict['attention_mask'].to(self.device) | |
| labels = data_dict['labels'].to(self.device) | |
| pixel_values = data_dict.get('pixel_values', None) | |
| if pixel_values is None: | |
| assert False | |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) | |
| _, z_null = self.extract_visual_feature( | |
| torch.zeros(1, 16, 16, self.token_embed_dim, | |
| dtype=self.dtype, device=self.device) | |
| ) | |
| loss_null = z_null.mean() * 0.0 | |
| print(f"No image found in this batch!", flush=True) | |
| else: | |
| x = pixel_values.to(dtype=self.dtype, device=self.device) | |
| x = self.encode(x) # b m n c | |
| _, z_enc = self.extract_visual_feature(x) | |
| if self.grad_scale is not None: | |
| z_enc = _ScaleGradient.apply(z_enc, self.grad_scale) | |
| for i in range(input_ids.shape[0]): | |
| assert input_ids[i, 3] == IMAGE_TOKEN_INDEX | |
| input_ids = torch.cat( | |
| ( | |
| input_ids[:, :3], | |
| IMAGE_TOKEN_INDEX * torch.ones( | |
| input_ids.shape[0], 1088, dtype=torch.long, device=input_ids.device), | |
| input_ids[:, 4:], | |
| ), dim=1 | |
| ) | |
| attention_mask = torch.cat( | |
| ( | |
| attention_mask[:, :3], | |
| torch.ones( | |
| attention_mask.shape[0], 1088, dtype=torch.long, device=attention_mask.device), | |
| attention_mask[:, 4:], | |
| ), dim=1 | |
| ) | |
| labels = torch.cat( | |
| ( | |
| labels[:, :3], | |
| labels[0, 3] * torch.ones( | |
| labels.shape[0], 1088 , dtype=torch.long, device=labels.device), | |
| labels[:, 4:], | |
| ), dim=1 | |
| ) | |
| inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size) | |
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1) | |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( | |
| input_ids[input_ids != IMAGE_TOKEN_INDEX]) | |
| loss_null = 0.0 | |
| output = self.llm_model(inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True) | |
| last_hidden_state = output.last_hidden_state[:, :-1] | |
| labels = labels[:, 1:] | |
| last_hidden_state = last_hidden_state[labels >= 0] | |
| labels = labels[labels >= 0] | |
| logits = self.llm.get_output_embeddings()(last_hidden_state) | |
| loss_i2t = F.cross_entropy(input=logits, target=labels) | |
| return loss_i2t + loss_null | |
| def recon_loss(self, data_dict): | |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) | |
| x = self.encode(x) # b m n c | |
| b, m, n, _ = x.shape | |
| gt_latents = x.clone().detach().view(b, m*n, -1) | |
| input_ids = data_dict['input_ids'].to(self.device) | |
| attention_mask = data_dict['attention_mask'].to(self.device) | |
| _, z_enc = self.extract_visual_feature(x) | |
| if self.grad_scale is not None: | |
| z_enc = _ScaleGradient.apply(z_enc, self.grad_scale) | |
| for i in range(input_ids.shape[0]): | |
| assert input_ids[i, 3] == IMAGE_TOKEN_INDEX | |
| input_ids = torch.cat( | |
| ( | |
| input_ids[:, :3], | |
| IMAGE_TOKEN_INDEX * torch.ones( | |
| input_ids.shape[0], 1088, dtype=torch.long, device=input_ids.device), | |
| input_ids[:, 4:], | |
| ), dim=1 | |
| ) | |
| attention_mask = torch.cat( | |
| ( | |
| attention_mask[:, :3], | |
| torch.ones( | |
| attention_mask.shape[0], 1088, dtype=torch.long, device=attention_mask.device), | |
| attention_mask[:, 4:], | |
| ), dim=1 | |
| ) | |
| inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size) | |
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1) | |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( | |
| input_ids[input_ids != IMAGE_TOKEN_INDEX]) | |
| orders = self.mar.sample_orders(bsz=b, seq_len=m*n) | |
| mask = self.mar.random_masking(x.flatten(1, 2), orders) | |
| x_enc = self.forward_mae_encoder( | |
| x, | |
| mask, | |
| inputs_embeds=inputs_embeds, | |
| input_ids=None, | |
| attention_mask=attention_mask | |
| ) | |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n)) | |
| loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) | |
| return loss | |
| def viewpoint2image_loss(self, data_dict): | |
| """ | |
| Generate images conditioned on viewpoint parameters. | |
| Text + viewpoint tokens → Image generation. | |
| """ | |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) | |
| x = self.encode(x) # (batch, m, n, c) | |
| batch_size, m, n, _ = x.shape | |
| gt_latents = x.clone().detach().view(batch_size, m*n, -1) | |
| # Get viewpoint parameters and valid mask | |
| viewpoint_params = data_dict['viewpoint_params'].to(dtype=self.dtype, device=self.device) # (batch, 6) | |
| valid_mask = data_dict.get('viewpoint_valid_mask', None) # (batch, 6) or None | |
| if valid_mask is not None: | |
| valid_mask = valid_mask.to(device=self.device) | |
| # Get num_objects if available (for Compass multi-object dataset) | |
| num_objects = data_dict.get('num_objects', None) # (batch,) or None | |
| if num_objects is not None: | |
| num_objects = num_objects.to(device=self.device) | |
| # Get input_ids (prompts with viewpoint tokens) | |
| # e.g., "A <view_token_0><view_token_1>...<view_token_7> red car" | |
| input_ids = data_dict['input_ids'].to(self.device) | |
| attention_mask = data_dict['attention_mask'].to(self.device) | |
| # Create inputs_embeds from input_ids | |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) | |
| # Inject viewpoint embeddings | |
| inputs_embeds = self.inject_viewpoint_embeddings(input_ids, viewpoint_params, inputs_embeds, valid_mask, num_objects=num_objects) | |
| # Apply random masking to image tokens | |
| orders = self.mar.sample_orders(bsz=batch_size, seq_len=m*n) | |
| mask = self.mar.random_masking(x.flatten(1, 2), orders) | |
| # Forward through MAR encoder | |
| x_enc = self.forward_mae_encoder( | |
| x, mask, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask | |
| ) | |
| # MAR decoder | |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n)) | |
| # Diffusion loss | |
| loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) | |
| return loss | |
| def image2viewpoint_loss(self, data_dict): | |
| """ | |
| Predict viewpoint parameters from input images. | |
| Image + text prompt → Viewpoint parameter prediction. | |
| """ | |
| # Check if viewpoint prediction head is available | |
| if self.viewpoint_head is None: | |
| return torch.tensor(0.0, device=self.device, dtype=self.dtype) | |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) | |
| x = self.encode(x) # (batch, m, n, c) | |
| # Extract visual features (no masking) | |
| _, z_enc = self.extract_visual_feature(x) | |
| # Apply gradient scaling (inherited from image2text) | |
| # Note: For viewpoint prediction, you may want to disable gradient scaling | |
| # to allow stronger learning of image→viewpoint mapping. Set grad_scale=None | |
| # or use task-specific scaling in future versions. | |
| # if self.grad_scale is not None: | |
| # z_enc = _ScaleGradient.apply(z_enc, self.grad_scale) | |
| # Get input_ids (prompts for viewpoint prediction) | |
| # e.g., "Predict the camera viewpoint for this image:" | |
| input_ids = data_dict['input_ids'].to(self.device) | |
| attention_mask = data_dict['attention_mask'].to(self.device) | |
| # Insert visual tokens at IMAGE_TOKEN position | |
| # Find IMAGE_TOKEN_INDEX position (should be at position 3 for Qwen, but verify dynamically) | |
| image_token_positions = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=False) | |
| if image_token_positions.numel() == 0: | |
| raise ValueError( | |
| f"IMAGE_TOKEN_INDEX ({IMAGE_TOKEN_INDEX}) not found in input_ids. " | |
| f"Check dataset tokenization." | |
| ) | |
| # Verify each sample has exactly one IMAGE_TOKEN | |
| batch_size = input_ids.shape[0] | |
| for i in range(batch_size): | |
| sample_positions = image_token_positions[image_token_positions[:, 0] == i] | |
| if len(sample_positions) != 1: | |
| raise ValueError( | |
| f"Sample {i} has {len(sample_positions)} IMAGE_TOKENs (expected 1)" | |
| ) | |
| input_ids = torch.cat(( | |
| input_ids[:, :3], | |
| IMAGE_TOKEN_INDEX * torch.ones( | |
| input_ids.shape[0], 1088, dtype=torch.long, device=input_ids.device | |
| ), | |
| input_ids[:, 4:], | |
| ), dim=1) | |
| attention_mask = torch.cat(( | |
| attention_mask[:, :3], | |
| torch.ones( | |
| attention_mask.shape[0], 1088, dtype=torch.long, device=attention_mask.device | |
| ), | |
| attention_mask[:, 4:], | |
| ), dim=1) | |
| # Create inputs_embeds with visual tokens | |
| inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size) | |
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1) | |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( | |
| input_ids[input_ids != IMAGE_TOKEN_INDEX] | |
| ) | |
| # Forward through LLM | |
| output = self.llm_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| # Extract features for viewpoint prediction | |
| # Use last hidden state at the last position | |
| pooled_output = output.last_hidden_state[:, -1, :] # (batch, hidden_size) | |
| # Predict viewpoint parameters | |
| predicted = self.viewpoint_head(pooled_output) # (batch, 4) or (batch, 9) depending on mode | |
| # Ground truth viewpoint parameters | |
| gt_viewpoint = data_dict['viewpoint_params'].to(self.device, dtype=self.dtype) | |
| # Compute loss based on viewpoint_param_type | |
| if self.viewpoint_mlp.viewpoint_param_type == 'spherical': | |
| # Spherical mode: predict sin/cos of azimuth and elevation | |
| # predicted: (batch, 4) = [sin(az), cos(az), sin(el), cos(el)] | |
| # gt_viewpoint: (batch, 2) = [azimuth, elevation] in radians | |
| gt_azimuth = gt_viewpoint[:, 0] # (batch,) | |
| gt_elevation = gt_viewpoint[:, 1] # (batch,) | |
| # Convert ground truth to sin/cos representation | |
| gt_sin_cos = torch.stack([ | |
| torch.sin(gt_azimuth), | |
| torch.cos(gt_azimuth), | |
| torch.sin(gt_elevation), | |
| torch.cos(gt_elevation) | |
| ], dim=-1) # (batch, 4) | |
| # Compute MSE loss between predicted and ground truth sin/cos values | |
| loss = F.mse_loss(predicted, gt_sin_cos) | |
| elif self.viewpoint_mlp.viewpoint_param_type == 'rotation_translation' or self.viewpoint_mlp.viewpoint_param_type == 'relative_rotation_translation': | |
| # Rotation_translation mode: predict 6D rotation + 3D translation | |
| # predicted: (batch, 12) = [rot_6d (6), trans (3)] | |
| # gt_viewpoint: (batch, 12) = [rot_6d (6), trans (3)] | |
| # Separate losses for rotation and translation | |
| pred_rot = predicted[:, :9] # (batch, 9) | |
| pred_trans = predicted[:, 9:12] # (batch, 3) | |
| gt_rot = gt_viewpoint[:, :9] # (batch, 9) | |
| gt_trans = gt_viewpoint[:, 9:12] # (batch, 3) | |
| loss_rot = F.mse_loss(pred_rot, gt_rot) | |
| loss_trans = F.mse_loss(pred_trans, gt_trans) | |
| # Combined loss (can weight differently if needed) | |
| loss = 0.5 * loss_rot + 0.5 * loss_trans | |
| elif self.viewpoint_mlp.viewpoint_param_type == 'factorized': | |
| # Factorized mode: predict sin/cos for azimuth and elevation, plus radius, pitch, yaw | |
| # predicted: (batch, 7) = [sin(az), cos(az), sin(el), cos(el), radius_norm, pitch, yaw] | |
| # gt_viewpoint: (batch, 5) = [azimuth, elevation, radius_norm, pitch, yaw] in radians | |
| gt_azimuth = gt_viewpoint[:, 0] # (batch,) | |
| gt_elevation = gt_viewpoint[:, 1] # (batch,) | |
| gt_radius = gt_viewpoint[:, 2] # (batch,) - normalized | |
| gt_pitch = gt_viewpoint[:, 3] # (batch,) | |
| gt_yaw = gt_viewpoint[:, 4] # (batch,) | |
| # Convert ground truth angles to sin/cos representation | |
| gt_sin_cos_angles = torch.stack([ | |
| torch.sin(gt_azimuth), | |
| torch.cos(gt_azimuth), | |
| torch.sin(gt_elevation), | |
| torch.cos(gt_elevation) | |
| ], dim=-1) # (batch, 4) | |
| # Concatenate with direct parameters: [sin_az, cos_az, sin_el, cos_el, radius, pitch, yaw] | |
| gt_full = torch.cat([ | |
| gt_sin_cos_angles, | |
| gt_radius.unsqueeze(-1), | |
| gt_pitch.unsqueeze(-1), | |
| gt_yaw.unsqueeze(-1) | |
| ], dim=-1) # (batch, 7) | |
| # Compute MSE loss | |
| loss = F.mse_loss(predicted, gt_full) | |
| elif self.viewpoint_mlp.viewpoint_param_type == 'rotation_factorized': | |
| # Rotation factorized mode (PLACEHOLDER): simple MSE loss on all 12 parameters | |
| # predicted: (batch, 12) = [R_rel_9d (9), azimuth, elevation, radius_normalized] | |
| # gt_viewpoint: (batch, 12) = [R_rel_9d (9), azimuth, elevation, radius_normalized] | |
| # Simple MSE loss as placeholder (can be refined later with weighted components) | |
| loss = F.mse_loss(predicted, gt_viewpoint) | |
| elif self.viewpoint_mlp.viewpoint_param_type == 'plucker': | |
| # Plucker mode: predict direction (3) + moment (3) | |
| # predicted: (batch, 6) = [d_x, d_y, d_z, m_x, m_y, m_z] | |
| # gt_viewpoint: (batch, 6) = [d_x, d_y, d_z, m_x, m_y, m_z] | |
| # Separate losses for direction and moment | |
| pred_direction = predicted[:, :3] | |
| pred_moment = predicted[:, 3:6] | |
| gt_direction = gt_viewpoint[:, :3] | |
| gt_moment = gt_viewpoint[:, 3:6] | |
| # Normalize GT direction for fair comparison | |
| gt_direction = F.normalize(gt_direction, p=2, dim=-1) | |
| loss_direction = F.mse_loss(pred_direction, gt_direction) | |
| loss_moment = F.mse_loss(pred_moment, gt_moment) | |
| loss = loss_direction + loss_moment | |
| else: | |
| raise ValueError(f"Unknown viewpoint_param_type: {self.viewpoint_mlp.viewpoint_param_type}") | |
| return loss | |
| def relpose2image_loss(self, data_dict): | |
| """ | |
| Generate target images from source images + viewpoint parameters. | |
| Source image + viewpoint tokens → Target image generation. | |
| Uses dual conditioning pattern (like recon_loss): | |
| - Source image: encoded and injected as visual features (NO gradient scaling) | |
| - Viewpoint tokens: encoded as parametric embeddings | |
| - Target image: masked and used for generation | |
| Args: | |
| data_dict: { | |
| 'src_pixel_values': (batch, 3, 512, 512) source images | |
| 'tgt_pixel_values': (batch, 3, 512, 512) target images (ground truth) | |
| 'input_ids': (batch, seq_len) prompts with viewpoint tokens | |
| 'attention_mask': (batch, seq_len) | |
| 'viewpoint_params': (batch, 6) [azimuth, elevation, distance, updown, leftright, roll] | |
| } | |
| Returns: | |
| loss: Diffusion loss on masked target image tokens | |
| """ | |
| # === STEP 1: Encode SOURCE image for conditioning === | |
| src = data_dict['src_pixel_values'].to(dtype=self.dtype, device=self.device) | |
| src_latents = self.encode(src) # (batch, m, n, c) | |
| _, z_src = self.extract_visual_feature(src_latents) # (batch, 1088, hidden_dim) | |
| # Apply gradient scaling to prevent overfitting to source image appearance | |
| if self.grad_scale is not None: | |
| z_src = _ScaleGradient.apply(z_src, self.grad_scale) | |
| # === STEP 2: Encode TARGET image for generation === | |
| tgt = data_dict['tgt_pixel_values'].to(dtype=self.dtype, device=self.device) | |
| tgt_latents = self.encode(tgt) # (batch, m, n, c) | |
| batch_size, m, n, _ = tgt_latents.shape | |
| gt_latents = tgt_latents.clone().detach().view(batch_size, m*n, -1) | |
| # === STEP 3: Get viewpoint parameters === | |
| viewpoint_params = data_dict['viewpoint_params'].to(dtype=self.dtype, device=self.device) | |
| # === STEP 4: Expand IMAGE_TOKEN to 1088 tokens === | |
| input_ids = data_dict['input_ids'].to(self.device) | |
| attention_mask = data_dict['attention_mask'].to(self.device) | |
| # Verify IMAGE_TOKEN is at expected position | |
| for i in range(input_ids.shape[0]): | |
| assert input_ids[i, 3] == IMAGE_TOKEN_INDEX, \ | |
| f"Expected IMAGE_TOKEN_INDEX at position 3, got {input_ids[i, 3]}" | |
| # Expand IMAGE_TOKEN from 1 to 1088 tokens | |
| input_ids = torch.cat([ | |
| input_ids[:, :3], | |
| IMAGE_TOKEN_INDEX * torch.ones(batch_size, 1088, dtype=torch.long, device=input_ids.device), | |
| input_ids[:, 4:], | |
| ], dim=1) | |
| attention_mask = torch.cat([ | |
| attention_mask[:, :3], | |
| torch.ones(batch_size, 1088, dtype=torch.long, device=attention_mask.device), | |
| attention_mask[:, 4:], | |
| ], dim=1) | |
| # === STEP 5: Build inputs_embeds with source image + viewpoint tokens + text === | |
| inputs_embeds = z_src.new_zeros(batch_size, input_ids.shape[1], self.llm.config.hidden_size) | |
| # 5a. Inject source visual features at IMAGE_TOKEN positions | |
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) | |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( | |
| input_ids[input_ids != IMAGE_TOKEN_INDEX]) | |
| # 5b. Inject viewpoint token embeddings | |
| inputs_embeds = self.inject_viewpoint_embeddings( | |
| input_ids, | |
| viewpoint_params, | |
| inputs_embeds, | |
| valid_mask=None, | |
| num_objects=None # Relpose dataset doesn't use multi-object | |
| ) | |
| # === STEP 6: Apply random masking to TARGET image === | |
| orders = self.mar.sample_orders(bsz=batch_size, seq_len=m*n) | |
| mask = self.mar.random_masking(tgt_latents.flatten(1, 2), orders) | |
| # === STEP 7: Forward through MAR encoder with dual conditioning === | |
| x_enc = self.forward_mae_encoder( | |
| tgt_latents, # Masked target image | |
| mask, | |
| inputs_embeds=inputs_embeds, # Contains: source image + viewpoint tokens + text | |
| attention_mask=attention_mask | |
| ) | |
| # === STEP 8: MAR decoder === | |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n)) | |
| # === STEP 9: Diffusion loss === | |
| loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) | |
| return loss | |
| def forward(self, data, data_samples=None, mode='loss'): | |
| if mode == 'loss': | |
| return self.compute_loss(data_dict=data) | |
| else: | |
| raise NotImplementedError | |
| def compute_loss(self, data_dict): | |
| losses = {} | |
| for data_type, batch_data in data_dict.items(): | |
| if 'text2image' in data_type: | |
| loss = self.text2image_loss(batch_data) | |
| elif 'image2text' in data_type: | |
| loss = self.image2text_loss(batch_data) | |
| elif 'recon' in data_type: | |
| loss = self.recon_loss(batch_data) | |
| elif 'viewpoint2image' in data_type: | |
| loss = self.viewpoint2image_loss(batch_data) | |
| elif 'image2viewpoint' in data_type: | |
| loss = self.image2viewpoint_loss(batch_data) | |
| elif 'relpose2image' in data_type: | |
| loss = self.relpose2image_loss(batch_data) | |
| else: | |
| raise NotImplementedError(f"Unknown data type: {data_type}") | |
| losses[f'loss_{data_type}'] = loss * self.loss_weights.get(data_type, 1.0) | |
| return losses | |