viewtoken-harmon-demo / src /models /harmon_dev.py
XinxuanLu's picture
Initial demo
becf13a verified
import torch
import torch.nn.functional as F
from torch.nn.modules.module import T
from mmengine.model import BaseModel
from torch.autograd.function import Function
from mmengine.logging import print_log
from xtuner.model.utils import guess_load_checkpoint
from xtuner.utils import IMAGE_TOKEN_INDEX
from .harmon import Harmon
from .viewpoint_mlp import ViewpointPredictionHead
from .mar.lora import apply_lora_to_mar
try:
from peft import get_peft_model, LoraConfig
HAS_PEFT = True
except ImportError:
HAS_PEFT = False
class _ScaleGradient(Function):
@staticmethod
def forward(ctx, input, scale):
ctx.scale = scale
return input
@staticmethod
def backward(ctx, grad_output):
return grad_output * ctx.scale, None
class HarmonDev(Harmon, BaseModel):
def __init__(self,
grad_scale=0.1,
loss_weights={'image2text': 1.0, 'text2image': 1.0, 'recon': 1.0,
'viewpoint2image': 1.0, 'image2viewpoint': 1.0,
'relpose2image': 1.0},
pretrained_pth=None,
freeze_llm=False,
gradient_checkpointing=True,
viewpoint_prediction_head=None,
lora=None,
lora_mar=None,
**kwargs
):
super().__init__(**kwargs)
self.grad_scale = grad_scale
self.loss_weights = loss_weights
self.debuged = False
# Load pretrained checkpoint BEFORE initializing viewpoint tokens
# This ensures embedding sizes match during loading
print(f"Initializing HarmonDev model...")
if pretrained_pth is not None:
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
self.load_state_dict(pretrained_state_dict, strict=False)
print_log(f'Load pretrained weight from {pretrained_pth}')
# Convert VAE to same dtype as LLM (bfloat16)
# VAE is frozen but needs to match input dtype
self.vae = self.vae.to(dtype=self.llm.dtype)
# Convert MAR to match LLM dtype
self.mar = self.mar.to(dtype=self.llm.dtype)
# Convert projection layers to match LLM dtype
self.proj_in = self.proj_in.to(dtype=self.llm.dtype)
self.proj_out = self.proj_out.to(dtype=self.llm.dtype)
# Apply LoRA to MAR (after checkpoint loading + dtype setup)
if lora_mar is not None:
apply_lora_to_mar(self.mar, **lora_mar)
# NOW initialize viewpoint tokens (after loading checkpoint)
# This adds new tokens and resizes embeddings
self._init_viewpoint_tokens()
# Convert viewpoint MLP to match LLM dtype
if hasattr(self, 'viewpoint_mlp') and self.viewpoint_mlp is not None:
self.viewpoint_mlp = self.viewpoint_mlp.to(dtype=self.llm.dtype)
# Initialize viewpoint prediction head (optional, only needed for image2viewpoint task)
if self.use_viewpoint_tokens and self.loss_weights.get('image2viewpoint', 0.0) > 0 and viewpoint_prediction_head is not None:
if isinstance(viewpoint_prediction_head, dict):
self.viewpoint_head = ViewpointPredictionHead(**viewpoint_prediction_head)
else:
self.viewpoint_head = viewpoint_prediction_head
# Convert viewpoint head to match LLM dtype
self.viewpoint_head = self.viewpoint_head.to(dtype=self.llm.dtype)
else:
self.viewpoint_head = None
# Apply LoRA to LLM (after loading checkpoint, before freezing)
if lora is not None:
assert HAS_PEFT, 'peft is required for LoRA. Install with: pip install peft'
if isinstance(lora, dict):
lora_config = LoraConfig(**lora)
else:
lora_config = lora
self.llm = get_peft_model(self.llm, lora_config)
# Store ref to base transformer model for forward_mae_encoder
# PeftModelForCausalLM -> LoraModel -> Qwen2ForCausalLM -> Qwen2Model
self._llm_base_model = self.llm.get_base_model().model
self.llm.print_trainable_parameters()
print_log('Applied LoRA to LLM')
# With LoRA, freeze_llm controls the base model (already frozen by peft),
# LoRA adapters remain trainable
elif freeze_llm:
self.llm.requires_grad_(False)
# gradient checkpointing
if gradient_checkpointing:
self.gradient_checkpointing_enable()
else:
self.gradient_checkpointing_disable()
def gradient_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
self.mar.gradient_checkpointing_disable()
def gradient_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
self.mar.gradient_checkpointing_enable()
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
# Filter out VAE weights using in-place deletion to preserve OrderedDict type and _metadata
keys_to_remove = [k for k in state_dict.keys() if 'vae.' in k]
for k in keys_to_remove:
del state_dict[k]
return state_dict
def train(self: T, mode: bool = True) -> T:
super().train(mode=mode)
self.vae.train(mode=False)
return self
def text2image_loss(self, data_dict):
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
x = self.encode(x) # b m n c
b, m, n, _ = x.shape
gt_latents = x.clone().detach().view(b, m*n, -1)
orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
mask = self.mar.random_masking(x.flatten(1, 2), orders)
input_ids = data_dict['input_ids'].to(self.device)
attention_mask = data_dict['attention_mask'].to(self.device)
x_enc = self.forward_mae_encoder(x, mask, input_ids=input_ids,
attention_mask=attention_mask)
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n))
loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
return loss
def image2text_loss(self, data_dict):
input_ids = data_dict['input_ids'].to(self.device)
attention_mask = data_dict['attention_mask'].to(self.device)
labels = data_dict['labels'].to(self.device)
pixel_values = data_dict.get('pixel_values', None)
if pixel_values is None:
assert False
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
_, z_null = self.extract_visual_feature(
torch.zeros(1, 16, 16, self.token_embed_dim,
dtype=self.dtype, device=self.device)
)
loss_null = z_null.mean() * 0.0
print(f"No image found in this batch!", flush=True)
else:
x = pixel_values.to(dtype=self.dtype, device=self.device)
x = self.encode(x) # b m n c
_, z_enc = self.extract_visual_feature(x)
if self.grad_scale is not None:
z_enc = _ScaleGradient.apply(z_enc, self.grad_scale)
for i in range(input_ids.shape[0]):
assert input_ids[i, 3] == IMAGE_TOKEN_INDEX
input_ids = torch.cat(
(
input_ids[:, :3],
IMAGE_TOKEN_INDEX * torch.ones(
input_ids.shape[0], 1088, dtype=torch.long, device=input_ids.device),
input_ids[:, 4:],
), dim=1
)
attention_mask = torch.cat(
(
attention_mask[:, :3],
torch.ones(
attention_mask.shape[0], 1088, dtype=torch.long, device=attention_mask.device),
attention_mask[:, 4:],
), dim=1
)
labels = torch.cat(
(
labels[:, :3],
labels[0, 3] * torch.ones(
labels.shape[0], 1088 , dtype=torch.long, device=labels.device),
labels[:, 4:],
), dim=1
)
inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1)
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
input_ids[input_ids != IMAGE_TOKEN_INDEX])
loss_null = 0.0
output = self.llm_model(inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True)
last_hidden_state = output.last_hidden_state[:, :-1]
labels = labels[:, 1:]
last_hidden_state = last_hidden_state[labels >= 0]
labels = labels[labels >= 0]
logits = self.llm.get_output_embeddings()(last_hidden_state)
loss_i2t = F.cross_entropy(input=logits, target=labels)
return loss_i2t + loss_null
def recon_loss(self, data_dict):
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
x = self.encode(x) # b m n c
b, m, n, _ = x.shape
gt_latents = x.clone().detach().view(b, m*n, -1)
input_ids = data_dict['input_ids'].to(self.device)
attention_mask = data_dict['attention_mask'].to(self.device)
_, z_enc = self.extract_visual_feature(x)
if self.grad_scale is not None:
z_enc = _ScaleGradient.apply(z_enc, self.grad_scale)
for i in range(input_ids.shape[0]):
assert input_ids[i, 3] == IMAGE_TOKEN_INDEX
input_ids = torch.cat(
(
input_ids[:, :3],
IMAGE_TOKEN_INDEX * torch.ones(
input_ids.shape[0], 1088, dtype=torch.long, device=input_ids.device),
input_ids[:, 4:],
), dim=1
)
attention_mask = torch.cat(
(
attention_mask[:, :3],
torch.ones(
attention_mask.shape[0], 1088, dtype=torch.long, device=attention_mask.device),
attention_mask[:, 4:],
), dim=1
)
inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1)
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
input_ids[input_ids != IMAGE_TOKEN_INDEX])
orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
mask = self.mar.random_masking(x.flatten(1, 2), orders)
x_enc = self.forward_mae_encoder(
x,
mask,
inputs_embeds=inputs_embeds,
input_ids=None,
attention_mask=attention_mask
)
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n))
loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
return loss
def viewpoint2image_loss(self, data_dict):
"""
Generate images conditioned on viewpoint parameters.
Text + viewpoint tokens → Image generation.
"""
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
x = self.encode(x) # (batch, m, n, c)
batch_size, m, n, _ = x.shape
gt_latents = x.clone().detach().view(batch_size, m*n, -1)
# Get viewpoint parameters and valid mask
viewpoint_params = data_dict['viewpoint_params'].to(dtype=self.dtype, device=self.device) # (batch, 6)
valid_mask = data_dict.get('viewpoint_valid_mask', None) # (batch, 6) or None
if valid_mask is not None:
valid_mask = valid_mask.to(device=self.device)
# Get num_objects if available (for Compass multi-object dataset)
num_objects = data_dict.get('num_objects', None) # (batch,) or None
if num_objects is not None:
num_objects = num_objects.to(device=self.device)
# Get input_ids (prompts with viewpoint tokens)
# e.g., "A <view_token_0><view_token_1>...<view_token_7> red car"
input_ids = data_dict['input_ids'].to(self.device)
attention_mask = data_dict['attention_mask'].to(self.device)
# Create inputs_embeds from input_ids
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
# Inject viewpoint embeddings
inputs_embeds = self.inject_viewpoint_embeddings(input_ids, viewpoint_params, inputs_embeds, valid_mask, num_objects=num_objects)
# Apply random masking to image tokens
orders = self.mar.sample_orders(bsz=batch_size, seq_len=m*n)
mask = self.mar.random_masking(x.flatten(1, 2), orders)
# Forward through MAR encoder
x_enc = self.forward_mae_encoder(
x, mask,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask
)
# MAR decoder
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n))
# Diffusion loss
loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
return loss
def image2viewpoint_loss(self, data_dict):
"""
Predict viewpoint parameters from input images.
Image + text prompt → Viewpoint parameter prediction.
"""
# Check if viewpoint prediction head is available
if self.viewpoint_head is None:
return torch.tensor(0.0, device=self.device, dtype=self.dtype)
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
x = self.encode(x) # (batch, m, n, c)
# Extract visual features (no masking)
_, z_enc = self.extract_visual_feature(x)
# Apply gradient scaling (inherited from image2text)
# Note: For viewpoint prediction, you may want to disable gradient scaling
# to allow stronger learning of image→viewpoint mapping. Set grad_scale=None
# or use task-specific scaling in future versions.
# if self.grad_scale is not None:
# z_enc = _ScaleGradient.apply(z_enc, self.grad_scale)
# Get input_ids (prompts for viewpoint prediction)
# e.g., "Predict the camera viewpoint for this image:"
input_ids = data_dict['input_ids'].to(self.device)
attention_mask = data_dict['attention_mask'].to(self.device)
# Insert visual tokens at IMAGE_TOKEN position
# Find IMAGE_TOKEN_INDEX position (should be at position 3 for Qwen, but verify dynamically)
image_token_positions = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=False)
if image_token_positions.numel() == 0:
raise ValueError(
f"IMAGE_TOKEN_INDEX ({IMAGE_TOKEN_INDEX}) not found in input_ids. "
f"Check dataset tokenization."
)
# Verify each sample has exactly one IMAGE_TOKEN
batch_size = input_ids.shape[0]
for i in range(batch_size):
sample_positions = image_token_positions[image_token_positions[:, 0] == i]
if len(sample_positions) != 1:
raise ValueError(
f"Sample {i} has {len(sample_positions)} IMAGE_TOKENs (expected 1)"
)
input_ids = torch.cat((
input_ids[:, :3],
IMAGE_TOKEN_INDEX * torch.ones(
input_ids.shape[0], 1088, dtype=torch.long, device=input_ids.device
),
input_ids[:, 4:],
), dim=1)
attention_mask = torch.cat((
attention_mask[:, :3],
torch.ones(
attention_mask.shape[0], 1088, dtype=torch.long, device=attention_mask.device
),
attention_mask[:, 4:],
), dim=1)
# Create inputs_embeds with visual tokens
inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1)
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
input_ids[input_ids != IMAGE_TOKEN_INDEX]
)
# Forward through LLM
output = self.llm_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True
)
# Extract features for viewpoint prediction
# Use last hidden state at the last position
pooled_output = output.last_hidden_state[:, -1, :] # (batch, hidden_size)
# Predict viewpoint parameters
predicted = self.viewpoint_head(pooled_output) # (batch, 4) or (batch, 9) depending on mode
# Ground truth viewpoint parameters
gt_viewpoint = data_dict['viewpoint_params'].to(self.device, dtype=self.dtype)
# Compute loss based on viewpoint_param_type
if self.viewpoint_mlp.viewpoint_param_type == 'spherical':
# Spherical mode: predict sin/cos of azimuth and elevation
# predicted: (batch, 4) = [sin(az), cos(az), sin(el), cos(el)]
# gt_viewpoint: (batch, 2) = [azimuth, elevation] in radians
gt_azimuth = gt_viewpoint[:, 0] # (batch,)
gt_elevation = gt_viewpoint[:, 1] # (batch,)
# Convert ground truth to sin/cos representation
gt_sin_cos = torch.stack([
torch.sin(gt_azimuth),
torch.cos(gt_azimuth),
torch.sin(gt_elevation),
torch.cos(gt_elevation)
], dim=-1) # (batch, 4)
# Compute MSE loss between predicted and ground truth sin/cos values
loss = F.mse_loss(predicted, gt_sin_cos)
elif self.viewpoint_mlp.viewpoint_param_type == 'rotation_translation' or self.viewpoint_mlp.viewpoint_param_type == 'relative_rotation_translation':
# Rotation_translation mode: predict 6D rotation + 3D translation
# predicted: (batch, 12) = [rot_6d (6), trans (3)]
# gt_viewpoint: (batch, 12) = [rot_6d (6), trans (3)]
# Separate losses for rotation and translation
pred_rot = predicted[:, :9] # (batch, 9)
pred_trans = predicted[:, 9:12] # (batch, 3)
gt_rot = gt_viewpoint[:, :9] # (batch, 9)
gt_trans = gt_viewpoint[:, 9:12] # (batch, 3)
loss_rot = F.mse_loss(pred_rot, gt_rot)
loss_trans = F.mse_loss(pred_trans, gt_trans)
# Combined loss (can weight differently if needed)
loss = 0.5 * loss_rot + 0.5 * loss_trans
elif self.viewpoint_mlp.viewpoint_param_type == 'factorized':
# Factorized mode: predict sin/cos for azimuth and elevation, plus radius, pitch, yaw
# predicted: (batch, 7) = [sin(az), cos(az), sin(el), cos(el), radius_norm, pitch, yaw]
# gt_viewpoint: (batch, 5) = [azimuth, elevation, radius_norm, pitch, yaw] in radians
gt_azimuth = gt_viewpoint[:, 0] # (batch,)
gt_elevation = gt_viewpoint[:, 1] # (batch,)
gt_radius = gt_viewpoint[:, 2] # (batch,) - normalized
gt_pitch = gt_viewpoint[:, 3] # (batch,)
gt_yaw = gt_viewpoint[:, 4] # (batch,)
# Convert ground truth angles to sin/cos representation
gt_sin_cos_angles = torch.stack([
torch.sin(gt_azimuth),
torch.cos(gt_azimuth),
torch.sin(gt_elevation),
torch.cos(gt_elevation)
], dim=-1) # (batch, 4)
# Concatenate with direct parameters: [sin_az, cos_az, sin_el, cos_el, radius, pitch, yaw]
gt_full = torch.cat([
gt_sin_cos_angles,
gt_radius.unsqueeze(-1),
gt_pitch.unsqueeze(-1),
gt_yaw.unsqueeze(-1)
], dim=-1) # (batch, 7)
# Compute MSE loss
loss = F.mse_loss(predicted, gt_full)
elif self.viewpoint_mlp.viewpoint_param_type == 'rotation_factorized':
# Rotation factorized mode (PLACEHOLDER): simple MSE loss on all 12 parameters
# predicted: (batch, 12) = [R_rel_9d (9), azimuth, elevation, radius_normalized]
# gt_viewpoint: (batch, 12) = [R_rel_9d (9), azimuth, elevation, radius_normalized]
# Simple MSE loss as placeholder (can be refined later with weighted components)
loss = F.mse_loss(predicted, gt_viewpoint)
elif self.viewpoint_mlp.viewpoint_param_type == 'plucker':
# Plucker mode: predict direction (3) + moment (3)
# predicted: (batch, 6) = [d_x, d_y, d_z, m_x, m_y, m_z]
# gt_viewpoint: (batch, 6) = [d_x, d_y, d_z, m_x, m_y, m_z]
# Separate losses for direction and moment
pred_direction = predicted[:, :3]
pred_moment = predicted[:, 3:6]
gt_direction = gt_viewpoint[:, :3]
gt_moment = gt_viewpoint[:, 3:6]
# Normalize GT direction for fair comparison
gt_direction = F.normalize(gt_direction, p=2, dim=-1)
loss_direction = F.mse_loss(pred_direction, gt_direction)
loss_moment = F.mse_loss(pred_moment, gt_moment)
loss = loss_direction + loss_moment
else:
raise ValueError(f"Unknown viewpoint_param_type: {self.viewpoint_mlp.viewpoint_param_type}")
return loss
def relpose2image_loss(self, data_dict):
"""
Generate target images from source images + viewpoint parameters.
Source image + viewpoint tokens → Target image generation.
Uses dual conditioning pattern (like recon_loss):
- Source image: encoded and injected as visual features (NO gradient scaling)
- Viewpoint tokens: encoded as parametric embeddings
- Target image: masked and used for generation
Args:
data_dict: {
'src_pixel_values': (batch, 3, 512, 512) source images
'tgt_pixel_values': (batch, 3, 512, 512) target images (ground truth)
'input_ids': (batch, seq_len) prompts with viewpoint tokens
'attention_mask': (batch, seq_len)
'viewpoint_params': (batch, 6) [azimuth, elevation, distance, updown, leftright, roll]
}
Returns:
loss: Diffusion loss on masked target image tokens
"""
# === STEP 1: Encode SOURCE image for conditioning ===
src = data_dict['src_pixel_values'].to(dtype=self.dtype, device=self.device)
src_latents = self.encode(src) # (batch, m, n, c)
_, z_src = self.extract_visual_feature(src_latents) # (batch, 1088, hidden_dim)
# Apply gradient scaling to prevent overfitting to source image appearance
if self.grad_scale is not None:
z_src = _ScaleGradient.apply(z_src, self.grad_scale)
# === STEP 2: Encode TARGET image for generation ===
tgt = data_dict['tgt_pixel_values'].to(dtype=self.dtype, device=self.device)
tgt_latents = self.encode(tgt) # (batch, m, n, c)
batch_size, m, n, _ = tgt_latents.shape
gt_latents = tgt_latents.clone().detach().view(batch_size, m*n, -1)
# === STEP 3: Get viewpoint parameters ===
viewpoint_params = data_dict['viewpoint_params'].to(dtype=self.dtype, device=self.device)
# === STEP 4: Expand IMAGE_TOKEN to 1088 tokens ===
input_ids = data_dict['input_ids'].to(self.device)
attention_mask = data_dict['attention_mask'].to(self.device)
# Verify IMAGE_TOKEN is at expected position
for i in range(input_ids.shape[0]):
assert input_ids[i, 3] == IMAGE_TOKEN_INDEX, \
f"Expected IMAGE_TOKEN_INDEX at position 3, got {input_ids[i, 3]}"
# Expand IMAGE_TOKEN from 1 to 1088 tokens
input_ids = torch.cat([
input_ids[:, :3],
IMAGE_TOKEN_INDEX * torch.ones(batch_size, 1088, dtype=torch.long, device=input_ids.device),
input_ids[:, 4:],
], dim=1)
attention_mask = torch.cat([
attention_mask[:, :3],
torch.ones(batch_size, 1088, dtype=torch.long, device=attention_mask.device),
attention_mask[:, 4:],
], dim=1)
# === STEP 5: Build inputs_embeds with source image + viewpoint tokens + text ===
inputs_embeds = z_src.new_zeros(batch_size, input_ids.shape[1], self.llm.config.hidden_size)
# 5a. Inject source visual features at IMAGE_TOKEN positions
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1)
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
input_ids[input_ids != IMAGE_TOKEN_INDEX])
# 5b. Inject viewpoint token embeddings
inputs_embeds = self.inject_viewpoint_embeddings(
input_ids,
viewpoint_params,
inputs_embeds,
valid_mask=None,
num_objects=None # Relpose dataset doesn't use multi-object
)
# === STEP 6: Apply random masking to TARGET image ===
orders = self.mar.sample_orders(bsz=batch_size, seq_len=m*n)
mask = self.mar.random_masking(tgt_latents.flatten(1, 2), orders)
# === STEP 7: Forward through MAR encoder with dual conditioning ===
x_enc = self.forward_mae_encoder(
tgt_latents, # Masked target image
mask,
inputs_embeds=inputs_embeds, # Contains: source image + viewpoint tokens + text
attention_mask=attention_mask
)
# === STEP 8: MAR decoder ===
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n))
# === STEP 9: Diffusion loss ===
loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
return loss
def forward(self, data, data_samples=None, mode='loss'):
if mode == 'loss':
return self.compute_loss(data_dict=data)
else:
raise NotImplementedError
def compute_loss(self, data_dict):
losses = {}
for data_type, batch_data in data_dict.items():
if 'text2image' in data_type:
loss = self.text2image_loss(batch_data)
elif 'image2text' in data_type:
loss = self.image2text_loss(batch_data)
elif 'recon' in data_type:
loss = self.recon_loss(batch_data)
elif 'viewpoint2image' in data_type:
loss = self.viewpoint2image_loss(batch_data)
elif 'image2viewpoint' in data_type:
loss = self.image2viewpoint_loss(batch_data)
elif 'relpose2image' in data_type:
loss = self.relpose2image_loss(batch_data)
else:
raise NotImplementedError(f"Unknown data type: {data_type}")
losses[f'loss_{data_type}'] = loss * self.loss_weights.get(data_type, 1.0)
return losses