import torch import torch.nn.functional as F from torch.nn.modules.module import T from mmengine.model import BaseModel from torch.autograd.function import Function from mmengine.logging import print_log from xtuner.model.utils import guess_load_checkpoint import os #from .skywork_unipic import SkyworkUnipic from .skywork_unipic_siglip import SkyworkUnipic from xtuner.utils import IMAGE_TOKEN_INDEX import torch.distributed as dist import json from einops import rearrange def _load_state_dict_with_ds(module_to_load, state_dict, start_prefix="", strict=True): try: import deepspeed except ImportError: raise ImportError("deepspeed is not installed. Please install deepspeed to use this feature.") # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] missing_keys = [] unexpected_keys = [] def load(module: torch.nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: # In sharded models, each shard has only part of the full state_dict, so only gather # parameters that are in the current state_dict. named_parameters = dict( module.named_parameters(prefix=prefix[:-1], recurse=False) ) params_to_gather = [ named_parameters[k] for k in state_dict.keys() if k in named_parameters ] if len(params_to_gather) > 0: # because zero3 puts placeholders in model params, this context # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again with deepspeed.zero.GatheredParameters( params_to_gather, modifier_rank=0 ): if deepspeed.comm.get_rank() == 0: module._load_from_state_dict(*args) else: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") load(module_to_load, state_dict, start_prefix) if len(missing_keys) > 0: print_log(f"[WARNING] Missing keys: {missing_keys}") if len(unexpected_keys) > 0: print_log(f"[WARNING] Unexpected keys: {unexpected_keys}") if error_msgs: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( module_to_load.__class__.__name__, "\n\t".join(error_msgs) ) ) class _ScaleGradient(Function): @staticmethod def forward(ctx, input, scale): ctx.scale = scale return input @staticmethod def backward(ctx, grad_output): return grad_output * ctx.scale, None class SkyworkUnipicDev(SkyworkUnipic, BaseModel): def __init__( self, grad_scale=0.1, loss_weights=None, pretrained_pth=None, mar_path=None, siglip_proj_path=None, freeze_llm=False, freeze_mar=False, freeze_mar_decoder=False, freeze_siglip_proj=False, gradient_checkpointing=True, **kwargs, ): if loss_weights is None: loss_weights = { "image2text": 0.01, "text2image": 1.0, "image_edit": 1.0, "contrastive": 0.1, } super().__init__(**kwargs) self.grad_scale = grad_scale self.loss_weights = loss_weights self.pretrained_pth = pretrained_pth self.mar_path = mar_path self.siglip_proj_path = siglip_proj_path # 判断分布式 rank rank = dist.get_rank() if dist.is_initialized() else 0 # === 加载预训练权重 === if pretrained_pth: self.load_hf_weights( skywork_unipic_ckpt=pretrained_pth, siglip_proj_path=siglip_proj_path, mar_path=mar_path ) # === 冻结模块 === if freeze_llm: self.llm.requires_grad_(False) if freeze_mar: self.mar.requires_grad_(False) if freeze_mar_decoder: # 仅冻结 MAR 解码器部件 for param in self.mar.decoder_embed.parameters(): param.requires_grad = False for block in self.mar.decoder_blocks: for param in block.parameters(): param.requires_grad = False for param in self.mar.decoder_norm.parameters(): param.requires_grad = False if isinstance(self.mar.decoder_pos_embed_learned, torch.nn.Parameter): self.mar.decoder_pos_embed_learned.requires_grad = False if isinstance(self.mar.diffusion_pos_embed_learned, torch.nn.Parameter): self.mar.diffusion_pos_embed_learned.requires_grad = False if freeze_siglip_proj: self.siglip2_proj.requires_grad_(False) # === 梯度检查点 === if gradient_checkpointing: self.gradient_checkpointing_enable() else: self.gradient_checkpointing_disable() def load_hf_weights(self, skywork_unipic_ckpt: str = None, siglip_proj_path: str = None, mar_path: str = None): """统一加载 SkyworkUnipic(可选) + SigLIP2 + MAR""" device = "cpu" state_dict = {} def _print_load_result(module_name, missing, unexpected): print_log(f"[INFO] Loaded {module_name}. missing={len(missing)}, unexpected={len(unexpected)}") # === SkyworkUnipic 主模型(可选) === if skywork_unipic_ckpt: print_log(f"[INFO] Loading SkyworkUnipic checkpoint from: {skywork_unipic_ckpt}") # 加载 checkpoint(支持文件或目录) if os.path.isfile(skywork_unipic_ckpt): skywork_unipic_state = torch.load(skywork_unipic_ckpt, map_location=device) else: idx = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin.index.json") if os.path.exists(idx): with open(idx, 'r') as f: index = json.load(f) skywork_unipic_state = {} for shard in sorted(set(index["weight_map"].values())): shard_path = os.path.join(skywork_unipic_ckpt, shard) skywork_unipic_state.update(torch.load(shard_path, map_location=device)) else: bin_path = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin") skywork_unipic_state = torch.load(bin_path, map_location=device) # 删除 SkyworkUnipic checkpoint 中可能带的 MAR pos_embed,避免覆盖 # for key in [ # "mar.encoder_pos_embed_learned", # "mar.decoder_pos_embed_learned", # "mar.diffusion_pos_embed_learned" # ]: # if key in skywork_unipic_state: # print_log(f"[INFO] Dropping `{key}` from SkyworkUnipic checkpoint") # del skywork_unipic_state[key] model_dict = self.state_dict() filtered_checkpoint = {} shape_mismatch_keys = [] for k, v in skywork_unipic_state.items(): if k in model_dict: if v.shape == model_dict[k].shape: filtered_checkpoint[k] = v else: shape_mismatch_keys.append((k, v.shape, model_dict[k].shape)) missing, unexpected = self.load_state_dict(filtered_checkpoint, strict=False) # 打印不匹配的 key 及其形状 if shape_mismatch_keys: print("以下 key 因形状不匹配被跳过:") for k, checkpoint_shape, model_shape in shape_mismatch_keys: print(f" - {k}:") print(f" checkpoint 中的形状: {checkpoint_shape}") print(f" 当前模型的形状: {model_shape}") else: print("所有 key 形状匹配,未跳过任何参数") # missing, unexpected = self.load_state_dict(skywork_unipic_state, strict=False) _print_load_result("SkyworkUnipic", missing, unexpected) else: print_log("[INFO] Skipping SkyworkUnipic checkpoint loading") # === SigLIP2 权重 === if siglip_proj_path: print_log(f"[INFO] Loading SigLIP2 weights from: {siglip_proj_path}") siglip_state = torch.load( siglip_proj_path, map_location="cpu", weights_only=False ) # 如果 checkpoint 是 {"model": {...}} if isinstance(siglip_state, dict) and "model" in siglip_state: siglip_state = siglip_state["model"] missing, unexpected = self.siglip2_proj.load_state_dict( siglip_state, strict=False ) _print_load_result("SigLIP2", missing, unexpected) else: print_log("[INFO] No SigLIP2 checkpoint provided, skipping") # === MAR 权重 === if mar_path: print_log(f"[INFO] Loading MAR weights from: {mar_path}") mar_state = torch.load(mar_path, map_location="cpu", weights_only=False) # 兼容 model_ema or model dict if isinstance(mar_state, dict) and "model_ema" in mar_state: mar_state = mar_state["model_ema"] elif isinstance(mar_state, dict) and "model" in mar_state: mar_state = mar_state["model"] # 如果 key 带有 “mar.” 前缀,批量去掉 if any(k.startswith("mar.") for k in mar_state): filtered_mar = { k.replace("mar.", "", 1): v for k, v in mar_state.items() if k.startswith("mar.") } else: filtered_mar = mar_state missing, unexpected = self.mar.load_state_dict( filtered_mar, strict=False ) _print_load_result("MAR", missing, unexpected) else: print_log("[INFO] No MAR checkpoint provided, skipping") return state_dict def gradient_checkpointing_disable(self): self.llm.gradient_checkpointing_disable() self.mar.gradient_checkpointing_disable() def gradient_checkpointing_enable(self): self.llm.gradient_checkpointing_enable() self.mar.gradient_checkpointing_enable() def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) state_dict = {k: v for k, v in state_dict.items() if 'vae.' not in k} return state_dict def train(self: T, mode: bool = True) -> T: super().train(mode=mode) self.vae.train(mode=False) return self def text2image_loss(self, data_dict): x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) x = self.encode(x) # b m n c b, m, n, _ = x.shape gt_latents = x.clone().detach().view(b, m*n, -1) orders = self.mar.sample_orders(bsz=b, seq_len=m*n) mask = self.mar.random_masking(x.flatten(1, 2), orders) input_ids = data_dict['input_ids'].to(self.device) attention_mask = data_dict['attention_mask'].to(self.device) x_enc = self.forward_mae_encoder(x, mask, input_ids=input_ids, attention_mask=attention_mask) z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n)) loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) return loss def image2text_loss(self, data_dict): input_ids = data_dict['input_ids'].to(self.device) attention_mask = data_dict['attention_mask'].to(self.device) labels = data_dict['labels'].to(self.device) pixel_values = data_dict.get('pixel_values', None) # print("pixel_values batch:", pixel_values.shape) # print("input_ids batch:", input_ids.shape) if pixel_values is None: inputs_embeds = self.llm.get_input_embeddings()(input_ids) _, z_null = self.extract_visual_feature( torch.zeros(1, 16, 16, self.token_embed_dim, dtype=self.dtype, device=self.device) ) loss_null = z_null.mean() * 0.0 print(f"No image found in this batch!", flush=True) else: x = pixel_values.to(dtype=self.dtype, device=self.device) x = self.encode(x) # b m n c _, z_enc = self.extract_visual_feature(x) if self.grad_scale is not None: z_enc = _ScaleGradient.apply(z_enc, self.grad_scale) inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size) self.tokenizer.add_tokens([""], special_tokens=True) IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("") # print(f"IMAGE_TOKEN_INDEX: {IMAGE_TOKEN_INDEX}") img_tokens = (torch.tensor(input_ids) == IMAGE_TOKEN_INDEX).sum().item() # print(f"[校验日志] input_ids长度: {len('input_ids')}, 图像token出现次数: {img_tokens}\n") inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1) inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( input_ids[input_ids != IMAGE_TOKEN_INDEX]) loss_null = 0.0 output = self.llm_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True) last_hidden_state = output.last_hidden_state[:, :-1] labels = labels[:, 1:] last_hidden_state = last_hidden_state[labels >= 0] labels = labels[labels >= 0] logits = self.llm.get_output_embeddings()(last_hidden_state) loss_i2t = F.cross_entropy(input=logits, target=labels) return loss_i2t + loss_null # def image_edit_loss(self, data_dict): # # 1. 图像前向:拼 batch 并编码到视觉特征 # x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) # 源图像批次,shape=[b_src, C, H, W] # x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) # 编辑图像批次,shape=[b_edit, C, H, W] # print_log(f"[DEBUG image_edit_loss] x_src.shape = {x_src.shape}, x.shape = {x.shape}", level="WARNING") # # b_edit 应该 >= b_src # assert x.shape[0] >= x_src.shape[0], \ # f"编辑批次大小 ({x.shape[0]}) 必须 >= 源图像批次大小 ({x_src.shape[0]})" # # 拼接并一次性编码 # x_all = torch.cat([x_src, x], dim=0) # shape=[b_src + b_edit, C, H, W] # x_all = self.encode(x_all) # shape=[b_src + b_edit, m, n, c] # # 分割回源/编辑两部分 # x_src_enc, x_enc = x_all.split([x_src.shape[0], x.shape[0]], dim=0) # # x_src_enc.shape=[b_src, m, n, c], x_enc.shape=[b_edit, m, n, c] # # 2. 提取视觉特征:x_con 用于 decoder 条件,z_src 用于填充文本中的 token # x_con, z_src = self.extract_visual_feature(x_src_enc) # if self.grad_scale is not None: # x_con = _ScaleGradient.apply(x_con, self.grad_scale) # z_src = _ScaleGradient.apply(z_src, self.grad_scale) # # z_src.shape = [b_src, m*n, C] # # 3. 文本条件分支:构造 inputs_embeds # attention_mask = data_dict['attention_mask'].to(self.device) # shape=[b_edit, seq_len] # input_ids = data_dict['input_ids'].to(self.device) # shape=[b_edit, seq_len] # b_edit, seq_len = input_ids.shape # hidden_size = self.llm.config.hidden_size # # 先准备一个全 0 的 inputs_embeds # inputs_embeds = z_src.new_zeros(b_edit, seq_len, hidden_size) # shape=[b_edit, seq_len, hidden_size] # # 找到所有 token 位置的 mask # mask_imgpos = (input_ids == IMAGE_TOKEN_INDEX) # bool tensor [b_edit, seq_len] # # 需要将单个 z_src 展开成 b_edit 份,再按 mask_imgpos 填入 # # 1) expand:把 z_src 从 [b_src, m*n, C] → [b_edit, m*n, C] # # (一般 b_src=1,所以就是复制那一份) # z_src_rep = z_src.expand(b_edit, -1, -1) # [b_edit, m*n, C] # # 2) flatten:将二维展开到一维,对应 mask_imgpos.sum() 个位置 # flat_z = z_src_rep.flatten(0, 1) # [b_edit*m*n, C] # # **重要检查**:保证 mask_imgpos 中 True 的数量 == flat_z.shape[0] # img_tokens_count = mask_imgpos.sum().item() # assert img_tokens_count == flat_z.shape[0], \ # f" token 数 ({img_tokens_count}) 不等于视觉特征数 ({flat_z.shape[0]})" # # 填充视觉 token 对应位置 # inputs_embeds[mask_imgpos] = flat_z # # 剩下的位置用文本 embedding # txt_pos = ~mask_imgpos # txt_embeddings = self.llm.get_input_embeddings()(input_ids[txt_pos]) # inputs_embeds[txt_pos] = txt_embeddings # # 4. MAE-style 重建分支:在 decoder 前注入 inputs_embeds 与 attention_mask # b, m, n, c = x_enc.shape # gt = x_enc.view(b, m*n, c) # 作为重建目标 # orders = self.mar.sample_orders(bsz=b, seq_len=m*n) # mask = self.mar.random_masking(x_enc.flatten(1, 2), orders) # # 带条件的 encoder forward # x_enc_out = self.forward_mae_encoder( # x_enc, # mask, # inputs_embeds=inputs_embeds, # attention_mask=attention_mask # ) # # decoder 重建 # z_dec = self.mar.forward_mae_decoder( # x_enc_out, # mask, # image_shape=(m, n), # x_con=x_con # ) # # 计算损失 # loss = self.mar.forward_loss(z=z_dec, target=gt, mask=mask) # return loss # def image_edit_loss_vae(self, data_dict): # """ # 计算图像编辑任务的损失。 # 参考图(x_src)的特征直接作为条件(x_con)送入解码器,不参与编码器重建。 # 编码器(encoder)仅在目标图(x_tgt)上进行掩码重建,并接收文本和参考图的上下文信息。 # """ # # === 步骤 1: 读入数据 === # x_src = data_dict['pixel_values_src'].to(self.device).to(self.dtype) # x_tgt = data_dict['pixel_values'].to(self.device).to(self.dtype) # attention_mask = data_dict['attention_mask'].to(self.device) # input_ids = data_dict['input_ids'].to(self.device) # # IMG_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("") # B = x_tgt.shape[0] # # === 步骤 2: 处理参考图 (Reference Image) === # # VAE编码,不计算梯度 # with torch.no_grad(): # z_src_latent = self.encode(x_src) # [B, m, n, token_dim] # # 将VAE潜变量转换为解码器条件(x_con)和LLM输入(z_src_buf) # # 这一步实现了 "参考图潜变量 -> 解码器" 的直接通路 # x_con, z_src_buf = self.vae_latent_to_decoder_feature(z_src_latent) # # x_con: [B, 4096, enc_dim] -> 用于解码器 # # z_src_buf: [B, 4160, llm_dim] -> 用于LLM # # === 步骤 3: 构建LLM的输入 (inputs_embeds) === # # 结合文本指令(input_ids)和参考图特征(z_src_buf) # _, T = input_ids.shape # H_llm = self.llm.config.hidden_size # inputs_embeds = torch.zeros(B, T, H_llm, device=self.device, dtype=z_src_buf.dtype) # # 填充 token和文本token的嵌入 # inputs_embeds[input_ids == IMG_TOKEN_INDEX] = z_src_buf.flatten(0, 1) # # input_ids 为33280 # # z_src_buf.flatten(0, 1) 为33792 为什么 会比input_ids 多512个呢? # inputs_embeds[input_ids != IMG_TOKEN_INDEX] = self.llm.get_input_embeddings()( # input_ids[input_ids != IMG_TOKEN_INDEX] # ) # # === 步骤 4: 处理目标图 (Target Image) 并进行编码器前向传播 === # # VAE编码目标图,不计算梯度 # with torch.no_grad(): # z_tgt_latent = self.encode(x_tgt) # [B, m, n, token_dim] # # 为目标图潜变量创建掩码(mask)以进行MAE重建 # B, m, n, token_dim = z_tgt_latent.shape # patch_tokens_tgt = z_tgt_latent.view(B, m * n, token_dim) # 作为重建的目标 # orders = self.mar.sample_orders(bsz=B, seq_len=m * n) # mask = self.mar.random_masking(patch_tokens_tgt, orders) # # **核心**: 编码器只处理目标图(z_tgt_latent)的可见部分,并接收LLM的上下文 # x_enc = self.forward_mae_encoder( # z_tgt_latent, # 目标图潜变量 # mask, # detach=False, # inputs_embeds=inputs_embeds, # 包含文本和参考图信息的上下文 # attention_mask=attention_mask # ) # # === 步骤 5: 解码器重建 === # # 解码器使用编码器的输出(x_enc)和参考图的特征(x_con)来重建完整的潜在表示 # z_pred = self.mar.forward_mae_decoder( # x_enc, # mask, # image_shape=(m, n), # x_con=x_con # ★ 参考图特征直接作用于此 # ) # # === 步骤 6: 计算损失 === # loss = self.mar.forward_loss( # z=z_pred, # target=patch_tokens_tgt, # mask=mask # ) # return loss def image_edit_loss_contrastive(self, data_dict): # Step 1: 获取图像特征 x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) assert len(x_src) >= len(x) x_src, x = self.encode(torch.cat([x_src, x])).split([len(x_src), len(x)], dim=0) # Step 2: 文本输入部分 attention_mask = data_dict['attention_mask'].to(self.device) input_ids = data_dict['input_ids'].to(self.device) x_con, z_src = self.extract_visual_feature(x_src) if self.grad_scale is not None: z_src = _ScaleGradient.apply(z_src, self.grad_scale) x_con = _ScaleGradient.apply(x_con, self.grad_scale) inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size) # IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("") inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( input_ids[input_ids != IMAGE_TOKEN_INDEX] ) # Step 3: 计算 reconstruction loss b, m, n, _ = x.shape gt_latents = x.clone().detach().view(b, m * n, -1) orders = self.mar.sample_orders(bsz=b, seq_len=m*n) mask = self.mar.random_masking(x.flatten(1, 2), orders) x_enc = self.forward_mae_encoder(x, mask, inputs_embeds=inputs_embeds, attention_mask=attention_mask) z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con) rec_loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) # Step 4: Contrastive loss between repeat and edit # 假设 batch 是偶数,按 (repeat, edit) 对排列 z_src_flat = z_src.mean(dim=1) # [B, D] 全局池化 z_src_flat = F.normalize(z_src_flat, dim=-1) repeat_z = z_src_flat[::2] # even index edit_z = z_src_flat[1::2] # odd index logits = torch.matmul(edit_z, repeat_z.T) / 0.07 # [B, B] labels = torch.arange(logits.size(0), device=logits.device) contrastive_loss = F.cross_entropy(logits, labels) return rec_loss + self.loss_weights.get("contrastive") * contrastive_loss def image_edit_loss(self, data_dict): # Multi-turn editing is also supported x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) # print_log(f"[DEBUG] x_src.shape = {x_src.shape}, x.shape = {x.shape}") # assert len(x_src) >= len(x) x_cat = torch.cat([x_src, x], dim=0) x_src, x = self.encode(x_cat).split([len(x_src), len(x)], dim=0) # Prepare context, including source images and instructions attention_mask = data_dict['attention_mask'].to(self.device) input_ids = data_dict['input_ids'].to(self.device) x_con, z_src = self.extract_visual_feature(x_src) if self.grad_scale is not None: z_src = _ScaleGradient.apply(z_src, self.grad_scale) x_con = _ScaleGradient.apply(x_con, self.grad_scale) inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size) self.tokenizer.add_tokens([""], special_tokens=True) IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("") # print("tokenizer idx in skywork_unipic_dev=", self.tokenizer.convert_tokens_to_ids("")) inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( input_ids[input_ids != IMAGE_TOKEN_INDEX] ) # -------------------------------------------------- # 3. MAE-style 重建 # -------------------------------------------------- b, m, n, _ = x.shape gt_latents = x.clone().detach().view(b, m * n, -1) orders = self.mar.sample_orders(bsz=b, seq_len=m*n) mask = self.mar.random_masking(x.flatten(1, 2), orders) x_enc = self.forward_mae_encoder(x, mask, inputs_embeds=inputs_embeds, attention_mask=attention_mask) z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con) loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) return loss def forward(self, data, data_samples=None, mode='loss'): if mode == 'loss': return self.compute_loss(data_dict=data) else: raise NotImplementedError def compute_loss(self, data_dict): losses = {} for data_type, batch_data in data_dict.items(): if 'text2image' in data_type: loss = self.text2image_loss(batch_data) elif 'image2text' in data_type: loss = self.image2text_loss(batch_data) elif 'image_edit' in data_type: loss = self.image_edit_loss(batch_data) else: raise NotImplementedError(f"Unknown data_type: {data_type}") weight = self.loss_weights.get(data_type, 1.0) losses[f'loss_{data_type}'] = loss * weight return losses