Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import logging | |
| import re | |
| from typing import Optional, Tuple, Union, List | |
| from einops import rearrange | |
| from timm.layers import LayerNorm, LayerNorm2d | |
| from timm.layers.pos_embed import resample_abs_pos_embed | |
| from timm.models.regnet import RegStage | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from transformers import LlamaForCausalLM | |
| from transformers.modeling_outputs import BaseModelOutput | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.auto import AutoModelForCausalLM | |
| from transformers.models.qwen2_vl.configuration_qwen2_vl import ( | |
| Qwen2VLVisionConfig, | |
| ) | |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | |
| PatchEmbed, | |
| Qwen2VLPreTrainedModel, | |
| Qwen2VisionTransformerPretrainedModel, | |
| Qwen2VLVisionBlock, | |
| VisionRotaryEmbedding | |
| ) | |
| from configuration import KananaVVisualProjectorConfig, KananaVConfig | |
| logger = logging.getLogger("kanana-1.5-v") | |
| def build_pos_embeds( | |
| config: KananaVVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int | |
| ): | |
| # pos emb | |
| if config.pos_emb: | |
| # ✨ 수정: num_input_tokens가 음수일 때 기본값 사용 | |
| if num_input_tokens <= 0: | |
| num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576 | |
| pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size)) | |
| nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02) | |
| else: | |
| pos_emb = None | |
| return pos_emb | |
| def build_eos_tokens(config: KananaVVisualProjectorConfig, output_hidden_size: int): | |
| # think tokens | |
| num_eos_tokens = config.num_eos_tokens | |
| if num_eos_tokens: | |
| eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size)) | |
| nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) | |
| else: | |
| eos_tokens = None | |
| return eos_tokens | |
| def build_prenorm(config: KananaVVisualProjectorConfig): | |
| if getattr(config, "prenorm", False): | |
| prenorm = LayerNorm(config.encoder_hidden_size) | |
| else: | |
| prenorm = None | |
| return prenorm | |
| def build_mlp(depth: int, hidden_size: int, output_hidden_size: int): | |
| layers = [nn.Linear(hidden_size, output_hidden_size)] | |
| for _ in range(1, depth): | |
| layers.append(nn.SiLU()) | |
| layers.append(nn.Linear(output_hidden_size, output_hidden_size)) | |
| return nn.Sequential(*layers) | |
| class PatchMerge(nn.Module): | |
| def __init__(self, merge_size): | |
| super().__init__() | |
| self.merge_size = merge_size | |
| def forward(self, x, channel_last=False): | |
| if channel_last: | |
| x = rearrange(x, "B H W D -> B D H W") | |
| _, D, H, W = x.shape | |
| # 홀수 차원을 처리하기 위해 패딩 추가 | |
| pad_h = (self.merge_size - H % self.merge_size) % self.merge_size | |
| pad_w = (self.merge_size - W % self.merge_size) % self.merge_size | |
| if pad_h > 0 or pad_w > 0: | |
| print(f"🔍 PatchMerge - 패딩 추가: H={H}->{H+pad_h}, W={W}->{W+pad_w}") | |
| x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='replicate') | |
| H, W = H + pad_h, W + pad_w | |
| merged_x = rearrange( | |
| x, "B D (H h2) (W w2) -> B (D h2 w2) H W", h2=self.merge_size, w2=self.merge_size | |
| ) | |
| return merged_x | |
| class DynamicCAbstractor(nn.Module): | |
| """Dynamic C-Abstractor based on RegBlock""" | |
| def __init__(self, config: KananaVVisualProjectorConfig, num_input_tokens: int): | |
| super().__init__() | |
| self.config = config | |
| # ✨ 수정: num_input_tokens가 음수일 때 기본값 설정 | |
| if num_input_tokens <= 0: | |
| num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576 | |
| self.num_input_tokens = num_input_tokens | |
| # ✨ 추가: 누락된 속성들 설정 | |
| self.merge_size = getattr(config, 'merge_size', 2) | |
| self.pos_emb_size = getattr(config, 'pos_emb_size', 576) | |
| # ✨ 최적화: 모든 레이어를 bfloat16으로 초기화 | |
| self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size) | |
| if self.pos_emb is not None: | |
| self.pos_emb.data = self.pos_emb.data.to(torch.bfloat16) | |
| self.eos_tokens = build_eos_tokens(config, config.output_hidden_size) | |
| if self.eos_tokens is not None: | |
| self.eos_tokens.data = self.eos_tokens.data.to(torch.bfloat16) | |
| self.prenorm = build_prenorm(config) | |
| if self.prenorm is not None: | |
| self.prenorm = self.prenorm.to(torch.bfloat16) | |
| # ✨ 수정: build_net에서 self.net과 self.readout 설정 | |
| self.build_net() | |
| # ✨ 최적화: net 레이어들을 bfloat16으로 변환 | |
| if hasattr(self, 'net'): | |
| if isinstance(self.net, nn.ModuleList): | |
| for layer in self.net: | |
| layer = layer.to(torch.bfloat16) | |
| for module in layer.modules(): | |
| if hasattr(module, 'weight'): | |
| module.weight.data = module.weight.data.to(torch.bfloat16) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| module.bias.data = module.bias.data.to(torch.bfloat16) | |
| else: | |
| # self.net이 단일 모듈인 경우 | |
| self.net = self.net.to(torch.bfloat16) | |
| for module in self.net.modules(): | |
| if hasattr(module, 'weight'): | |
| module.weight.data = module.weight.data.to(torch.bfloat16) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| module.bias.data = module.bias.data.to(torch.bfloat16) | |
| # ✨ 최적화: readout 레이어를 bfloat16으로 변환 | |
| if hasattr(self, 'readout'): | |
| self.readout = self.readout.to(torch.bfloat16) | |
| for module in self.readout.modules(): | |
| if hasattr(module, 'weight'): | |
| module.weight.data = module.weight.data.to(torch.bfloat16) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| module.bias.data = module.bias.data.to(torch.bfloat16) | |
| def build_net(self): | |
| encoder_hidden_size = self.config.encoder_hidden_size | |
| hidden_size = self.config.hidden_size | |
| output_hidden_size = self.config.output_hidden_size | |
| depth = self.config.depth | |
| mlp_depth = self.config.mlp_depth | |
| RegBlock = partial( | |
| RegStage, | |
| stride=1, | |
| dilation=1, | |
| act_layer=nn.SiLU, | |
| norm_layer=LayerNorm2d, | |
| ) | |
| s1 = RegBlock( | |
| depth, | |
| encoder_hidden_size, | |
| hidden_size, | |
| ) | |
| sampler = PatchMerge(merge_size=self.merge_size) | |
| s2 = RegBlock( | |
| depth, | |
| self.merge_size**2 * hidden_size, | |
| hidden_size, | |
| ) | |
| if depth: | |
| self.net = nn.ModuleList([s1, sampler, s2]) | |
| self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size) | |
| else: | |
| self.net = sampler | |
| self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size) | |
| def forward(self, flattened_visual_embeds, grid_thw, **unused_kwargs): | |
| n_token_loc = torch.prod(grid_thw, dim=1) | |
| split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist()) | |
| flattened_visual_embeds = [] | |
| for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw): | |
| T, H, W = _grid_thw | |
| assert T == 1, "T must be 1. Video is not supported yet." | |
| reshaped_visual_embeds = rearrange( | |
| _visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W | |
| ) | |
| # remove temporal dim | |
| reshaped_visual_embeds = reshaped_visual_embeds[:, 0] | |
| if self.prenorm is not None: | |
| reshaped_visual_embeds = self.prenorm(reshaped_visual_embeds) | |
| if self.pos_emb is not None: | |
| # interpolate pos emb and add to visual embeds | |
| print(f"🔍 abstractor - pos_emb 형태: {self.pos_emb.shape}") | |
| print(f"🔍 abstractor - reshaped_visual_embeds 형태: {reshaped_visual_embeds.shape}") | |
| _local_pos_emb = resample_abs_pos_embed( | |
| posemb=self.pos_emb, | |
| old_size=tuple([int(self.pos_emb_size**0.5)] * 2), | |
| new_size=(H, W), | |
| num_prefix_tokens=0, | |
| ) | |
| _local_pos_emb = rearrange( | |
| _local_pos_emb, | |
| "1 (h w) d -> 1 h w d", | |
| h=H, | |
| w=W, | |
| ) | |
| print(f"🔍 abstractor - _local_pos_emb 형태: {_local_pos_emb.shape}") | |
| # 차원이 맞지 않는 경우 처리 | |
| if reshaped_visual_embeds.shape[-1] != _local_pos_emb.shape[-1]: | |
| print(f"🔍 abstractor - 차원 불일치 감지, pos_emb 건너뛰기") | |
| # pos_emb를 건너뛰고 visual_embeds만 사용 | |
| else: | |
| reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb | |
| reshaped_visual_embeds = self._forward( | |
| reshaped_visual_embeds, | |
| input_size=(H, W), | |
| ) | |
| flattened_visual_embeds.append(reshaped_visual_embeds) | |
| reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0) | |
| output = BaseModelOutput(last_hidden_state=reshaped_visual_embeds) | |
| return output | |
| def _forward(self, x, input_size): | |
| h, w = input_size | |
| x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w) | |
| # 입력 채널 수가 맞지 않는 경우 처리 | |
| # RegStage의 첫 번째 블록에서 채널 수 확인 | |
| try: | |
| if hasattr(self.net[0], 'conv'): | |
| expected_channels = self.net[0].conv.in_channels | |
| elif hasattr(self.net[0], 'blocks') and len(self.net[0].blocks) > 0: | |
| expected_channels = self.net[0].blocks[0].conv1.in_channels | |
| else: | |
| # 기본값 사용 | |
| expected_channels = 1280 | |
| except: | |
| expected_channels = 1280 | |
| actual_channels = x.shape[1] | |
| if actual_channels != expected_channels: | |
| # 선형 변환으로 채널 수 조정 | |
| if not hasattr(self, 'channel_adapter'): | |
| # channel_adapter를 bfloat16으로 생성 | |
| self.channel_adapter = nn.Linear(actual_channels, expected_channels, dtype=torch.bfloat16).to(x.device) | |
| x = x.permute(0, 2, 3, 1) # (1, d, h, w) -> (1, h, w, d) | |
| # 입력을 bfloat16으로 변환 (한 번만) | |
| if x.dtype != torch.bfloat16: | |
| x = x.to(torch.bfloat16) | |
| x = self.channel_adapter(x) # 채널 수 조정 | |
| x = x.permute(0, 3, 1, 2) # (1, h, w, d) -> (1, d, h, w) | |
| # ✨ 최적화: 이미 bfloat16으로 초기화된 레이어들 사용 | |
| x = self.net[0](x) | |
| x = self.net[1](x) | |
| x = self.net[2](x) | |
| x = rearrange(x, "1 d h w -> (h w) d") | |
| # ✨ 최적화: 이미 bfloat16으로 초기화된 readout 사용 | |
| x = self.readout(x) | |
| return x | |
| class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel): | |
| config_class = Qwen2VLVisionConfig | |
| _no_split_modules = ["Qwen2VLVisionBlock"] | |
| def __init__(self, config) -> None: | |
| Qwen2VLPreTrainedModel.__init__(self, config) | |
| self.spatial_merge_size = config.spatial_merge_size | |
| self.gradient_checkpointing = False | |
| self.patch_embed = PatchEmbed( | |
| patch_size=config.patch_size, | |
| temporal_patch_size=config.temporal_patch_size, | |
| in_channels=config.in_channels, | |
| embed_dim=config.embed_dim, | |
| ) | |
| head_dim = config.embed_dim // config.num_heads | |
| self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) | |
| self.blocks = nn.ModuleList( | |
| [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] | |
| ) | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| grid_thw: torch.Tensor, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutput]: | |
| assert return_dict, "Only return_dict=True is supported." | |
| encoder_states = () if output_hidden_states else None | |
| hidden_states = self.patch_embed(pixel_values) | |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) | |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) | |
| position_embeddings = emb.cos(), emb.sin() | |
| cu_seqlens = torch.repeat_interleave( | |
| grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] | |
| ).cumsum(dim=0, dtype=torch.int32) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| for blk in self.blocks: | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = torch.utils.checkpoint.checkpoint( | |
| blk.__call__, | |
| hidden_states=hidden_states, | |
| cu_seqlens=cu_seqlens, | |
| position_embeddings=position_embeddings, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| layer_outputs = blk( | |
| hidden_states=hidden_states, | |
| cu_seqlens=cu_seqlens, | |
| position_embeddings=position_embeddings, | |
| ) | |
| hidden_states = layer_outputs | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, encoder_states] if v is not None) | |
| return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states) | |
| def get_num_tokens(self): | |
| return -1 | |
| class KananaVPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and | |
| a simple interface for downloading and loading pretrained models. | |
| """ | |
| config_class = KananaVConfig | |
| base_model_prefix = "kanana-1.5-v" | |
| supports_gradient_checkpointing = True | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _supports_cache_class = True | |
| _supports_static_cache = False | |
| _keys_to_ignore_on_load_missing = [ | |
| r"position_ids", | |
| r"language_model.encoder.embed_tokens.weight", | |
| r"language_model.decoder.embed_tokens.weight", | |
| r"language_model.lm_head.weight", | |
| ] | |
| _no_split_modules = [ | |
| "CustomQwen2VLVE", | |
| "DynamicCAbstractor", | |
| "LlamaForCausalLM", | |
| "Parameter", | |
| ] | |
| def _init_weights(self, module): | |
| """Initialize the weights""" | |
| if ( | |
| isinstance(module, nn.Conv2d) | |
| or isinstance(module, nn.Embedding) | |
| or isinstance(module, nn.Linear) | |
| ): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if hasattr(module, "bias") and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, nn.Parameter): | |
| raise ValueError() | |
| class KananaVForConditionalGeneration(KananaVPreTrainedModel): | |
| config_class = KananaVConfig | |
| def __init__(self, config: KananaVConfig): | |
| super().__init__(config) | |
| logger.info("Build vision model ...") | |
| self.vision_model = CustomQwen2VLVE._from_config(config.vision_config) | |
| logger.info("Build projector ...") | |
| self.abstractor = DynamicCAbstractor(config.projector_config, | |
| num_input_tokens=self.vision_model.get_num_tokens()) | |
| logger.info("Build language model ...") | |
| self.language_model = LlamaForCausalLM._from_config(config=config.text_config) | |
| self.post_init() | |
| def forward_vision(self, pixel_values: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None): | |
| # ✨ 핵심 수정: pixel_values가 리스트일 경우와 텐서일 경우를 모두 처리 | |
| if isinstance(pixel_values, list): | |
| # 다중 이미지: 각 이미지를 처리하여 결과를 합침 | |
| visual_features_list = [] | |
| for i, pv in enumerate(pixel_values): | |
| single_image_metas = {k: v[i] for k, v in image_metas.items()} | |
| # grid_thw 처리 수정 | |
| vision_grid_thw = single_image_metas["vision_grid_thw"] | |
| if isinstance(vision_grid_thw, (list, tuple)): | |
| # 튜플을 리스트로 변환하여 텐서 생성 | |
| grid_thw = torch.tensor([list(vision_grid_thw)]).to(pv.device) | |
| else: | |
| grid_thw = torch.tensor([vision_grid_thw]).to(pv.device) | |
| # ✨ 최적화: 불필요한 dtype 변환 제거 | |
| v_outputs = self.vision_model( | |
| pixel_values=pv.unsqueeze(0), | |
| grid_thw=grid_thw, | |
| return_dict=True, output_hidden_states=True | |
| ) | |
| layer_index = self.config.projector_config.feature_layer_index | |
| visual_features_list.append(self._get_visual_feature_at(v_outputs.hidden_states, layer_index)) | |
| return visual_features_list # 리스트 형태로 반환 | |
| else: | |
| # 단일 이미지 - 이미 분리된 특징 텐서 | |
| # grid_thw가 리스트인 경우 첫 번째 요소 사용 | |
| grid_thw = image_metas["vision_grid_thw"] | |
| if isinstance(grid_thw, list): | |
| grid_thw = grid_thw[0] | |
| # grid_thw를 텐서로 변환 | |
| if not isinstance(grid_thw, torch.Tensor): | |
| if isinstance(grid_thw, (list, tuple)): | |
| # 튜플을 리스트로 변환하여 텐서 생성 | |
| grid_thw = torch.tensor([list(grid_thw)]) | |
| else: | |
| grid_thw = torch.tensor([grid_thw]) | |
| # 디바이스 정보 추가 | |
| if hasattr(pixel_values, 'device'): | |
| grid_thw = grid_thw.to(pixel_values.device) | |
| # pixel_values가 2D 특징 텐서인 경우 vision_model을 통해 처리 | |
| if len(pixel_values.shape) == 2: | |
| # 2D 특징 텐서를 vision_model이 처리할 수 있는 형태로 변환 | |
| # 다중 이미지와 동일한 방식으로 처리하되, 올바른 형태로 변환 | |
| # pixel_values를 (1, 3, H, W) 형태로 재구성 | |
| # 다중 이미지에서 사용하는 방식과 동일하게 처리 | |
| if len(pixel_values.shape) == 2: | |
| # 2D 텐서를 vision_model이 처리할 수 있는 형태로 변환 | |
| # 다중 이미지에서는 이미 올바른 형태로 전달되므로 동일하게 처리 | |
| # ✨ 최적화: 불필요한 dtype 변환 제거 | |
| v_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| grid_thw=grid_thw, | |
| return_dict=True, output_hidden_states=True | |
| ) | |
| layer_index = self.config.projector_config.feature_layer_index | |
| return self._get_visual_feature_at(v_outputs.hidden_states, layer_index) | |
| else: | |
| return pixel_values | |
| # ✨ 최적화: 불필요한 dtype 변환 제거 | |
| v_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| grid_thw=grid_thw, | |
| return_dict=True, output_hidden_states=True | |
| ) | |
| layer_index = self.config.projector_config.feature_layer_index | |
| return self._get_visual_feature_at(v_outputs.hidden_states, layer_index) | |
| def forward_projector(self, visual_features: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None): | |
| print(f"🔍 forward_projector - visual_features 형태: {visual_features.shape if hasattr(visual_features, 'shape') else type(visual_features)}") | |
| # ✨ 핵심 수정: visual_features가 리스트일 경우 처리 | |
| if isinstance(visual_features, list): | |
| print(f"🔍 forward_projector - 리스트 형태 처리") | |
| visual_embeds_list = [] | |
| for i, vf in enumerate(visual_features): | |
| single_image_metas = {k: v[i] for k, v in image_metas.items()} | |
| # grid_thw 처리 수정 | |
| vision_grid_thw = single_image_metas["vision_grid_thw"] | |
| if isinstance(vision_grid_thw, (list, tuple)): | |
| # 튜플을 리스트로 변환하여 텐서 생성 | |
| grid_thw = torch.tensor([list(vision_grid_thw)]).to(vf.device) | |
| else: | |
| grid_thw = torch.tensor([vision_grid_thw]).to(vf.device) | |
| print(f"🔍 forward_projector - 이미지 {i} 처리 중") | |
| print(f"🔍 forward_projector - 이미지 {i} 특징 형태: {vf.shape}") | |
| print(f"🔍 forward_projector - 이미지 {i} grid_thw: {grid_thw}") | |
| visual_embeds = self.abstractor(vf, grid_thw=grid_thw)["last_hidden_state"] | |
| print(f"🔍 forward_projector - 이미지 {i} visual_embeds 형태: {visual_embeds.shape}") | |
| visual_embeds_list.append(visual_embeds) | |
| return torch.cat(visual_embeds_list, dim=0) # 최종적으로 하나의 텐서로 합쳐서 반환 | |
| else: | |
| # 단일 이미지 | |
| print(f"🔍 forward_projector - 단일 텐서 처리") | |
| # visual_features가 이미 처리된 특징 텐서인 경우 | |
| if len(visual_features.shape) == 2: | |
| print(f"🔍 forward_projector - 이미 처리된 특징 텐서 감지") | |
| print(f"🔍 forward_projector - 특징 텐서 형태: {visual_features.shape}") | |
| # grid_thw가 리스트인 경우 첫 번째 요소 사용 | |
| grid_thw = image_metas["vision_grid_thw"] | |
| if isinstance(grid_thw, list): | |
| grid_thw = grid_thw[0] | |
| # grid_thw를 텐서로 변환 | |
| if not isinstance(grid_thw, torch.Tensor): | |
| if isinstance(grid_thw, (list, tuple)): | |
| # 튜플을 리스트로 변환하여 텐서 생성 | |
| grid_thw = torch.tensor([list(grid_thw)]) | |
| else: | |
| grid_thw = torch.tensor([grid_thw]) | |
| # 디바이스 정보 추가 | |
| if hasattr(visual_features, 'device'): | |
| grid_thw = grid_thw.to(visual_features.device) | |
| print(f"🔍 forward_projector - grid_thw: {grid_thw}") | |
| print(f"🔍 forward_projector - grid_thw 계산된 토큰 수: {torch.prod(grid_thw, dim=1)}") | |
| print(f"🔍 forward_projector - 실제 특징 텐서 토큰 수: {visual_features.shape[0]}") | |
| # grid_thw가 실제 토큰 수와 맞지 않는 경우 수정 | |
| actual_tokens = visual_features.shape[0] | |
| if torch.prod(grid_thw, dim=1).item() != actual_tokens: | |
| print(f"🔍 forward_projector - grid_thw 수정 필요") | |
| # 실제 토큰 수에 맞는 grid_thw 계산 | |
| # 이미지의 실제 비율을 고려하여 계산 | |
| T = 1 | |
| # 이미지 메타데이터에서 실제 크기 정보 사용 | |
| if 'hw_orig_resolution' in image_metas: | |
| orig_h, orig_w = image_metas['hw_orig_resolution'] | |
| if isinstance(orig_h, list): | |
| orig_h = orig_h[0] if isinstance(orig_h[0], (int, float)) else orig_h[0][0] | |
| if isinstance(orig_w, list): | |
| orig_w = orig_w[0] if isinstance(orig_w[0], (int, float)) else orig_w[0][0] | |
| # 실제 비율을 유지하면서 토큰 수에 맞게 조정 | |
| aspect_ratio = orig_w / orig_h | |
| H = int((actual_tokens / aspect_ratio) ** 0.5) | |
| W = int(actual_tokens / H) | |
| # 토큰 수가 맞지 않으면 조정 | |
| while H * W != actual_tokens and H > 1 and W > 1: | |
| if H * W > actual_tokens: | |
| H -= 1 | |
| W = int(actual_tokens / H) | |
| else: | |
| W += 1 | |
| H = int(actual_tokens / W) | |
| else: | |
| # 기본값 사용 | |
| H = int(actual_tokens ** 0.5) | |
| W = actual_tokens // H | |
| if actual_tokens % H != 0: | |
| W += 1 | |
| corrected_grid_thw = torch.tensor([[T, H, W]]) | |
| print(f"🔍 forward_projector - 수정된 grid_thw: {corrected_grid_thw}") | |
| print(f"🔍 forward_projector - 수정된 토큰 수: {torch.prod(corrected_grid_thw, dim=1)}") | |
| # 토큰 수가 맞지 않는 경우 패딩 또는 자르기 | |
| expected_tokens = torch.prod(corrected_grid_thw, dim=1).item() | |
| if expected_tokens > actual_tokens: | |
| # 패딩 | |
| padding_size = expected_tokens - actual_tokens | |
| padding = torch.zeros(padding_size, visual_features.shape[1], device=visual_features.device) | |
| visual_features = torch.cat([visual_features, padding], dim=0) | |
| print(f"🔍 forward_projector - 패딩 추가: {padding_size}개 토큰") | |
| elif expected_tokens < actual_tokens: | |
| # 자르기 | |
| visual_features = visual_features[:expected_tokens] | |
| print(f"🔍 forward_projector - 토큰 자르기: {expected_tokens}개로") | |
| grid_thw = corrected_grid_thw | |
| # 특징 텐서를 abstractor에 직접 전달 | |
| visual_embeds = self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"] | |
| print(f"🔍 forward_projector - abstractor 출력 형태: {visual_embeds.shape}") | |
| return visual_embeds | |
| else: | |
| # 일반적인 vision model 출력 | |
| grid_thw = image_metas["vision_grid_thw"] | |
| return self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"] | |
| def forward_and_project_vision(self, pixel_values, image_metas: Optional[dict] = None): | |
| visual_features = self.forward_vision(pixel_values, image_metas=image_metas) | |
| visual_embeds = self.forward_projector(visual_features, image_metas=image_metas) | |
| return visual_embeds | |
| def _get_visual_feature_at(self, v_output, layer_index): | |
| if isinstance(layer_index, list): | |
| visual_features = torch.stack(v_output, dim=1)[:, layer_index] # [B, n_scales, L, dim] | |
| else: | |
| visual_features = v_output[layer_index] # [B, L, dim] | |
| return visual_features | |
| def embed_text_tokens(self, input_ids): | |
| """Embed input_ids into text_embeds, ignoring media tokens (negative values).""" | |
| input_ids = input_ids.clone() | |
| input_ids[input_ids < 0] = 0 | |
| text_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| if hasattr(self.language_model, "transformer") and hasattr( | |
| self.language_model.transformer, "word_embeddings_layernorm" | |
| ): | |
| text_embeds = self.language_model.transformer.word_embeddings_layernorm(text_embeds) | |
| return text_embeds | |
| def prepare_mm_inputs( | |
| self, | |
| input_ids: torch.FloatTensor, | |
| pixel_values: Optional[list[torch.FloatTensor]] = None, | |
| image_metas: Optional[dict] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| ): | |
| """Prepare multimodal inputs from input_ids and pixel_values.""" | |
| if pixel_values is not None: | |
| # pixel_values가 리스트인 경우 각각을 변환 | |
| if isinstance(pixel_values, list): | |
| pixel_values = [pv.to(self._get_input_dtype()) for pv in pixel_values] | |
| else: | |
| pixel_values = pixel_values.to(self._get_input_dtype()) | |
| if attention_mask is None: | |
| attention_mask = input_ids.new_ones(*input_ids.shape) | |
| # Get Text Embeddings | |
| text_embeds = self.embed_text_tokens(input_ids) | |
| flattened_text_embeds = rearrange(text_embeds, "b l d -> (b l) d") | |
| flattened_input_ids = rearrange(input_ids, "b l -> (b l)") | |
| # Get Visual Embeddings | |
| if pixel_values is not None: | |
| print(f"🔍 prepare_mm_inputs - pixel_values 타입: {type(pixel_values)}") | |
| if hasattr(pixel_values, 'shape'): | |
| print(f"🔍 prepare_mm_inputs - pixel_values 형태: {pixel_values.shape}") | |
| if isinstance(pixel_values, list): | |
| print(f"🔍 prepare_mm_inputs - pixel_values 길이: {len(pixel_values)}") | |
| # 다중 이미지 처리: 각 이미지를 개별적으로 처리 | |
| if isinstance(pixel_values, list) and len(pixel_values) > 1: | |
| print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 시작") | |
| visual_embeds_list = [] | |
| for i, single_pixel_values in enumerate(pixel_values): | |
| print(f"🔍 prepare_mm_inputs - 이미지 {i} 처리 중") | |
| # 각 이미지에 대한 개별 image_metas 생성 | |
| single_image_metas = {} | |
| for key, value_list in image_metas.items(): | |
| if isinstance(value_list, list): | |
| single_image_metas[key] = value_list[i] | |
| else: | |
| single_image_metas[key] = value_list | |
| # 개별 이미지 처리 | |
| single_visual_embeds = self.forward_and_project_vision( | |
| single_pixel_values.unsqueeze(0), single_image_metas | |
| ) | |
| visual_embeds_list.append(single_visual_embeds) | |
| # 모든 이미지의 visual embeds를 연결 | |
| flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0) | |
| print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}") | |
| else: | |
| # 단일 이미지 처리 (기존 방식) | |
| print(f"🔍 prepare_mm_inputs - 단일 이미지 처리") | |
| # pixel_values가 이미 처리된 특징 텐서인 경우 (다중 이미지 결합) | |
| if hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 2: | |
| print(f"🔍 prepare_mm_inputs - 처리된 특징 텐서 감지, 다중 이미지로 분리 시도") | |
| # image_metas에서 이미지 개수 확인 | |
| num_images = 0 | |
| if isinstance(image_metas, dict) and "image_token_thw" in image_metas: | |
| num_images = len(image_metas["image_token_thw"]) | |
| print(f"🔍 prepare_mm_inputs - 감지된 이미지 개수: {num_images}") | |
| if num_images > 1: | |
| print(f"🔍 prepare_mm_inputs - {num_images}개 이미지로 분리 처리") | |
| visual_embeds_list = [] | |
| # 각 이미지의 실제 토큰 수 계산 | |
| current_idx = 0 | |
| for i in range(num_images): | |
| print(f"🔍 prepare_mm_inputs - 이미지 {i} 처리 중") | |
| # 각 이미지에 대한 개별 image_metas 생성 | |
| single_image_metas = {} | |
| for key, value_list in image_metas.items(): | |
| if isinstance(value_list, list): | |
| single_image_metas[key] = value_list[i] | |
| else: | |
| single_image_metas[key] = value_list | |
| # image_token_thw에서 실제 토큰 수 계산 | |
| if "image_token_thw" in single_image_metas: | |
| token_thw = single_image_metas["image_token_thw"] | |
| if isinstance(token_thw, (list, tuple)): | |
| tokens_per_image = token_thw[0] * token_thw[1] * token_thw[2] | |
| else: | |
| tokens_per_image = token_thw[0] * token_thw[1] * token_thw[2] | |
| print(f"🔍 prepare_mm_inputs - 이미지 {i} 실제 토큰 수: {tokens_per_image}") | |
| else: | |
| # 기본값 사용 | |
| tokens_per_image = pixel_values.shape[0] // num_images | |
| print(f"🔍 prepare_mm_inputs - 이미지 {i} 기본 토큰 수: {tokens_per_image}") | |
| # pixel_values에서 해당 이미지 부분 추출 | |
| start_idx = current_idx | |
| end_idx = current_idx + tokens_per_image | |
| single_pixel_values = pixel_values[start_idx:end_idx] | |
| print(f"🔍 prepare_mm_inputs - 이미지 {i} 특징 형태: {single_pixel_values.shape}") | |
| # 개별 이미지 처리 | |
| single_visual_embeds = self.forward_and_project_vision( | |
| single_pixel_values, single_image_metas | |
| ) | |
| visual_embeds_list.append(single_visual_embeds) | |
| current_idx += tokens_per_image | |
| # 모든 이미지의 visual embeds를 연결 | |
| flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0) | |
| print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}") | |
| else: | |
| # 단일 이미지 처리 | |
| print(f"🔍 prepare_mm_inputs - 단일 이미지로 처리") | |
| flattened_visual_embeds = self.forward_and_project_vision( | |
| pixel_values, image_metas | |
| ) | |
| # pixel_values가 배치 형태인 경우 개별 이미지로 분리 | |
| elif hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 4 and pixel_values.shape[0] > 1: | |
| print(f"🔍 prepare_mm_inputs - 배치 형태 감지, 개별 이미지로 분리") | |
| visual_embeds_list = [] | |
| for i in range(pixel_values.shape[0]): | |
| print(f"🔍 prepare_mm_inputs - 배치 이미지 {i} 처리 중") | |
| # 각 이미지에 대한 개별 image_metas 생성 | |
| single_image_metas = {} | |
| for key, value_list in image_metas.items(): | |
| if isinstance(value_list, list): | |
| single_image_metas[key] = value_list[i] | |
| else: | |
| single_image_metas[key] = value_list | |
| # 개별 이미지 처리 | |
| if isinstance(pixel_values, list): | |
| single_pixel_values = pixel_values[i:i+1] | |
| else: | |
| # pixel_values가 텐서인 경우 | |
| single_pixel_values = pixel_values[i:i+1] | |
| single_visual_embeds = self.forward_and_project_vision( | |
| single_pixel_values, single_image_metas | |
| ) | |
| visual_embeds_list.append(single_visual_embeds) | |
| # 모든 이미지의 visual embeds를 연결 | |
| flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0) | |
| print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}") | |
| # 각 이미지의 embeds 크기 출력 | |
| for i, embeds in enumerate(visual_embeds_list): | |
| print(f"🔍 prepare_mm_inputs - 이미지 {i} embeds 크기: {embeds.shape}") | |
| else: | |
| # 단일 이미지 처리 | |
| # image_metas가 다중 이미지 정보를 포함하는 경우 첫 번째 이미지 정보만 사용 | |
| if isinstance(image_metas, dict): | |
| single_image_metas = {} | |
| for key, value_list in image_metas.items(): | |
| if isinstance(value_list, list): | |
| single_image_metas[key] = value_list[0] # 첫 번째 이미지 정보 사용 | |
| else: | |
| single_image_metas[key] = value_list | |
| print(f"🔍 prepare_mm_inputs - 단일 이미지 처리, 첫 번째 이미지 정보 사용") | |
| else: | |
| single_image_metas = image_metas | |
| # 단일 이미지 처리 시 pixel_values가 리스트인지 확인 | |
| if isinstance(pixel_values, list): | |
| single_pixel_values = pixel_values[0] # 첫 번째 이미지만 사용 | |
| else: | |
| single_pixel_values = pixel_values | |
| flattened_visual_embeds = self.forward_and_project_vision( | |
| single_pixel_values, single_image_metas | |
| ) | |
| # dtype 일치를 위해 visual_embeds를 text_embeds와 같은 dtype으로 변환 | |
| flattened_visual_embeds = flattened_visual_embeds.to(flattened_text_embeds.dtype) | |
| # visual embeds와 -1 토큰 개수 확인 및 조정 | |
| num_visual_tokens = flattened_visual_embeds.shape[0] | |
| num_neg_one_tokens = (flattened_input_ids == -1).sum().item() | |
| print(f"🔍 prepare_mm_inputs - visual embeds 개수: {num_visual_tokens}") | |
| print(f"🔍 prepare_mm_inputs - -1 토큰 개수: {num_neg_one_tokens}") | |
| if num_visual_tokens != num_neg_one_tokens: | |
| print(f"🔍 prepare_mm_inputs - 토큰 개수 불일치, 조정 필요") | |
| if num_visual_tokens > num_neg_one_tokens: | |
| # visual embeds가 많으면 자르기 | |
| flattened_visual_embeds = flattened_visual_embeds[:num_neg_one_tokens] | |
| print(f"🔍 prepare_mm_inputs - visual embeds 자르기: {num_visual_tokens} -> {num_neg_one_tokens}") | |
| else: | |
| # visual embeds가 적으면 반복해서 사용 | |
| repeat_times = num_neg_one_tokens // num_visual_tokens | |
| remainder = num_neg_one_tokens % num_visual_tokens | |
| if repeat_times > 0: | |
| # visual embeds를 반복 | |
| repeated_embeds = flattened_visual_embeds.repeat(repeat_times, 1) | |
| if remainder > 0: | |
| # 나머지 부분 추가 | |
| remainder_embeds = flattened_visual_embeds[:remainder] | |
| repeated_embeds = torch.cat([repeated_embeds, remainder_embeds], dim=0) | |
| flattened_visual_embeds = repeated_embeds | |
| else: | |
| # visual embeds가 너무 적으면 첫 번째 토큰을 반복 | |
| first_token = flattened_visual_embeds[0:1].repeat(num_neg_one_tokens, 1) | |
| flattened_visual_embeds = first_token | |
| print(f"🔍 prepare_mm_inputs - visual embeds 반복: {num_visual_tokens} -> {num_neg_one_tokens}") | |
| flattened_text_embeds[flattened_input_ids == -1] = flattened_visual_embeds | |
| input_embeds = rearrange( | |
| flattened_text_embeds, "(b l) d -> b l d", b=input_ids.shape[0] | |
| ) | |
| return_inputs = { | |
| "inputs_embeds": input_embeds, | |
| "attention_mask": attention_mask, | |
| } | |
| return return_inputs | |
| def forward( | |
| self, | |
| pixel_values: list[torch.FloatTensor], | |
| image_metas: dict[list], | |
| input_ids: torch.FloatTensor, | |
| seq_length: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| inputs = self.prepare_mm_inputs( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| image_metas=image_metas, | |
| attention_mask=attention_mask, | |
| ) | |
| outputs = self.language_model( | |
| **inputs, | |
| labels=labels, | |
| position_ids=None, | |
| return_dict=return_dict, | |
| output_attentions=self.config.output_attentions, | |
| ) | |
| return outputs | |
| def generate( | |
| self, | |
| pixel_values: torch.FloatTensor = None, | |
| image_metas: dict[list] = None, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| seq_length: Optional[torch.LongTensor] = None, | |
| **generate_kwargs, | |
| ) -> torch.LongTensor: | |
| """ | |
| Overrides `generate` function to be able to use the model as a conditional generator. | |
| Args: | |
| pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): | |
| Input images to be processed. | |
| input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
| The sequence used as a prompt for the generation. | |
| attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
| Mask to avoid performing attention on padding token indices | |
| Returns: | |
| captions (list): A list of strings of length batch_size * num_captions. | |
| """ | |
| if input_ids is None: | |
| return self.language_model.generate(attention_mask=attention_mask, **generate_kwargs) | |
| if pixel_values is None: | |
| return self.language_model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) | |
| if ( | |
| image_metas is not None | |
| and image_metas.get("vision_grid_thw") is not None | |
| and isinstance(image_metas.get("vision_grid_thw"), torch.Tensor) | |
| ): | |
| image_metas["vision_grid_thw"] = image_metas["vision_grid_thw"].to(input_ids.device) | |
| inputs = self.prepare_mm_inputs( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| image_metas=image_metas, | |
| attention_mask=attention_mask, | |
| ) | |
| outputs = self.language_model.generate( | |
| **inputs, | |
| **generate_kwargs, | |
| ) | |
| return outputs | |
| def _get_input_dtype(self): | |
| dtype = next(self.vision_model.parameters()).dtype | |
| return dtype | |