| |
| |
| |
|
|
| import copy |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple, Union |
| import os |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.utils import logging |
| from configuration_bailingmm import BailingMMConfig |
| from modeling_utils import patch_continuous_features, build_modality_mask |
|
|
| |
| from funasr.models.sanm.encoder import SANMEncoder |
| from modeling_bailing_moe import BailingMoeForCausalLM |
| from modeling_utils import Transpose, encode_audio_segments |
|
|
| |
| from qwen2_5_vit import Qwen2_5_VisionTransformer |
|
|
| |
| from modeling_bailing_talker import BailingTalkerForConditionalGeneration |
|
|
| |
| from modeling_whisper_encoder import WhisperAudioEncoder |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "BailingMMConfig" |
|
|
|
|
| @dataclass |
| class BailingMMCausalLMOutputWithPast(ModelOutput): |
| """ |
| Base class for BailingMM causal language model (or autoregressive) outputs. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): |
| The rope index difference between sequence length and multimodal rope. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[List[torch.FloatTensor]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
| class BailingMMNativeForConditionalGeneration(PreTrainedModel): |
| config_class = BailingMMConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["BailingAudioModel"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
|
|
| def __init__( |
| self, |
| config: BailingMMConfig, |
| ): |
| super().__init__(config) |
| self.config: BailingMMConfig = config |
| self.vision = None |
| self.audio = None |
| self.whisper_encoder = None |
| self.talker = None |
|
|
| self.llm_dytpe = torch.bfloat16 |
|
|
| if self.config.vision_config: |
| self.vision = Qwen2_5_VisionTransformer(self.config.vision_config) |
|
|
| if self.config.audio_config: |
| self.audio = SANMEncoder(**self.config.audio_config.audio_encoder_config_sanm) |
|
|
| if self.config.whisper_config: |
| self.whisper_encoder = WhisperAudioEncoder(**self.config.whisper_config.whisper_encoder_config) |
|
|
| self.model = BailingMoeForCausalLM(self.config.llm_config) |
|
|
| mlp_modules_img = [nn.Linear(self.vision.image_emb_dim, self.model.config.hidden_size)] |
| for _ in range(1, self.config.mlp_depth): |
| mlp_modules_img.append(nn.GELU()) |
| mlp_modules_img.append(nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size)) |
| self.linear_proj = nn.Sequential(*mlp_modules_img) |
|
|
| if self.audio: |
| audio_encoder_proj = torch.nn.Conv1d( |
| self.config.audio_config.audio_encoder_output_size, |
| self.model.config.hidden_size, |
| kernel_size=self.config.audio_config.ds_kernel_size, |
| stride=self.config.audio_config.ds_stride, |
| padding=self.config.audio_config.ds_kernel_size // 2, |
| ) |
|
|
| mlp_modules_audio = [audio_encoder_proj, Transpose(-1, -2)] |
| for _ in range(1, self.config.mlp_depth): |
| mlp_modules_audio.append(nn.GELU()) |
| mlp_modules_audio.append(nn.Linear( |
| self.model.config.hidden_size, self.model.config.hidden_size |
| )) |
| mlp_modules_audio.append(Transpose(-1, -2)) |
| self.linear_proj_audio = nn.Sequential(*mlp_modules_audio) |
|
|
| if self.whisper_encoder: |
| whisper_encoder_proj = torch.nn.Conv1d( |
| self.whisper_encoder.audio_emb_dim, |
| self.model.config.hidden_size, |
| kernel_size=self.config.whisper_config.ds_kernel_size, |
| stride=self.config.whisper_config.ds_stride, |
| padding=self.config.whisper_config.ds_kernel_size // 2, |
| ) |
|
|
| mlp_modules_whisper = [whisper_encoder_proj, Transpose(-1, -2)] |
| for _ in range(1, self.config.mlp_depth): |
| mlp_modules_whisper.append(nn.GELU()) |
| mlp_modules_whisper.append(nn.Linear( |
| self.model.config.hidden_size, self.model.config.hidden_size |
| )) |
| mlp_modules_whisper.append(Transpose(-1, -2)) |
| self.linear_proj_whisper = nn.Sequential(*mlp_modules_whisper) |
|
|
| if self.config.talker_config: |
| self.config.talker_config._name_or_path = f'{self.config._name_or_path}/talker' |
| self.talker = BailingTalkerForConditionalGeneration(self.config.talker_config) |
| self.post_init() |
| self.loaded_image_gen_modules = False |
|
|
| def extract_image_feature(self, pixel_values, grid_thw): |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| image_embeds = self.vision(pixel_values, grid_thw=grid_thw) |
| image_embeds = image_embeds.float() |
| image_embeds = self.linear_proj(image_embeds) |
| image_embeds = F.normalize(image_embeds, dim=-1) |
| return image_embeds |
| |
| def extract_audio_feature(self, audio_feats, audio_feats_lengths, use_whisper_encoder=False): |
| if not use_whisper_encoder: |
| assert self.audio is not None |
| assert self.linear_proj_audio is not None |
| encoder = self.audio |
| proj_layer = self.linear_proj_audio |
| else: |
| assert self.whisper_encoder is not None |
| assert self.linear_proj_whisper is not None |
| encoder = self.whisper_encoder |
| proj_layer = self.linear_proj_whisper |
| audio_embeds, _, audio_embeds_lengths = encode_audio_segments( |
| encoder=encoder, |
| proj_layer=proj_layer, |
| wav_feats=audio_feats, |
| wav_feats_lengths=audio_feats_lengths, |
| audio_config=self.config.audio_config, |
| whisper_config=self.config.whisper_config, |
| use_whisper_encoder=use_whisper_encoder |
| ) |
| if self.config.audio_config.norm_query_embeds: |
| audio_embeds = F.normalize(audio_embeds, dim=2) |
| return audio_embeds.to(audio_feats.dtype), audio_embeds_lengths |
|
|
| def prompt_wrap_vision(self, input_ids, inputs_embeds, vision_embeds, image_token_id=None): |
| if vision_embeds is None or input_ids is None: |
| return inputs_embeds |
|
|
| if len(vision_embeds.shape) == 3: |
| vision_embeds = vision_embeds.reshape(-1, vision_embeds.shape[-1]) |
|
|
| self.config.llm_config.image_patch_token = image_token_id if image_token_id is not None else self.config.llm_config.image_patch_token |
| n_image_tokens = (input_ids == self.config.llm_config.image_patch_token).sum().item() |
| n_image_features = vision_embeds.shape[0] |
|
|
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| image_router_mask = ( |
| (input_ids == self.config.llm_config.image_patch_token) |
| .unsqueeze(-1) |
| .to(inputs_embeds.device) |
| ) |
| image_mask = image_router_mask.expand_as(inputs_embeds) |
| image_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| image_router_mask = image_router_mask.squeeze(-1) |
| return inputs_embeds, image_router_mask |
|
|
| def prompt_wrap_audio(self, input_ids, inputs_embeds, audio_embeds, audio_embeds_lengths, placeholder_audio_loc_lens): |
| inputs_embeds = patch_continuous_features( |
| input_embeddings=inputs_embeds, placeholder_loc_lens=placeholder_audio_loc_lens, |
| encoded_feats=audio_embeds, encoded_feat_lens=audio_embeds_lengths, |
| ) |
| audio_router_mask = build_modality_mask(placeholder_audio_loc_lens, inputs_embeds.shape[:-1]).to(inputs_embeds.device) |
| return inputs_embeds, audio_router_mask |
| |
| def prompt_wrap_navit(self, input_ids, query_embeds_image=None, query_embeds_video=None, query_embeds_audio=None, |
| query_embeds_audio_lengths=None, placeholder_audio_loc_lens=None, target_embeds=None): |
| inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| if query_embeds_image is None and query_embeds_video is None and query_embeds_audio is None and target_embeds is None: |
| return inputs_embeds |
|
|
| image_mask = None |
| audio_mask = None |
| if query_embeds_image is not None: |
| inputs_embeds, image_mask = self.prompt_wrap_vision(input_ids, inputs_embeds, query_embeds_image) |
| if query_embeds_video is not None: |
| inputs_embeds, image_mask = self.prompt_wrap_vision(input_ids, inputs_embeds, query_embeds_video) |
| if query_embeds_audio is not None: |
| inputs_embeds, audio_mask = self.prompt_wrap_audio( |
| input_ids, inputs_embeds, query_embeds_audio, query_embeds_audio_lengths, placeholder_audio_loc_lens, |
| ) |
| return inputs_embeds, image_mask, audio_mask |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| audio_feats: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| audio_feats_lengths: Optional[torch.LongTensor] = None, |
| audio_placeholder_loc_lens: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.Tensor]] = None, |
| use_whisper_encoder: bool = False |
| ) -> Union[Tuple, BailingMMCausalLMOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if (pixel_values is not None or pixel_values_videos is not None or audio_feats is not None) and inputs_embeds is not None: |
| raise ValueError( |
| "You cannot specify both pixel_values/pixel_values_videos/pixel_values_audios and inputs_embeds at the same time, and must specify either one" |
| ) |
| |
| image_embeds, video_embeds, audio_embeds, audio_embeds_lengths = None, None, None, None |
| if pixel_values is not None: |
| image_embeds = self.extract_image_feature(pixel_values, grid_thw=image_grid_thw) |
| if pixel_values_videos is not None: |
| video_embeds = self.extract_image_feature(pixel_values_videos, grid_thw=video_grid_thw) |
| if audio_feats is not None: |
| audio_embeds, audio_embeds_lengths = self.extract_audio_feature(audio_feats, audio_feats_lengths, use_whisper_encoder=use_whisper_encoder) |
|
|
| if (image_embeds is None and video_embeds is None and audio_embeds is None) or input_ids.size(1) == 1: |
| words_embeddings = self.model.get_input_embeddings()(input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1)) |
| image_mask = None |
| audio_mask = None |
|
|
| else: |
| words_embeddings, image_mask, audio_mask = self.prompt_wrap_navit( |
| input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1), image_embeds, video_embeds, audio_embeds, |
| audio_embeds_lengths, audio_placeholder_loc_lens, None, |
| ) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=words_embeddings, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| image_mask=image_mask, |
| audio_mask=audio_mask, |
| ) |
|
|
| return BailingMMCausalLMOutputWithPast( |
| loss=outputs.loss, |
| logits=outputs.logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| ) |
|
|
| def append_input_ids_with_multiscale_learnable_tokens( |
| self, |
| text_ids, |
| attention_mask, |
| scales, |
| start_token_id, |
| end_token_id, |
| patch_token_id, |
| ): |
| assert text_ids.shape[0] == 1 |
| assert attention_mask.shape == text_ids.shape |
| gen_mask = torch.zeros_like(attention_mask) |
| for scale in scales: |
| text_ids = torch.cat([ |
| text_ids, |
| torch.tensor([[start_token_id]]).to(text_ids.dtype).to(text_ids.device), |
| torch.tensor([[patch_token_id] * (scale ** 2)]).to(text_ids.dtype).to(text_ids.device), |
| torch.tensor([[end_token_id]]).to(text_ids.dtype).to(text_ids.device), |
| ], dim=1) |
| attention_mask = torch.cat([ |
| attention_mask, |
| torch.tensor([[1] * ((scale ** 2) + 2)]).to(attention_mask.dtype).to(attention_mask.device), |
| ], dim=1) |
| gen_mask = torch.cat([ |
| gen_mask, |
| torch.tensor([[0]]).to(gen_mask.dtype).to(gen_mask.device), |
| torch.tensor([[1] * (scale ** 2)]).to(gen_mask.dtype).to(gen_mask.device), |
| torch.tensor([[0]]).to(gen_mask.dtype).to(gen_mask.device), |
| ], dim=1) |
| assert text_ids.shape == attention_mask.shape |
| assert attention_mask.shape == gen_mask.shape |
| return text_ids, attention_mask, gen_mask |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| audio_feats: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| audio_feats_lengths: Optional[torch.LongTensor] = None, |
| audio_placeholder_loc_lens: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.Tensor]] = None, |
| image_gen: Optional[bool] = False, |
| image_gen_steps: Optional[int] = 30, |
| image_gen_seed: Optional[int] = 0, |
| image_gen_cfg: Optional[float] = 3.5, |
| image_gen_height: Optional[int] = 512, |
| image_gen_width: Optional[int] = 512, |
| **generate_kwargs, |
| ): |
| image_embeds, video_embeds, audio_embeds, audio_embeds_lengths = None, None, None, None |
| if pixel_values is not None: |
| image_embeds = self.extract_image_feature(pixel_values, grid_thw=image_grid_thw) |
| if pixel_values_videos is not None: |
| video_embeds = self.extract_image_feature(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
| if image_gen: |
| assert self.loaded_image_gen_modules is True |
| input_ids, attention_mask, gen_mask = self.append_input_ids_with_multiscale_learnable_tokens( |
| input_ids, |
| attention_mask, |
| [4, 8, 16], |
| self.config.llm_config.image_patch_token + 1, |
| self.config.llm_config.image_patch_token + 2, |
| self.config.llm_config.image_patch_token, |
| ) |
| query_tokens_embeds = torch.cat( |
| [self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.img_gen_scales], |
| dim=0, |
| ) |
| if image_embeds is None: |
| image_embeds = query_tokens_embeds |
| else: |
| image_embeds = torch.cat([image_embeds, query_tokens_embeds], dim=0) |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| assert video_embeds is None and audio_embeds is None |
| if (image_embeds is None and video_embeds is None and audio_embeds is None) or input_ids.size(1) == 1: |
| words_embeddings = self.model.get_input_embeddings()(input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1)) |
| image_mask = None |
| audio_mask = None |
| else: |
| words_embeddings, image_mask, audio_mask = self.prompt_wrap_navit( |
| input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1), image_embeds, video_embeds, audio_embeds, |
| audio_embeds_lengths, audio_placeholder_loc_lens, None, |
| ) |
| outputs = self.model.forward( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=None, |
| inputs_embeds=words_embeddings, |
| use_cache=use_cache, |
| image_mask=image_mask, |
| audio_mask=audio_mask, |
| output_hidden_states=True, |
| ) |
| hidden_states = outputs.hidden_states[-1] |
| gen_mask = gen_mask.unsqueeze(-1).expand(gen_mask.shape[0], gen_mask.shape[1], hidden_states.shape[-1]).to(hidden_states.device).bool() |
| hidden_states_gen = torch.masked_select(hidden_states, gen_mask).view(hidden_states.shape[0], -1, hidden_states.shape[-1]) |
| |
| scale_start_idxes = [0] + self.scale_indices[:-1] |
| scale_end_idxes = self.scale_indices |
| assert scale_end_idxes[-1] == hidden_states_gen.shape[1] |
| new_query_embeds_images = {} |
| for scale, scale_start_idx, scale_end_idx in [ |
| i for i in zip(self.img_gen_scales, scale_start_idxes, scale_end_idxes) |
| ][-1:]: |
| scale_name = f"{scale}x{scale}" |
| scale_hidden = hidden_states_gen[:, scale_start_idx : scale_end_idx, :] |
| |
| |
| scale_embeds = self.proj_in(scale_hidden) |
| seq_shape = scale_embeds.shape |
| |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| scale_embeds = self.connector( |
| inputs_embeds=scale_embeds, |
| attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device), |
| output_hidden_states=True |
| ).hidden_states[-1] |
| scale_embeds = self.proj_out(scale_embeds) |
| |
| |
| scale_embeds = torch.nn.functional.normalize(scale_embeds, dim=-1) |
| new_query_embeds_images[scale_name] = scale_embeds |
| |
| imgs = [] |
| for scale in self.img_gen_scales[-1:]: |
| imgs.append( |
| self.diffusion_loss.sample( |
| new_query_embeds_images[f"{scale}x{scale}"], |
| steps=image_gen_steps, |
| seed=image_gen_seed, |
| cfg=image_gen_cfg, |
| height=image_gen_height, |
| width=image_gen_width |
| ) |
| ) |
| return imgs[-1] |
| |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| if audio_feats is not None: |
| use_whisper_encoder = generate_kwargs.pop('use_whisper_encoder', False) |
| audio_embeds, audio_embeds_lengths = self.extract_audio_feature(audio_feats, audio_feats_lengths, |
| use_whisper_encoder=use_whisper_encoder) |
| if (image_embeds is None and video_embeds is None and audio_embeds is None) or input_ids.size(1) == 1: |
| words_embeddings = self.model.get_input_embeddings()(input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1)) |
| image_mask = None |
| audio_mask = None |
| else: |
| words_embeddings, image_mask, audio_mask = self.prompt_wrap_navit( |
| input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1), image_embeds, video_embeds, audio_embeds, |
| audio_embeds_lengths, audio_placeholder_loc_lens, None, |
| ) |
|
|
| outputs = self.model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=words_embeddings, |
| use_cache=use_cache, |
| image_mask=image_mask, |
| audio_mask=audio_mask, |
| **generate_kwargs, |
| ) |
| return outputs |
|
|
| def load_image_gen_modules(self, inference_model_path): |
| from transformers import AutoModelForCausalLM |
| from diffusion.sana_loss import SANALoss |
| import os |
| from safetensors.torch import load_file |
| if os.path.exists(inference_model_path): |
| temp_state_dict = load_file(os.path.join(inference_model_path, 'mlp', 'model.safetensors')) |
| else: |
| from huggingface_hub import hf_hub_download |
| from safetensors import safe_open |
| safetensors_path = hf_hub_download( |
| repo_id=inference_model_path, |
| filename="model.safetensors", |
| subfolder="mlp" |
| ) |
| with safe_open(safetensors_path, framework="pt") as f: |
| temp_state_dict = {key: f.get_tensor(key) for key in f.keys()} |
| self.query_tokens_dict = nn.ParameterDict() |
| self.img_gen_scales = [4, 8, 16] |
| for scale in self.img_gen_scales: |
| num_tokens = scale * scale |
| scale_name = f"{scale}x{scale}" |
| |
| self.query_tokens_dict[scale_name] = nn.Parameter( |
| torch.nn.functional.normalize(torch.randn(num_tokens, self.model.config.hidden_size), dim=-1) |
| ) |
| self.query_tokens_dict.to(self.model.dtype).to(self.model.device) |
| modified_state_dict_query_tokens = { |
| f"{scale}x{scale}": temp_state_dict[f"query_tokens_dict.{scale}x{scale}"] |
| for scale in self.img_gen_scales |
| } |
| self.query_tokens_dict.load_state_dict(modified_state_dict_query_tokens, strict=True) |
| |
| self.scale_indices = [] |
| current_idx = 0 |
| for scale in self.img_gen_scales: |
| current_idx += scale * scale |
| self.scale_indices.append(current_idx) |
| |
| diffusion_mlp_state_dict = { |
| key[len("mlp.") :] : temp_state_dict[key] |
| for key in temp_state_dict if key.startswith("mlp.") |
| } |
| self.diffusion_loss = SANALoss( |
| model_path=inference_model_path, |
| scheduler_path=inference_model_path, |
| vision_dim=self.model.config.hidden_size, |
| |
| mlp_state_dict=diffusion_mlp_state_dict, |
| trainable_params="None", |
| ) |
| self.diffusion_loss.to(self.model.device) |
| |
| |
| self.connector = AutoModelForCausalLM.from_pretrained(inference_model_path, subfolder='connector') |
| for layer in self.connector.model.layers: |
| layer.self_attn.is_causal = False |
| self.connector.to(self.model.device) |
| |
| self.proj_in = nn.Linear(self.model.config.hidden_size, self.connector.config.hidden_size) |
| self.proj_out = nn.Linear(self.connector.config.hidden_size, self.model.config.hidden_size) |
| |
| modified_state_dict_in = { |
| 'weight': temp_state_dict['proj_in.weight'], |
| 'bias': temp_state_dict['proj_in.bias'] |
| } |
| self.proj_in.load_state_dict(modified_state_dict_in, strict=True) |
| modified_state_dict_out = { |
| 'weight': temp_state_dict['proj_out.weight'], |
| 'bias': temp_state_dict['proj_out.bias'] |
| } |
| self.proj_out.load_state_dict(modified_state_dict_out, strict=True) |
| self.proj_in.to(self.model.device) |
| self.proj_out.to(self.model.device) |
| self.loaded_image_gen_modules = True |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| *model_args, |
| **kwargs, |
| ): |
| model = super().from_pretrained( |
| pretrained_model_name_or_path, |
| *model_args, |
| **kwargs, |
| ) |
| model.load_image_gen_modules(pretrained_model_name_or_path) |
| return model |