lily-math-rag / modeling.py
gbrabbit's picture
Auto commit at 06-2025-08 23:55:35
71d8112
raw
history blame
45.4 kB
from functools import partial
import logging
import re
from typing import Optional, Tuple, Union, List
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.layers.pos_embed import resample_abs_pos_embed
from timm.models.regnet import RegStage
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModelForCausalLM
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLVisionConfig,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
PatchEmbed,
Qwen2VLPreTrainedModel,
Qwen2VisionTransformerPretrainedModel,
Qwen2VLVisionBlock,
VisionRotaryEmbedding
)
from configuration import KananaVVisualProjectorConfig, KananaVConfig
logger = logging.getLogger("kanana-1.5-v")
def build_pos_embeds(
config: KananaVVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int
):
# pos emb
if config.pos_emb:
# ✨ 수정: num_input_tokens가 음수일 때 기본값 사용
if num_input_tokens <= 0:
num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576
pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
else:
pos_emb = None
return pos_emb
def build_eos_tokens(config: KananaVVisualProjectorConfig, output_hidden_size: int):
# think tokens
num_eos_tokens = config.num_eos_tokens
if num_eos_tokens:
eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size))
nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range)
else:
eos_tokens = None
return eos_tokens
def build_prenorm(config: KananaVVisualProjectorConfig):
if getattr(config, "prenorm", False):
prenorm = LayerNorm(config.encoder_hidden_size)
else:
prenorm = None
return prenorm
def build_mlp(depth: int, hidden_size: int, output_hidden_size: int):
layers = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
layers.append(nn.SiLU())
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
return nn.Sequential(*layers)
class PatchMerge(nn.Module):
def __init__(self, merge_size):
super().__init__()
self.merge_size = merge_size
def forward(self, x, channel_last=False):
if channel_last:
x = rearrange(x, "B H W D -> B D H W")
_, D, H, W = x.shape
# 홀수 차원을 처리하기 위해 패딩 추가
pad_h = (self.merge_size - H % self.merge_size) % self.merge_size
pad_w = (self.merge_size - W % self.merge_size) % self.merge_size
if pad_h > 0 or pad_w > 0:
print(f"🔍 PatchMerge - 패딩 추가: H={H}->{H+pad_h}, W={W}->{W+pad_w}")
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='replicate')
H, W = H + pad_h, W + pad_w
merged_x = rearrange(
x, "B D (H h2) (W w2) -> B (D h2 w2) H W", h2=self.merge_size, w2=self.merge_size
)
return merged_x
class DynamicCAbstractor(nn.Module):
"""Dynamic C-Abstractor based on RegBlock"""
def __init__(self, config: KananaVVisualProjectorConfig, num_input_tokens: int):
super().__init__()
self.config = config
# ✨ 수정: num_input_tokens가 음수일 때 기본값 설정
if num_input_tokens <= 0:
num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576
self.num_input_tokens = num_input_tokens
# ✨ 추가: 누락된 속성들 설정
self.merge_size = getattr(config, 'merge_size', 2)
self.pos_emb_size = getattr(config, 'pos_emb_size', 576)
# ✨ 최적화: 모든 레이어를 bfloat16으로 초기화
self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size)
if self.pos_emb is not None:
self.pos_emb.data = self.pos_emb.data.to(torch.bfloat16)
self.eos_tokens = build_eos_tokens(config, config.output_hidden_size)
if self.eos_tokens is not None:
self.eos_tokens.data = self.eos_tokens.data.to(torch.bfloat16)
self.prenorm = build_prenorm(config)
if self.prenorm is not None:
self.prenorm = self.prenorm.to(torch.bfloat16)
# ✨ 수정: build_net에서 self.net과 self.readout 설정
self.build_net()
# ✨ 최적화: net 레이어들을 bfloat16으로 변환
if hasattr(self, 'net'):
if isinstance(self.net, nn.ModuleList):
for layer in self.net:
layer = layer.to(torch.bfloat16)
for module in layer.modules():
if hasattr(module, 'weight'):
module.weight.data = module.weight.data.to(torch.bfloat16)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data.to(torch.bfloat16)
else:
# self.net이 단일 모듈인 경우
self.net = self.net.to(torch.bfloat16)
for module in self.net.modules():
if hasattr(module, 'weight'):
module.weight.data = module.weight.data.to(torch.bfloat16)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data.to(torch.bfloat16)
# ✨ 최적화: readout 레이어를 bfloat16으로 변환
if hasattr(self, 'readout'):
self.readout = self.readout.to(torch.bfloat16)
for module in self.readout.modules():
if hasattr(module, 'weight'):
module.weight.data = module.weight.data.to(torch.bfloat16)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data.to(torch.bfloat16)
def build_net(self):
encoder_hidden_size = self.config.encoder_hidden_size
hidden_size = self.config.hidden_size
output_hidden_size = self.config.output_hidden_size
depth = self.config.depth
mlp_depth = self.config.mlp_depth
RegBlock = partial(
RegStage,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
s1 = RegBlock(
depth,
encoder_hidden_size,
hidden_size,
)
sampler = PatchMerge(merge_size=self.merge_size)
s2 = RegBlock(
depth,
self.merge_size**2 * hidden_size,
hidden_size,
)
if depth:
self.net = nn.ModuleList([s1, sampler, s2])
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
else:
self.net = sampler
self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)
def forward(self, flattened_visual_embeds, grid_thw, **unused_kwargs):
n_token_loc = torch.prod(grid_thw, dim=1)
split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())
flattened_visual_embeds = []
for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
T, H, W = _grid_thw
assert T == 1, "T must be 1. Video is not supported yet."
reshaped_visual_embeds = rearrange(
_visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
)
# remove temporal dim
reshaped_visual_embeds = reshaped_visual_embeds[:, 0]
if self.prenorm is not None:
reshaped_visual_embeds = self.prenorm(reshaped_visual_embeds)
if self.pos_emb is not None:
# interpolate pos emb and add to visual embeds
print(f"🔍 abstractor - pos_emb 형태: {self.pos_emb.shape}")
print(f"🔍 abstractor - reshaped_visual_embeds 형태: {reshaped_visual_embeds.shape}")
_local_pos_emb = resample_abs_pos_embed(
posemb=self.pos_emb,
old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
new_size=(H, W),
num_prefix_tokens=0,
)
_local_pos_emb = rearrange(
_local_pos_emb,
"1 (h w) d -> 1 h w d",
h=H,
w=W,
)
print(f"🔍 abstractor - _local_pos_emb 형태: {_local_pos_emb.shape}")
# 차원이 맞지 않는 경우 처리
if reshaped_visual_embeds.shape[-1] != _local_pos_emb.shape[-1]:
print(f"🔍 abstractor - 차원 불일치 감지, pos_emb 건너뛰기")
# pos_emb를 건너뛰고 visual_embeds만 사용
else:
reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb
reshaped_visual_embeds = self._forward(
reshaped_visual_embeds,
input_size=(H, W),
)
flattened_visual_embeds.append(reshaped_visual_embeds)
reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
output = BaseModelOutput(last_hidden_state=reshaped_visual_embeds)
return output
def _forward(self, x, input_size):
h, w = input_size
x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
# 입력 채널 수가 맞지 않는 경우 처리
# RegStage의 첫 번째 블록에서 채널 수 확인
try:
if hasattr(self.net[0], 'conv'):
expected_channels = self.net[0].conv.in_channels
elif hasattr(self.net[0], 'blocks') and len(self.net[0].blocks) > 0:
expected_channels = self.net[0].blocks[0].conv1.in_channels
else:
# 기본값 사용
expected_channels = 1280
except:
expected_channels = 1280
actual_channels = x.shape[1]
if actual_channels != expected_channels:
# 선형 변환으로 채널 수 조정
if not hasattr(self, 'channel_adapter'):
# channel_adapter를 bfloat16으로 생성
self.channel_adapter = nn.Linear(actual_channels, expected_channels, dtype=torch.bfloat16).to(x.device)
x = x.permute(0, 2, 3, 1) # (1, d, h, w) -> (1, h, w, d)
# 입력을 bfloat16으로 변환 (한 번만)
if x.dtype != torch.bfloat16:
x = x.to(torch.bfloat16)
x = self.channel_adapter(x) # 채널 수 조정
x = x.permute(0, 3, 1, 2) # (1, h, w, d) -> (1, d, h, w)
# ✨ 최적화: 이미 bfloat16으로 초기화된 레이어들 사용
x = self.net[0](x)
x = self.net[1](x)
x = self.net[2](x)
x = rearrange(x, "1 d h w -> (h w) d")
# ✨ 최적화: 이미 bfloat16으로 초기화된 readout 사용
x = self.readout(x)
return x
class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
config_class = Qwen2VLVisionConfig
_no_split_modules = ["Qwen2VLVisionBlock"]
def __init__(self, config) -> None:
Qwen2VLPreTrainedModel.__init__(self, config)
self.spatial_merge_size = config.spatial_merge_size
self.gradient_checkpointing = False
self.patch_embed = PatchEmbed(
patch_size=config.patch_size,
temporal_patch_size=config.temporal_patch_size,
in_channels=config.in_channels,
embed_dim=config.embed_dim,
)
head_dim = config.embed_dim // config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
)
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
assert return_dict, "Only return_dict=True is supported."
encoder_states = () if output_hidden_states else None
hidden_states = self.patch_embed(pixel_values)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = emb.cos(), emb.sin()
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
blk.__call__,
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
use_reentrant=False,
)
else:
layer_outputs = blk(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
def get_num_tokens(self):
return -1
class KananaVPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = KananaVConfig
base_model_prefix = "kanana-1.5-v"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = False
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = [
"CustomQwen2VLVE",
"DynamicCAbstractor",
"LlamaForCausalLM",
"Parameter",
]
def _init_weights(self, module):
"""Initialize the weights"""
if (
isinstance(module, nn.Conv2d)
or isinstance(module, nn.Embedding)
or isinstance(module, nn.Linear)
):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Parameter):
raise ValueError()
class KananaVForConditionalGeneration(KananaVPreTrainedModel):
config_class = KananaVConfig
def __init__(self, config: KananaVConfig):
super().__init__(config)
logger.info("Build vision model ...")
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
logger.info("Build projector ...")
self.abstractor = DynamicCAbstractor(config.projector_config,
num_input_tokens=self.vision_model.get_num_tokens())
logger.info("Build language model ...")
self.language_model = LlamaForCausalLM._from_config(config=config.text_config)
self.post_init()
def forward_vision(self, pixel_values: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None):
# ✨ 핵심 수정: pixel_values가 리스트일 경우와 텐서일 경우를 모두 처리
if isinstance(pixel_values, list):
# 다중 이미지: 각 이미지를 처리하여 결과를 합침
visual_features_list = []
for i, pv in enumerate(pixel_values):
single_image_metas = {k: v[i] for k, v in image_metas.items()}
# grid_thw 처리 수정
vision_grid_thw = single_image_metas["vision_grid_thw"]
if isinstance(vision_grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(vision_grid_thw)]).to(pv.device)
else:
grid_thw = torch.tensor([vision_grid_thw]).to(pv.device)
# ✨ 최적화: 불필요한 dtype 변환 제거
v_outputs = self.vision_model(
pixel_values=pv.unsqueeze(0),
grid_thw=grid_thw,
return_dict=True, output_hidden_states=True
)
layer_index = self.config.projector_config.feature_layer_index
visual_features_list.append(self._get_visual_feature_at(v_outputs.hidden_states, layer_index))
return visual_features_list # 리스트 형태로 반환
else:
# 단일 이미지 - 이미 분리된 특징 텐서
# grid_thw가 리스트인 경우 첫 번째 요소 사용
grid_thw = image_metas["vision_grid_thw"]
if isinstance(grid_thw, list):
grid_thw = grid_thw[0]
# grid_thw를 텐서로 변환
if not isinstance(grid_thw, torch.Tensor):
if isinstance(grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(grid_thw)])
else:
grid_thw = torch.tensor([grid_thw])
# 디바이스 정보 추가
if hasattr(pixel_values, 'device'):
grid_thw = grid_thw.to(pixel_values.device)
# pixel_values가 2D 특징 텐서인 경우 vision_model을 통해 처리
if len(pixel_values.shape) == 2:
# 2D 특징 텐서를 vision_model이 처리할 수 있는 형태로 변환
# 다중 이미지와 동일한 방식으로 처리하되, 올바른 형태로 변환
# pixel_values를 (1, 3, H, W) 형태로 재구성
# 다중 이미지에서 사용하는 방식과 동일하게 처리
if len(pixel_values.shape) == 2:
# 2D 텐서를 vision_model이 처리할 수 있는 형태로 변환
# 다중 이미지에서는 이미 올바른 형태로 전달되므로 동일하게 처리
# ✨ 최적화: 불필요한 dtype 변환 제거
v_outputs = self.vision_model(
pixel_values=pixel_values,
grid_thw=grid_thw,
return_dict=True, output_hidden_states=True
)
layer_index = self.config.projector_config.feature_layer_index
return self._get_visual_feature_at(v_outputs.hidden_states, layer_index)
else:
return pixel_values
# ✨ 최적화: 불필요한 dtype 변환 제거
v_outputs = self.vision_model(
pixel_values=pixel_values,
grid_thw=grid_thw,
return_dict=True, output_hidden_states=True
)
layer_index = self.config.projector_config.feature_layer_index
return self._get_visual_feature_at(v_outputs.hidden_states, layer_index)
def forward_projector(self, visual_features: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None):
print(f"🔍 forward_projector - visual_features 형태: {visual_features.shape if hasattr(visual_features, 'shape') else type(visual_features)}")
# ✨ 핵심 수정: visual_features가 리스트일 경우 처리
if isinstance(visual_features, list):
print(f"🔍 forward_projector - 리스트 형태 처리")
visual_embeds_list = []
for i, vf in enumerate(visual_features):
single_image_metas = {k: v[i] for k, v in image_metas.items()}
# grid_thw 처리 수정
vision_grid_thw = single_image_metas["vision_grid_thw"]
if isinstance(vision_grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(vision_grid_thw)]).to(vf.device)
else:
grid_thw = torch.tensor([vision_grid_thw]).to(vf.device)
print(f"🔍 forward_projector - 이미지 {i} 처리 중")
print(f"🔍 forward_projector - 이미지 {i} 특징 형태: {vf.shape}")
print(f"🔍 forward_projector - 이미지 {i} grid_thw: {grid_thw}")
visual_embeds = self.abstractor(vf, grid_thw=grid_thw)["last_hidden_state"]
print(f"🔍 forward_projector - 이미지 {i} visual_embeds 형태: {visual_embeds.shape}")
visual_embeds_list.append(visual_embeds)
return torch.cat(visual_embeds_list, dim=0) # 최종적으로 하나의 텐서로 합쳐서 반환
else:
# 단일 이미지
print(f"🔍 forward_projector - 단일 텐서 처리")
# visual_features가 이미 처리된 특징 텐서인 경우
if len(visual_features.shape) == 2:
print(f"🔍 forward_projector - 이미 처리된 특징 텐서 감지")
print(f"🔍 forward_projector - 특징 텐서 형태: {visual_features.shape}")
# grid_thw가 리스트인 경우 첫 번째 요소 사용
grid_thw = image_metas["vision_grid_thw"]
if isinstance(grid_thw, list):
grid_thw = grid_thw[0]
# grid_thw를 텐서로 변환
if not isinstance(grid_thw, torch.Tensor):
if isinstance(grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(grid_thw)])
else:
grid_thw = torch.tensor([grid_thw])
# 디바이스 정보 추가
if hasattr(visual_features, 'device'):
grid_thw = grid_thw.to(visual_features.device)
print(f"🔍 forward_projector - grid_thw: {grid_thw}")
print(f"🔍 forward_projector - grid_thw 계산된 토큰 수: {torch.prod(grid_thw, dim=1)}")
print(f"🔍 forward_projector - 실제 특징 텐서 토큰 수: {visual_features.shape[0]}")
# grid_thw가 실제 토큰 수와 맞지 않는 경우 수정
actual_tokens = visual_features.shape[0]
if torch.prod(grid_thw, dim=1).item() != actual_tokens:
print(f"🔍 forward_projector - grid_thw 수정 필요")
# 실제 토큰 수에 맞는 grid_thw 계산
# 이미지의 실제 비율을 고려하여 계산
T = 1
# 이미지 메타데이터에서 실제 크기 정보 사용
if 'hw_orig_resolution' in image_metas:
orig_h, orig_w = image_metas['hw_orig_resolution']
if isinstance(orig_h, list):
orig_h = orig_h[0] if isinstance(orig_h[0], (int, float)) else orig_h[0][0]
if isinstance(orig_w, list):
orig_w = orig_w[0] if isinstance(orig_w[0], (int, float)) else orig_w[0][0]
# 실제 비율을 유지하면서 토큰 수에 맞게 조정
aspect_ratio = orig_w / orig_h
H = int((actual_tokens / aspect_ratio) ** 0.5)
W = int(actual_tokens / H)
# 토큰 수가 맞지 않으면 조정
while H * W != actual_tokens and H > 1 and W > 1:
if H * W > actual_tokens:
H -= 1
W = int(actual_tokens / H)
else:
W += 1
H = int(actual_tokens / W)
else:
# 기본값 사용
H = int(actual_tokens ** 0.5)
W = actual_tokens // H
if actual_tokens % H != 0:
W += 1
corrected_grid_thw = torch.tensor([[T, H, W]])
print(f"🔍 forward_projector - 수정된 grid_thw: {corrected_grid_thw}")
print(f"🔍 forward_projector - 수정된 토큰 수: {torch.prod(corrected_grid_thw, dim=1)}")
# 토큰 수가 맞지 않는 경우 패딩 또는 자르기
expected_tokens = torch.prod(corrected_grid_thw, dim=1).item()
if expected_tokens > actual_tokens:
# 패딩
padding_size = expected_tokens - actual_tokens
padding = torch.zeros(padding_size, visual_features.shape[1], device=visual_features.device)
visual_features = torch.cat([visual_features, padding], dim=0)
print(f"🔍 forward_projector - 패딩 추가: {padding_size}개 토큰")
elif expected_tokens < actual_tokens:
# 자르기
visual_features = visual_features[:expected_tokens]
print(f"🔍 forward_projector - 토큰 자르기: {expected_tokens}개로")
grid_thw = corrected_grid_thw
# 특징 텐서를 abstractor에 직접 전달
visual_embeds = self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"]
print(f"🔍 forward_projector - abstractor 출력 형태: {visual_embeds.shape}")
return visual_embeds
else:
# 일반적인 vision model 출력
grid_thw = image_metas["vision_grid_thw"]
return self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"]
def forward_and_project_vision(self, pixel_values, image_metas: Optional[dict] = None):
visual_features = self.forward_vision(pixel_values, image_metas=image_metas)
visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
return visual_embeds
def _get_visual_feature_at(self, v_output, layer_index):
if isinstance(layer_index, list):
visual_features = torch.stack(v_output, dim=1)[:, layer_index] # [B, n_scales, L, dim]
else:
visual_features = v_output[layer_index] # [B, L, dim]
return visual_features
def embed_text_tokens(self, input_ids):
"""Embed input_ids into text_embeds, ignoring media tokens (negative values)."""
input_ids = input_ids.clone()
input_ids[input_ids < 0] = 0
text_embeds = self.language_model.get_input_embeddings()(input_ids)
if hasattr(self.language_model, "transformer") and hasattr(
self.language_model.transformer, "word_embeddings_layernorm"
):
text_embeds = self.language_model.transformer.word_embeddings_layernorm(text_embeds)
return text_embeds
def prepare_mm_inputs(
self,
input_ids: torch.FloatTensor,
pixel_values: Optional[list[torch.FloatTensor]] = None,
image_metas: Optional[dict] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""Prepare multimodal inputs from input_ids and pixel_values."""
if pixel_values is not None:
# pixel_values가 리스트인 경우 각각을 변환
if isinstance(pixel_values, list):
pixel_values = [pv.to(self._get_input_dtype()) for pv in pixel_values]
else:
pixel_values = pixel_values.to(self._get_input_dtype())
if attention_mask is None:
attention_mask = input_ids.new_ones(*input_ids.shape)
# Get Text Embeddings
text_embeds = self.embed_text_tokens(input_ids)
flattened_text_embeds = rearrange(text_embeds, "b l d -> (b l) d")
flattened_input_ids = rearrange(input_ids, "b l -> (b l)")
# Get Visual Embeddings
if pixel_values is not None:
print(f"🔍 prepare_mm_inputs - pixel_values 타입: {type(pixel_values)}")
if hasattr(pixel_values, 'shape'):
print(f"🔍 prepare_mm_inputs - pixel_values 형태: {pixel_values.shape}")
if isinstance(pixel_values, list):
print(f"🔍 prepare_mm_inputs - pixel_values 길이: {len(pixel_values)}")
# 다중 이미지 처리: 각 이미지를 개별적으로 처리
if isinstance(pixel_values, list) and len(pixel_values) > 1:
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 시작")
visual_embeds_list = []
for i, single_pixel_values in enumerate(pixel_values):
print(f"🔍 prepare_mm_inputs - 이미지 {i} 처리 중")
# 각 이미지에 대한 개별 image_metas 생성
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[i]
else:
single_image_metas[key] = value_list
# 개별 이미지 처리
single_visual_embeds = self.forward_and_project_vision(
single_pixel_values.unsqueeze(0), single_image_metas
)
visual_embeds_list.append(single_visual_embeds)
# 모든 이미지의 visual embeds를 연결
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}")
else:
# 단일 이미지 처리 (기존 방식)
print(f"🔍 prepare_mm_inputs - 단일 이미지 처리")
# pixel_values가 이미 처리된 특징 텐서인 경우 (다중 이미지 결합)
if hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 2:
print(f"🔍 prepare_mm_inputs - 처리된 특징 텐서 감지, 다중 이미지로 분리 시도")
# image_metas에서 이미지 개수 확인
num_images = 0
if isinstance(image_metas, dict) and "image_token_thw" in image_metas:
num_images = len(image_metas["image_token_thw"])
print(f"🔍 prepare_mm_inputs - 감지된 이미지 개수: {num_images}")
if num_images > 1:
print(f"🔍 prepare_mm_inputs - {num_images}개 이미지로 분리 처리")
visual_embeds_list = []
# 각 이미지의 실제 토큰 수 계산
current_idx = 0
for i in range(num_images):
print(f"🔍 prepare_mm_inputs - 이미지 {i} 처리 중")
# 각 이미지에 대한 개별 image_metas 생성
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[i]
else:
single_image_metas[key] = value_list
# image_token_thw에서 실제 토큰 수 계산
if "image_token_thw" in single_image_metas:
token_thw = single_image_metas["image_token_thw"]
if isinstance(token_thw, (list, tuple)):
tokens_per_image = token_thw[0] * token_thw[1] * token_thw[2]
else:
tokens_per_image = token_thw[0] * token_thw[1] * token_thw[2]
print(f"🔍 prepare_mm_inputs - 이미지 {i} 실제 토큰 수: {tokens_per_image}")
else:
# 기본값 사용
tokens_per_image = pixel_values.shape[0] // num_images
print(f"🔍 prepare_mm_inputs - 이미지 {i} 기본 토큰 수: {tokens_per_image}")
# pixel_values에서 해당 이미지 부분 추출
start_idx = current_idx
end_idx = current_idx + tokens_per_image
single_pixel_values = pixel_values[start_idx:end_idx]
print(f"🔍 prepare_mm_inputs - 이미지 {i} 특징 형태: {single_pixel_values.shape}")
# 개별 이미지 처리
single_visual_embeds = self.forward_and_project_vision(
single_pixel_values, single_image_metas
)
visual_embeds_list.append(single_visual_embeds)
current_idx += tokens_per_image
# 모든 이미지의 visual embeds를 연결
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}")
else:
# 단일 이미지 처리
print(f"🔍 prepare_mm_inputs - 단일 이미지로 처리")
flattened_visual_embeds = self.forward_and_project_vision(
pixel_values, image_metas
)
# pixel_values가 배치 형태인 경우 개별 이미지로 분리
elif hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 4 and pixel_values.shape[0] > 1:
print(f"🔍 prepare_mm_inputs - 배치 형태 감지, 개별 이미지로 분리")
visual_embeds_list = []
for i in range(pixel_values.shape[0]):
print(f"🔍 prepare_mm_inputs - 배치 이미지 {i} 처리 중")
# 각 이미지에 대한 개별 image_metas 생성
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[i]
else:
single_image_metas[key] = value_list
# 개별 이미지 처리
if isinstance(pixel_values, list):
single_pixel_values = pixel_values[i:i+1]
else:
# pixel_values가 텐서인 경우
single_pixel_values = pixel_values[i:i+1]
single_visual_embeds = self.forward_and_project_vision(
single_pixel_values, single_image_metas
)
visual_embeds_list.append(single_visual_embeds)
# 모든 이미지의 visual embeds를 연결
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}")
# 각 이미지의 embeds 크기 출력
for i, embeds in enumerate(visual_embeds_list):
print(f"🔍 prepare_mm_inputs - 이미지 {i} embeds 크기: {embeds.shape}")
else:
# 단일 이미지 처리
# image_metas가 다중 이미지 정보를 포함하는 경우 첫 번째 이미지 정보만 사용
if isinstance(image_metas, dict):
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[0] # 첫 번째 이미지 정보 사용
else:
single_image_metas[key] = value_list
print(f"🔍 prepare_mm_inputs - 단일 이미지 처리, 첫 번째 이미지 정보 사용")
else:
single_image_metas = image_metas
# 단일 이미지 처리 시 pixel_values가 리스트인지 확인
if isinstance(pixel_values, list):
single_pixel_values = pixel_values[0] # 첫 번째 이미지만 사용
else:
single_pixel_values = pixel_values
flattened_visual_embeds = self.forward_and_project_vision(
single_pixel_values, single_image_metas
)
# dtype 일치를 위해 visual_embeds를 text_embeds와 같은 dtype으로 변환
flattened_visual_embeds = flattened_visual_embeds.to(flattened_text_embeds.dtype)
# visual embeds와 -1 토큰 개수 확인 및 조정
num_visual_tokens = flattened_visual_embeds.shape[0]
num_neg_one_tokens = (flattened_input_ids == -1).sum().item()
print(f"🔍 prepare_mm_inputs - visual embeds 개수: {num_visual_tokens}")
print(f"🔍 prepare_mm_inputs - -1 토큰 개수: {num_neg_one_tokens}")
if num_visual_tokens != num_neg_one_tokens:
print(f"🔍 prepare_mm_inputs - 토큰 개수 불일치, 조정 필요")
if num_visual_tokens > num_neg_one_tokens:
# visual embeds가 많으면 자르기
flattened_visual_embeds = flattened_visual_embeds[:num_neg_one_tokens]
print(f"🔍 prepare_mm_inputs - visual embeds 자르기: {num_visual_tokens} -> {num_neg_one_tokens}")
else:
# visual embeds가 적으면 반복해서 사용
repeat_times = num_neg_one_tokens // num_visual_tokens
remainder = num_neg_one_tokens % num_visual_tokens
if repeat_times > 0:
# visual embeds를 반복
repeated_embeds = flattened_visual_embeds.repeat(repeat_times, 1)
if remainder > 0:
# 나머지 부분 추가
remainder_embeds = flattened_visual_embeds[:remainder]
repeated_embeds = torch.cat([repeated_embeds, remainder_embeds], dim=0)
flattened_visual_embeds = repeated_embeds
else:
# visual embeds가 너무 적으면 첫 번째 토큰을 반복
first_token = flattened_visual_embeds[0:1].repeat(num_neg_one_tokens, 1)
flattened_visual_embeds = first_token
print(f"🔍 prepare_mm_inputs - visual embeds 반복: {num_visual_tokens} -> {num_neg_one_tokens}")
flattened_text_embeds[flattened_input_ids == -1] = flattened_visual_embeds
input_embeds = rearrange(
flattened_text_embeds, "(b l) d -> b l d", b=input_ids.shape[0]
)
return_inputs = {
"inputs_embeds": input_embeds,
"attention_mask": attention_mask,
}
return return_inputs
def forward(
self,
pixel_values: list[torch.FloatTensor],
image_metas: dict[list],
input_ids: torch.FloatTensor,
seq_length: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = self.prepare_mm_inputs(
input_ids=input_ids,
pixel_values=pixel_values,
image_metas=image_metas,
attention_mask=attention_mask,
)
outputs = self.language_model(
**inputs,
labels=labels,
position_ids=None,
return_dict=return_dict,
output_attentions=self.config.output_attentions,
)
return outputs
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor = None,
image_metas: dict[list] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
seq_length: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
"""
Overrides `generate` function to be able to use the model as a conditional generator.
Args:
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
Input images to be processed.
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
if input_ids is None:
return self.language_model.generate(attention_mask=attention_mask, **generate_kwargs)
if pixel_values is None:
return self.language_model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
if (
image_metas is not None
and image_metas.get("vision_grid_thw") is not None
and isinstance(image_metas.get("vision_grid_thw"), torch.Tensor)
):
image_metas["vision_grid_thw"] = image_metas["vision_grid_thw"].to(input_ids.device)
inputs = self.prepare_mm_inputs(
input_ids=input_ids,
pixel_values=pixel_values,
image_metas=image_metas,
attention_mask=attention_mask,
)
outputs = self.language_model.generate(
**inputs,
**generate_kwargs,
)
return outputs
def _get_input_dtype(self):
dtype = next(self.vision_model.parameters()).dtype
return dtype