| import torch |
| import torch.nn.functional as F |
| from torch.nn.modules.module import T |
| from mmengine.model import BaseModel |
| from torch.autograd.function import Function |
| from mmengine.logging import print_log |
| from xtuner.model.utils import guess_load_checkpoint |
| import os |
| |
| from .skywork_unipic_siglip import SkyworkUnipic |
| from xtuner.utils import IMAGE_TOKEN_INDEX |
| import torch.distributed as dist |
| import json |
| from einops import rearrange |
|
|
|
|
| def _load_state_dict_with_ds(module_to_load, state_dict, start_prefix="", strict=True): |
| try: |
| import deepspeed |
| except ImportError: |
| raise ImportError("deepspeed is not installed. Please install deepspeed to use this feature.") |
| |
| |
| metadata = getattr(state_dict, "_metadata", None) |
| state_dict = state_dict.copy() |
| if metadata is not None: |
| state_dict._metadata = metadata |
|
|
| error_msgs = [] |
| missing_keys = [] |
| unexpected_keys = [] |
|
|
| def load(module: torch.nn.Module, state_dict, prefix=""): |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| args = (state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
| |
| |
| if len([key for key in state_dict if key.startswith(prefix)]) > 0: |
| |
| |
| named_parameters = dict( |
| module.named_parameters(prefix=prefix[:-1], recurse=False) |
| ) |
| params_to_gather = [ |
| named_parameters[k] |
| for k in state_dict.keys() |
| if k in named_parameters |
| ] |
| if len(params_to_gather) > 0: |
| |
| |
| |
| with deepspeed.zero.GatheredParameters( |
| params_to_gather, modifier_rank=0 |
| ): |
| if deepspeed.comm.get_rank() == 0: |
| module._load_from_state_dict(*args) |
| else: |
| module._load_from_state_dict(*args) |
|
|
| for name, child in module._modules.items(): |
| if child is not None: |
| load(child, state_dict, prefix + name + ".") |
|
|
| load(module_to_load, state_dict, start_prefix) |
| if len(missing_keys) > 0: |
| print_log(f"[WARNING] Missing keys: {missing_keys}") |
| if len(unexpected_keys) > 0: |
| print_log(f"[WARNING] Unexpected keys: {unexpected_keys}") |
| if error_msgs: |
| raise RuntimeError( |
| "Error(s) in loading state_dict for {}:\n\t{}".format( |
| module_to_load.__class__.__name__, "\n\t".join(error_msgs) |
| ) |
| ) |
|
|
|
|
| class _ScaleGradient(Function): |
| @staticmethod |
| def forward(ctx, input, scale): |
| ctx.scale = scale |
| return input |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output * ctx.scale, None |
|
|
|
|
| class SkyworkUnipicDev(SkyworkUnipic, BaseModel): |
| def __init__( |
| self, |
| grad_scale=0.1, |
| loss_weights=None, |
| pretrained_pth=None, |
| mar_path=None, |
| siglip_proj_path=None, |
| freeze_llm=False, |
| freeze_mar=False, |
| freeze_mar_decoder=False, |
| freeze_siglip_proj=False, |
| gradient_checkpointing=True, |
| **kwargs, |
| ): |
| if loss_weights is None: |
| loss_weights = { |
| "image2text": 0.01, |
| "text2image": 1.0, |
| "image_edit": 1.0, |
| "contrastive": 0.1, |
| } |
| super().__init__(**kwargs) |
| |
| self.grad_scale = grad_scale |
| self.loss_weights = loss_weights |
| self.pretrained_pth = pretrained_pth |
| self.mar_path = mar_path |
| self.siglip_proj_path = siglip_proj_path |
|
|
| |
| rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
| |
| if pretrained_pth: |
| self.load_hf_weights( |
| skywork_unipic_ckpt=pretrained_pth, |
| siglip_proj_path=siglip_proj_path, |
| mar_path=mar_path |
| ) |
|
|
| |
| if freeze_llm: |
| self.llm.requires_grad_(False) |
| if freeze_mar: |
| self.mar.requires_grad_(False) |
| if freeze_mar_decoder: |
| |
| for param in self.mar.decoder_embed.parameters(): |
| param.requires_grad = False |
| for block in self.mar.decoder_blocks: |
| for param in block.parameters(): |
| param.requires_grad = False |
| for param in self.mar.decoder_norm.parameters(): |
| param.requires_grad = False |
| if isinstance(self.mar.decoder_pos_embed_learned, torch.nn.Parameter): |
| self.mar.decoder_pos_embed_learned.requires_grad = False |
| if isinstance(self.mar.diffusion_pos_embed_learned, torch.nn.Parameter): |
| self.mar.diffusion_pos_embed_learned.requires_grad = False |
| if freeze_siglip_proj: |
| self.siglip2_proj.requires_grad_(False) |
|
|
| |
| if gradient_checkpointing: |
| self.gradient_checkpointing_enable() |
| else: |
| self.gradient_checkpointing_disable() |
|
|
| |
| def load_hf_weights(self, |
| skywork_unipic_ckpt: str = None, |
| siglip_proj_path: str = None, |
| mar_path: str = None): |
| """统一加载 SkyworkUnipic(可选) + SigLIP2 + MAR""" |
| device = "cpu" |
| state_dict = {} |
|
|
| def _print_load_result(module_name, missing, unexpected): |
| print_log(f"[INFO] Loaded {module_name}. missing={len(missing)}, unexpected={len(unexpected)}") |
|
|
| |
| if skywork_unipic_ckpt: |
| print_log(f"[INFO] Loading SkyworkUnipic checkpoint from: {skywork_unipic_ckpt}") |
| |
| if os.path.isfile(skywork_unipic_ckpt): |
| skywork_unipic_state = torch.load(skywork_unipic_ckpt, map_location=device) |
| else: |
| idx = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin.index.json") |
| if os.path.exists(idx): |
| with open(idx, 'r') as f: |
| index = json.load(f) |
| skywork_unipic_state = {} |
| for shard in sorted(set(index["weight_map"].values())): |
| shard_path = os.path.join(skywork_unipic_ckpt, shard) |
| skywork_unipic_state.update(torch.load(shard_path, map_location=device)) |
| else: |
| bin_path = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin") |
| skywork_unipic_state = torch.load(bin_path, map_location=device) |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| model_dict = self.state_dict() |
| |
| filtered_checkpoint = {} |
| shape_mismatch_keys = [] |
|
|
| for k, v in skywork_unipic_state.items(): |
| if k in model_dict: |
| if v.shape == model_dict[k].shape: |
| filtered_checkpoint[k] = v |
| else: |
| shape_mismatch_keys.append((k, v.shape, model_dict[k].shape)) |
|
|
| missing, unexpected = self.load_state_dict(filtered_checkpoint, strict=False) |
| |
| if shape_mismatch_keys: |
| print("以下 key 因形状不匹配被跳过:") |
| for k, checkpoint_shape, model_shape in shape_mismatch_keys: |
| print(f" - {k}:") |
| print(f" checkpoint 中的形状: {checkpoint_shape}") |
| print(f" 当前模型的形状: {model_shape}") |
| else: |
| print("所有 key 形状匹配,未跳过任何参数") |
|
|
| |
| _print_load_result("SkyworkUnipic", missing, unexpected) |
| else: |
| print_log("[INFO] Skipping SkyworkUnipic checkpoint loading") |
|
|
| |
| if siglip_proj_path: |
| print_log(f"[INFO] Loading SigLIP2 weights from: {siglip_proj_path}") |
| siglip_state = torch.load( |
| siglip_proj_path, map_location="cpu", weights_only=False |
| ) |
| |
| if isinstance(siglip_state, dict) and "model" in siglip_state: |
| siglip_state = siglip_state["model"] |
| missing, unexpected = self.siglip2_proj.load_state_dict( |
| siglip_state, strict=False |
| ) |
| _print_load_result("SigLIP2", missing, unexpected) |
| else: |
| print_log("[INFO] No SigLIP2 checkpoint provided, skipping") |
|
|
| |
| if mar_path: |
| print_log(f"[INFO] Loading MAR weights from: {mar_path}") |
| mar_state = torch.load(mar_path, map_location="cpu", weights_only=False) |
| |
|
|
| if isinstance(mar_state, dict) and "model_ema" in mar_state: |
| mar_state = mar_state["model_ema"] |
|
|
| elif isinstance(mar_state, dict) and "model" in mar_state: |
| mar_state = mar_state["model"] |
| |
| |
| |
| if any(k.startswith("mar.") for k in mar_state): |
| filtered_mar = { |
| k.replace("mar.", "", 1): v |
| for k, v in mar_state.items() |
| if k.startswith("mar.") |
| } |
| else: |
| filtered_mar = mar_state |
|
|
| missing, unexpected = self.mar.load_state_dict( |
| filtered_mar, strict=False |
| ) |
| _print_load_result("MAR", missing, unexpected) |
| else: |
| print_log("[INFO] No MAR checkpoint provided, skipping") |
|
|
| return state_dict |
|
|
|
|
|
|
| def gradient_checkpointing_disable(self): |
| self.llm.gradient_checkpointing_disable() |
| self.mar.gradient_checkpointing_disable() |
|
|
| def gradient_checkpointing_enable(self): |
| self.llm.gradient_checkpointing_enable() |
| self.mar.gradient_checkpointing_enable() |
|
|
| def state_dict(self, *args, **kwargs): |
| state_dict = super().state_dict(*args, **kwargs) |
| state_dict = {k: v for k, v in state_dict.items() |
| if 'vae.' not in k} |
|
|
| return state_dict |
|
|
| def train(self: T, mode: bool = True) -> T: |
| super().train(mode=mode) |
| self.vae.train(mode=False) |
| return self |
|
|
| def text2image_loss(self, data_dict): |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) |
| x = self.encode(x) |
| b, m, n, _ = x.shape |
| gt_latents = x.clone().detach().view(b, m*n, -1) |
|
|
| orders = self.mar.sample_orders(bsz=b, seq_len=m*n) |
| mask = self.mar.random_masking(x.flatten(1, 2), orders) |
|
|
| input_ids = data_dict['input_ids'].to(self.device) |
| attention_mask = data_dict['attention_mask'].to(self.device) |
| x_enc = self.forward_mae_encoder(x, mask, input_ids=input_ids, |
| attention_mask=attention_mask) |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n)) |
|
|
| loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) |
|
|
| return loss |
| |
| def image2text_loss(self, data_dict): |
| input_ids = data_dict['input_ids'].to(self.device) |
| attention_mask = data_dict['attention_mask'].to(self.device) |
| labels = data_dict['labels'].to(self.device) |
|
|
| pixel_values = data_dict.get('pixel_values', None) |
| |
| |
| if pixel_values is None: |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) |
| _, z_null = self.extract_visual_feature( |
| torch.zeros(1, 16, 16, self.token_embed_dim, |
| dtype=self.dtype, device=self.device) |
| ) |
| loss_null = z_null.mean() * 0.0 |
| print(f"No image found in this batch!", flush=True) |
| else: |
| x = pixel_values.to(dtype=self.dtype, device=self.device) |
| x = self.encode(x) |
| _, z_enc = self.extract_visual_feature(x) |
|
|
| if self.grad_scale is not None: |
| z_enc = _ScaleGradient.apply(z_enc, self.grad_scale) |
|
|
| inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size) |
| |
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
| IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>") |
| |
| img_tokens = (torch.tensor(input_ids) == IMAGE_TOKEN_INDEX).sum().item() |
| |
|
|
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1) |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( |
| input_ids[input_ids != IMAGE_TOKEN_INDEX]) |
| loss_null = 0.0 |
|
|
| output = self.llm_model(inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| return_dict=True) |
|
|
| last_hidden_state = output.last_hidden_state[:, :-1] |
| labels = labels[:, 1:] |
| last_hidden_state = last_hidden_state[labels >= 0] |
| labels = labels[labels >= 0] |
| logits = self.llm.get_output_embeddings()(last_hidden_state) |
|
|
| loss_i2t = F.cross_entropy(input=logits, target=labels) |
|
|
| return loss_i2t + loss_null |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def image_edit_loss_contrastive(self, data_dict): |
| |
| x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) |
| assert len(x_src) >= len(x) |
| x_src, x = self.encode(torch.cat([x_src, x])).split([len(x_src), len(x)], dim=0) |
|
|
| |
| attention_mask = data_dict['attention_mask'].to(self.device) |
| input_ids = data_dict['input_ids'].to(self.device) |
|
|
| x_con, z_src = self.extract_visual_feature(x_src) |
| if self.grad_scale is not None: |
| z_src = _ScaleGradient.apply(z_src, self.grad_scale) |
| x_con = _ScaleGradient.apply(x_con, self.grad_scale) |
|
|
| inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size) |
| |
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( |
| input_ids[input_ids != IMAGE_TOKEN_INDEX] |
| ) |
|
|
| |
| b, m, n, _ = x.shape |
| gt_latents = x.clone().detach().view(b, m * n, -1) |
| orders = self.mar.sample_orders(bsz=b, seq_len=m*n) |
| mask = self.mar.random_masking(x.flatten(1, 2), orders) |
| x_enc = self.forward_mae_encoder(x, mask, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask) |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con) |
| rec_loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) |
|
|
| |
| |
| z_src_flat = z_src.mean(dim=1) |
| z_src_flat = F.normalize(z_src_flat, dim=-1) |
|
|
| repeat_z = z_src_flat[::2] |
| edit_z = z_src_flat[1::2] |
|
|
| logits = torch.matmul(edit_z, repeat_z.T) / 0.07 |
| labels = torch.arange(logits.size(0), device=logits.device) |
| contrastive_loss = F.cross_entropy(logits, labels) |
|
|
| return rec_loss + self.loss_weights.get("contrastive") * contrastive_loss |
|
|
| def image_edit_loss(self, data_dict): |
| |
| x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) |
| x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) |
| |
|
|
| |
| x_cat = torch.cat([x_src, x], dim=0) |
| x_src, x = self.encode(x_cat).split([len(x_src), len(x)], dim=0) |
|
|
| |
| attention_mask = data_dict['attention_mask'].to(self.device) |
| input_ids = data_dict['input_ids'].to(self.device) |
|
|
| x_con, z_src = self.extract_visual_feature(x_src) |
| if self.grad_scale is not None: |
| z_src = _ScaleGradient.apply(z_src, self.grad_scale) |
| x_con = _ScaleGradient.apply(x_con, self.grad_scale) |
|
|
| inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size) |
|
|
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
|
| IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>") |
| |
|
|
| inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) |
| inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( |
| input_ids[input_ids != IMAGE_TOKEN_INDEX] |
| ) |
|
|
| |
| |
| |
|
|
| b, m, n, _ = x.shape |
| gt_latents = x.clone().detach().view(b, m * n, -1) |
| orders = self.mar.sample_orders(bsz=b, seq_len=m*n) |
| mask = self.mar.random_masking(x.flatten(1, 2), orders) |
| x_enc = self.forward_mae_encoder(x, mask, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask) |
| z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con) |
|
|
| loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask) |
| return loss |
|
|
|
|
|
|
| def forward(self, data, data_samples=None, mode='loss'): |
| if mode == 'loss': |
| return self.compute_loss(data_dict=data) |
| else: |
| raise NotImplementedError |
|
|
| def compute_loss(self, data_dict): |
| losses = {} |
| for data_type, batch_data in data_dict.items(): |
| if 'text2image' in data_type: |
| loss = self.text2image_loss(batch_data) |
| elif 'image2text' in data_type: |
| loss = self.image2text_loss(batch_data) |
| elif 'image_edit' in data_type: |
| loss = self.image_edit_loss(batch_data) |
| else: |
| raise NotImplementedError(f"Unknown data_type: {data_type}") |
| weight = self.loss_weights.get(data_type, 1.0) |
| losses[f'loss_{data_type}'] = loss * weight |
| return losses |
|
|
|
|
|
|
|
|
|
|
|
|
|
|