lily-math-rag / modeling.py
gbrabbit's picture
Auto commit at 06-2025-08 23:55:35
71d8112
from functools import partial
import logging
import re
from typing import Optional, Tuple, Union, List
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.layers.pos_embed import resample_abs_pos_embed
from timm.models.regnet import RegStage
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModelForCausalLM
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLVisionConfig,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
PatchEmbed,
Qwen2VLPreTrainedModel,
Qwen2VisionTransformerPretrainedModel,
Qwen2VLVisionBlock,
VisionRotaryEmbedding
)
from configuration import KananaVVisualProjectorConfig, KananaVConfig
logger = logging.getLogger("kanana-1.5-v")
def build_pos_embeds(
config: KananaVVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int
):
# pos emb
if config.pos_emb:
# ✨ 수정: num_input_tokens가 음수일 때 기본값 사용
if num_input_tokens <= 0:
num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576
pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
else:
pos_emb = None
return pos_emb
def build_eos_tokens(config: KananaVVisualProjectorConfig, output_hidden_size: int):
# think tokens
num_eos_tokens = config.num_eos_tokens
if num_eos_tokens:
eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size))
nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range)
else:
eos_tokens = None
return eos_tokens
def build_prenorm(config: KananaVVisualProjectorConfig):
if getattr(config, "prenorm", False):
prenorm = LayerNorm(config.encoder_hidden_size)
else:
prenorm = None
return prenorm
def build_mlp(depth: int, hidden_size: int, output_hidden_size: int):
layers = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
layers.append(nn.SiLU())
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
return nn.Sequential(*layers)
class PatchMerge(nn.Module):
def __init__(self, merge_size):
super().__init__()
self.merge_size = merge_size
def forward(self, x, channel_last=False):
if channel_last:
x = rearrange(x, "B H W D -> B D H W")
_, D, H, W = x.shape
# 홀수 차원을 처리하기 위해 패딩 추가
pad_h = (self.merge_size - H % self.merge_size) % self.merge_size
pad_w = (self.merge_size - W % self.merge_size) % self.merge_size
if pad_h > 0 or pad_w > 0:
print(f"🔍 PatchMerge - 패딩 추가: H={H}->{H+pad_h}, W={W}->{W+pad_w}")
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='replicate')
H, W = H + pad_h, W + pad_w
merged_x = rearrange(
x, "B D (H h2) (W w2) -> B (D h2 w2) H W", h2=self.merge_size, w2=self.merge_size
)
return merged_x
class DynamicCAbstractor(nn.Module):
"""Dynamic C-Abstractor based on RegBlock"""
def __init__(self, config: KananaVVisualProjectorConfig, num_input_tokens: int):
super().__init__()
self.config = config
# ✨ 수정: num_input_tokens가 음수일 때 기본값 설정
if num_input_tokens <= 0:
num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576
self.num_input_tokens = num_input_tokens
# ✨ 추가: 누락된 속성들 설정
self.merge_size = getattr(config, 'merge_size', 2)
self.pos_emb_size = getattr(config, 'pos_emb_size', 576)
# ✨ 최적화: 모든 레이어를 bfloat16으로 초기화
self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size)
if self.pos_emb is not None:
self.pos_emb.data = self.pos_emb.data.to(torch.bfloat16)
self.eos_tokens = build_eos_tokens(config, config.output_hidden_size)
if self.eos_tokens is not None:
self.eos_tokens.data = self.eos_tokens.data.to(torch.bfloat16)
self.prenorm = build_prenorm(config)
if self.prenorm is not None:
self.prenorm = self.prenorm.to(torch.bfloat16)
# ✨ 수정: build_net에서 self.net과 self.readout 설정
self.build_net()
# ✨ 최적화: net 레이어들을 bfloat16으로 변환
if hasattr(self, 'net'):
if isinstance(self.net, nn.ModuleList):
for layer in self.net:
layer = layer.to(torch.bfloat16)
for module in layer.modules():
if hasattr(module, 'weight'):
module.weight.data = module.weight.data.to(torch.bfloat16)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data.to(torch.bfloat16)
else:
# self.net이 단일 모듈인 경우
self.net = self.net.to(torch.bfloat16)
for module in self.net.modules():
if hasattr(module, 'weight'):
module.weight.data = module.weight.data.to(torch.bfloat16)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data.to(torch.bfloat16)
# ✨ 최적화: readout 레이어를 bfloat16으로 변환
if hasattr(self, 'readout'):
self.readout = self.readout.to(torch.bfloat16)
for module in self.readout.modules():
if hasattr(module, 'weight'):
module.weight.data = module.weight.data.to(torch.bfloat16)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data.to(torch.bfloat16)
def build_net(self):
encoder_hidden_size = self.config.encoder_hidden_size
hidden_size = self.config.hidden_size
output_hidden_size = self.config.output_hidden_size
depth = self.config.depth
mlp_depth = self.config.mlp_depth
RegBlock = partial(
RegStage,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
s1 = RegBlock(
depth,
encoder_hidden_size,
hidden_size,
)
sampler = PatchMerge(merge_size=self.merge_size)
s2 = RegBlock(
depth,
self.merge_size**2 * hidden_size,
hidden_size,
)
if depth:
self.net = nn.ModuleList([s1, sampler, s2])
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
else:
self.net = sampler
self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)
def forward(self, flattened_visual_embeds, grid_thw, **unused_kwargs):
n_token_loc = torch.prod(grid_thw, dim=1)
split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())
flattened_visual_embeds = []
for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
T, H, W = _grid_thw
assert T == 1, "T must be 1. Video is not supported yet."
reshaped_visual_embeds = rearrange(
_visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
)
# remove temporal dim
reshaped_visual_embeds = reshaped_visual_embeds[:, 0]
if self.prenorm is not None:
reshaped_visual_embeds = self.prenorm(reshaped_visual_embeds)
if self.pos_emb is not None:
# interpolate pos emb and add to visual embeds
print(f"🔍 abstractor - pos_emb 형태: {self.pos_emb.shape}")
print(f"🔍 abstractor - reshaped_visual_embeds 형태: {reshaped_visual_embeds.shape}")
_local_pos_emb = resample_abs_pos_embed(
posemb=self.pos_emb,
old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
new_size=(H, W),
num_prefix_tokens=0,
)
_local_pos_emb = rearrange(
_local_pos_emb,
"1 (h w) d -> 1 h w d",
h=H,
w=W,
)
print(f"🔍 abstractor - _local_pos_emb 형태: {_local_pos_emb.shape}")
# 차원이 맞지 않는 경우 처리
if reshaped_visual_embeds.shape[-1] != _local_pos_emb.shape[-1]:
print(f"🔍 abstractor - 차원 불일치 감지, pos_emb 건너뛰기")
# pos_emb를 건너뛰고 visual_embeds만 사용
else:
reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb
reshaped_visual_embeds = self._forward(
reshaped_visual_embeds,
input_size=(H, W),
)
flattened_visual_embeds.append(reshaped_visual_embeds)
reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
output = BaseModelOutput(last_hidden_state=reshaped_visual_embeds)
return output
def _forward(self, x, input_size):
h, w = input_size
x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
# 입력 채널 수가 맞지 않는 경우 처리
# RegStage의 첫 번째 블록에서 채널 수 확인
try:
if hasattr(self.net[0], 'conv'):
expected_channels = self.net[0].conv.in_channels
elif hasattr(self.net[0], 'blocks') and len(self.net[0].blocks) > 0:
expected_channels = self.net[0].blocks[0].conv1.in_channels
else:
# 기본값 사용
expected_channels = 1280
except:
expected_channels = 1280
actual_channels = x.shape[1]
if actual_channels != expected_channels:
# 선형 변환으로 채널 수 조정
if not hasattr(self, 'channel_adapter'):
# channel_adapter를 bfloat16으로 생성
self.channel_adapter = nn.Linear(actual_channels, expected_channels, dtype=torch.bfloat16).to(x.device)
x = x.permute(0, 2, 3, 1) # (1, d, h, w) -> (1, h, w, d)
# 입력을 bfloat16으로 변환 (한 번만)
if x.dtype != torch.bfloat16:
x = x.to(torch.bfloat16)
x = self.channel_adapter(x) # 채널 수 조정
x = x.permute(0, 3, 1, 2) # (1, h, w, d) -> (1, d, h, w)
# ✨ 최적화: 이미 bfloat16으로 초기화된 레이어들 사용
x = self.net[0](x)
x = self.net[1](x)
x = self.net[2](x)
x = rearrange(x, "1 d h w -> (h w) d")
# ✨ 최적화: 이미 bfloat16으로 초기화된 readout 사용
x = self.readout(x)
return x
class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
config_class = Qwen2VLVisionConfig
_no_split_modules = ["Qwen2VLVisionBlock"]
def __init__(self, config) -> None:
Qwen2VLPreTrainedModel.__init__(self, config)
self.spatial_merge_size = config.spatial_merge_size
self.gradient_checkpointing = False
self.patch_embed = PatchEmbed(
patch_size=config.patch_size,
temporal_patch_size=config.temporal_patch_size,
in_channels=config.in_channels,
embed_dim=config.embed_dim,
)
head_dim = config.embed_dim // config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
)
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
assert return_dict, "Only return_dict=True is supported."
encoder_states = () if output_hidden_states else None
hidden_states = self.patch_embed(pixel_values)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = emb.cos(), emb.sin()
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
blk.__call__,
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
use_reentrant=False,
)
else:
layer_outputs = blk(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
def get_num_tokens(self):
return -1
class KananaVPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = KananaVConfig
base_model_prefix = "kanana-1.5-v"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = False
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = [
"CustomQwen2VLVE",
"DynamicCAbstractor",
"LlamaForCausalLM",
"Parameter",
]
def _init_weights(self, module):
"""Initialize the weights"""
if (
isinstance(module, nn.Conv2d)
or isinstance(module, nn.Embedding)
or isinstance(module, nn.Linear)
):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Parameter):
raise ValueError()
class KananaVForConditionalGeneration(KananaVPreTrainedModel):
config_class = KananaVConfig
def __init__(self, config: KananaVConfig):
super().__init__(config)
logger.info("Build vision model ...")
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
logger.info("Build projector ...")
self.abstractor = DynamicCAbstractor(config.projector_config,
num_input_tokens=self.vision_model.get_num_tokens())
logger.info("Build language model ...")
self.language_model = LlamaForCausalLM._from_config(config=config.text_config)
self.post_init()
def forward_vision(self, pixel_values: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None):
# ✨ 핵심 수정: pixel_values가 리스트일 경우와 텐서일 경우를 모두 처리
if isinstance(pixel_values, list):
# 다중 이미지: 각 이미지를 처리하여 결과를 합침
visual_features_list = []
for i, pv in enumerate(pixel_values):
single_image_metas = {k: v[i] for k, v in image_metas.items()}
# grid_thw 처리 수정
vision_grid_thw = single_image_metas["vision_grid_thw"]
if isinstance(vision_grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(vision_grid_thw)]).to(pv.device)
else:
grid_thw = torch.tensor([vision_grid_thw]).to(pv.device)
# ✨ 최적화: 불필요한 dtype 변환 제거
v_outputs = self.vision_model(
pixel_values=pv.unsqueeze(0),
grid_thw=grid_thw,
return_dict=True, output_hidden_states=True
)
layer_index = self.config.projector_config.feature_layer_index
visual_features_list.append(self._get_visual_feature_at(v_outputs.hidden_states, layer_index))
return visual_features_list # 리스트 형태로 반환
else:
# 단일 이미지 - 이미 분리된 특징 텐서
# grid_thw가 리스트인 경우 첫 번째 요소 사용
grid_thw = image_metas["vision_grid_thw"]
if isinstance(grid_thw, list):
grid_thw = grid_thw[0]
# grid_thw를 텐서로 변환
if not isinstance(grid_thw, torch.Tensor):
if isinstance(grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(grid_thw)])
else:
grid_thw = torch.tensor([grid_thw])
# 디바이스 정보 추가
if hasattr(pixel_values, 'device'):
grid_thw = grid_thw.to(pixel_values.device)
# pixel_values가 2D 특징 텐서인 경우 vision_model을 통해 처리
if len(pixel_values.shape) == 2:
# 2D 특징 텐서를 vision_model이 처리할 수 있는 형태로 변환
# 다중 이미지와 동일한 방식으로 처리하되, 올바른 형태로 변환
# pixel_values를 (1, 3, H, W) 형태로 재구성
# 다중 이미지에서 사용하는 방식과 동일하게 처리
if len(pixel_values.shape) == 2:
# 2D 텐서를 vision_model이 처리할 수 있는 형태로 변환
# 다중 이미지에서는 이미 올바른 형태로 전달되므로 동일하게 처리
# ✨ 최적화: 불필요한 dtype 변환 제거
v_outputs = self.vision_model(
pixel_values=pixel_values,
grid_thw=grid_thw,
return_dict=True, output_hidden_states=True
)
layer_index = self.config.projector_config.feature_layer_index
return self._get_visual_feature_at(v_outputs.hidden_states, layer_index)
else:
return pixel_values
# ✨ 최적화: 불필요한 dtype 변환 제거
v_outputs = self.vision_model(
pixel_values=pixel_values,
grid_thw=grid_thw,
return_dict=True, output_hidden_states=True
)
layer_index = self.config.projector_config.feature_layer_index
return self._get_visual_feature_at(v_outputs.hidden_states, layer_index)
def forward_projector(self, visual_features: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None):
print(f"🔍 forward_projector - visual_features 형태: {visual_features.shape if hasattr(visual_features, 'shape') else type(visual_features)}")
# ✨ 핵심 수정: visual_features가 리스트일 경우 처리
if isinstance(visual_features, list):
print(f"🔍 forward_projector - 리스트 형태 처리")
visual_embeds_list = []
for i, vf in enumerate(visual_features):
single_image_metas = {k: v[i] for k, v in image_metas.items()}
# grid_thw 처리 수정
vision_grid_thw = single_image_metas["vision_grid_thw"]
if isinstance(vision_grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(vision_grid_thw)]).to(vf.device)
else:
grid_thw = torch.tensor([vision_grid_thw]).to(vf.device)
print(f"🔍 forward_projector - 이미지 {i} 처리 중")
print(f"🔍 forward_projector - 이미지 {i} 특징 형태: {vf.shape}")
print(f"🔍 forward_projector - 이미지 {i} grid_thw: {grid_thw}")
visual_embeds = self.abstractor(vf, grid_thw=grid_thw)["last_hidden_state"]
print(f"🔍 forward_projector - 이미지 {i} visual_embeds 형태: {visual_embeds.shape}")
visual_embeds_list.append(visual_embeds)
return torch.cat(visual_embeds_list, dim=0) # 최종적으로 하나의 텐서로 합쳐서 반환
else:
# 단일 이미지
print(f"🔍 forward_projector - 단일 텐서 처리")
# visual_features가 이미 처리된 특징 텐서인 경우
if len(visual_features.shape) == 2:
print(f"🔍 forward_projector - 이미 처리된 특징 텐서 감지")
print(f"🔍 forward_projector - 특징 텐서 형태: {visual_features.shape}")
# grid_thw가 리스트인 경우 첫 번째 요소 사용
grid_thw = image_metas["vision_grid_thw"]
if isinstance(grid_thw, list):
grid_thw = grid_thw[0]
# grid_thw를 텐서로 변환
if not isinstance(grid_thw, torch.Tensor):
if isinstance(grid_thw, (list, tuple)):
# 튜플을 리스트로 변환하여 텐서 생성
grid_thw = torch.tensor([list(grid_thw)])
else:
grid_thw = torch.tensor([grid_thw])
# 디바이스 정보 추가
if hasattr(visual_features, 'device'):
grid_thw = grid_thw.to(visual_features.device)
print(f"🔍 forward_projector - grid_thw: {grid_thw}")
print(f"🔍 forward_projector - grid_thw 계산된 토큰 수: {torch.prod(grid_thw, dim=1)}")
print(f"🔍 forward_projector - 실제 특징 텐서 토큰 수: {visual_features.shape[0]}")
# grid_thw가 실제 토큰 수와 맞지 않는 경우 수정
actual_tokens = visual_features.shape[0]
if torch.prod(grid_thw, dim=1).item() != actual_tokens:
print(f"🔍 forward_projector - grid_thw 수정 필요")
# 실제 토큰 수에 맞는 grid_thw 계산
# 이미지의 실제 비율을 고려하여 계산
T = 1
# 이미지 메타데이터에서 실제 크기 정보 사용
if 'hw_orig_resolution' in image_metas:
orig_h, orig_w = image_metas['hw_orig_resolution']
if isinstance(orig_h, list):
orig_h = orig_h[0] if isinstance(orig_h[0], (int, float)) else orig_h[0][0]
if isinstance(orig_w, list):
orig_w = orig_w[0] if isinstance(orig_w[0], (int, float)) else orig_w[0][0]
# 실제 비율을 유지하면서 토큰 수에 맞게 조정
aspect_ratio = orig_w / orig_h
H = int((actual_tokens / aspect_ratio) ** 0.5)
W = int(actual_tokens / H)
# 토큰 수가 맞지 않으면 조정
while H * W != actual_tokens and H > 1 and W > 1:
if H * W > actual_tokens:
H -= 1
W = int(actual_tokens / H)
else:
W += 1
H = int(actual_tokens / W)
else:
# 기본값 사용
H = int(actual_tokens ** 0.5)
W = actual_tokens // H
if actual_tokens % H != 0:
W += 1
corrected_grid_thw = torch.tensor([[T, H, W]])
print(f"🔍 forward_projector - 수정된 grid_thw: {corrected_grid_thw}")
print(f"🔍 forward_projector - 수정된 토큰 수: {torch.prod(corrected_grid_thw, dim=1)}")
# 토큰 수가 맞지 않는 경우 패딩 또는 자르기
expected_tokens = torch.prod(corrected_grid_thw, dim=1).item()
if expected_tokens > actual_tokens:
# 패딩
padding_size = expected_tokens - actual_tokens
padding = torch.zeros(padding_size, visual_features.shape[1], device=visual_features.device)
visual_features = torch.cat([visual_features, padding], dim=0)
print(f"🔍 forward_projector - 패딩 추가: {padding_size}개 토큰")
elif expected_tokens < actual_tokens:
# 자르기
visual_features = visual_features[:expected_tokens]
print(f"🔍 forward_projector - 토큰 자르기: {expected_tokens}개로")
grid_thw = corrected_grid_thw
# 특징 텐서를 abstractor에 직접 전달
visual_embeds = self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"]
print(f"🔍 forward_projector - abstractor 출력 형태: {visual_embeds.shape}")
return visual_embeds
else:
# 일반적인 vision model 출력
grid_thw = image_metas["vision_grid_thw"]
return self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"]
def forward_and_project_vision(self, pixel_values, image_metas: Optional[dict] = None):
visual_features = self.forward_vision(pixel_values, image_metas=image_metas)
visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
return visual_embeds
def _get_visual_feature_at(self, v_output, layer_index):
if isinstance(layer_index, list):
visual_features = torch.stack(v_output, dim=1)[:, layer_index] # [B, n_scales, L, dim]
else:
visual_features = v_output[layer_index] # [B, L, dim]
return visual_features
def embed_text_tokens(self, input_ids):
"""Embed input_ids into text_embeds, ignoring media tokens (negative values)."""
input_ids = input_ids.clone()
input_ids[input_ids < 0] = 0
text_embeds = self.language_model.get_input_embeddings()(input_ids)
if hasattr(self.language_model, "transformer") and hasattr(
self.language_model.transformer, "word_embeddings_layernorm"
):
text_embeds = self.language_model.transformer.word_embeddings_layernorm(text_embeds)
return text_embeds
def prepare_mm_inputs(
self,
input_ids: torch.FloatTensor,
pixel_values: Optional[list[torch.FloatTensor]] = None,
image_metas: Optional[dict] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""Prepare multimodal inputs from input_ids and pixel_values."""
if pixel_values is not None:
# pixel_values가 리스트인 경우 각각을 변환
if isinstance(pixel_values, list):
pixel_values = [pv.to(self._get_input_dtype()) for pv in pixel_values]
else:
pixel_values = pixel_values.to(self._get_input_dtype())
if attention_mask is None:
attention_mask = input_ids.new_ones(*input_ids.shape)
# Get Text Embeddings
text_embeds = self.embed_text_tokens(input_ids)
flattened_text_embeds = rearrange(text_embeds, "b l d -> (b l) d")
flattened_input_ids = rearrange(input_ids, "b l -> (b l)")
# Get Visual Embeddings
if pixel_values is not None:
print(f"🔍 prepare_mm_inputs - pixel_values 타입: {type(pixel_values)}")
if hasattr(pixel_values, 'shape'):
print(f"🔍 prepare_mm_inputs - pixel_values 형태: {pixel_values.shape}")
if isinstance(pixel_values, list):
print(f"🔍 prepare_mm_inputs - pixel_values 길이: {len(pixel_values)}")
# 다중 이미지 처리: 각 이미지를 개별적으로 처리
if isinstance(pixel_values, list) and len(pixel_values) > 1:
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 시작")
visual_embeds_list = []
for i, single_pixel_values in enumerate(pixel_values):
print(f"🔍 prepare_mm_inputs - 이미지 {i} 처리 중")
# 각 이미지에 대한 개별 image_metas 생성
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[i]
else:
single_image_metas[key] = value_list
# 개별 이미지 처리
single_visual_embeds = self.forward_and_project_vision(
single_pixel_values.unsqueeze(0), single_image_metas
)
visual_embeds_list.append(single_visual_embeds)
# 모든 이미지의 visual embeds를 연결
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}")
else:
# 단일 이미지 처리 (기존 방식)
print(f"🔍 prepare_mm_inputs - 단일 이미지 처리")
# pixel_values가 이미 처리된 특징 텐서인 경우 (다중 이미지 결합)
if hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 2:
print(f"🔍 prepare_mm_inputs - 처리된 특징 텐서 감지, 다중 이미지로 분리 시도")
# image_metas에서 이미지 개수 확인
num_images = 0
if isinstance(image_metas, dict) and "image_token_thw" in image_metas:
num_images = len(image_metas["image_token_thw"])
print(f"🔍 prepare_mm_inputs - 감지된 이미지 개수: {num_images}")
if num_images > 1:
print(f"🔍 prepare_mm_inputs - {num_images}개 이미지로 분리 처리")
visual_embeds_list = []
# 각 이미지의 실제 토큰 수 계산
current_idx = 0
for i in range(num_images):
print(f"🔍 prepare_mm_inputs - 이미지 {i} 처리 중")
# 각 이미지에 대한 개별 image_metas 생성
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[i]
else:
single_image_metas[key] = value_list
# image_token_thw에서 실제 토큰 수 계산
if "image_token_thw" in single_image_metas:
token_thw = single_image_metas["image_token_thw"]
if isinstance(token_thw, (list, tuple)):
tokens_per_image = token_thw[0] * token_thw[1] * token_thw[2]
else:
tokens_per_image = token_thw[0] * token_thw[1] * token_thw[2]
print(f"🔍 prepare_mm_inputs - 이미지 {i} 실제 토큰 수: {tokens_per_image}")
else:
# 기본값 사용
tokens_per_image = pixel_values.shape[0] // num_images
print(f"🔍 prepare_mm_inputs - 이미지 {i} 기본 토큰 수: {tokens_per_image}")
# pixel_values에서 해당 이미지 부분 추출
start_idx = current_idx
end_idx = current_idx + tokens_per_image
single_pixel_values = pixel_values[start_idx:end_idx]
print(f"🔍 prepare_mm_inputs - 이미지 {i} 특징 형태: {single_pixel_values.shape}")
# 개별 이미지 처리
single_visual_embeds = self.forward_and_project_vision(
single_pixel_values, single_image_metas
)
visual_embeds_list.append(single_visual_embeds)
current_idx += tokens_per_image
# 모든 이미지의 visual embeds를 연결
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}")
else:
# 단일 이미지 처리
print(f"🔍 prepare_mm_inputs - 단일 이미지로 처리")
flattened_visual_embeds = self.forward_and_project_vision(
pixel_values, image_metas
)
# pixel_values가 배치 형태인 경우 개별 이미지로 분리
elif hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 4 and pixel_values.shape[0] > 1:
print(f"🔍 prepare_mm_inputs - 배치 형태 감지, 개별 이미지로 분리")
visual_embeds_list = []
for i in range(pixel_values.shape[0]):
print(f"🔍 prepare_mm_inputs - 배치 이미지 {i} 처리 중")
# 각 이미지에 대한 개별 image_metas 생성
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[i]
else:
single_image_metas[key] = value_list
# 개별 이미지 처리
if isinstance(pixel_values, list):
single_pixel_values = pixel_values[i:i+1]
else:
# pixel_values가 텐서인 경우
single_pixel_values = pixel_values[i:i+1]
single_visual_embeds = self.forward_and_project_vision(
single_pixel_values, single_image_metas
)
visual_embeds_list.append(single_visual_embeds)
# 모든 이미지의 visual embeds를 연결
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
print(f"🔍 prepare_mm_inputs - 다중 이미지 처리 완료, 연결된 embeds 크기: {flattened_visual_embeds.shape}")
# 각 이미지의 embeds 크기 출력
for i, embeds in enumerate(visual_embeds_list):
print(f"🔍 prepare_mm_inputs - 이미지 {i} embeds 크기: {embeds.shape}")
else:
# 단일 이미지 처리
# image_metas가 다중 이미지 정보를 포함하는 경우 첫 번째 이미지 정보만 사용
if isinstance(image_metas, dict):
single_image_metas = {}
for key, value_list in image_metas.items():
if isinstance(value_list, list):
single_image_metas[key] = value_list[0] # 첫 번째 이미지 정보 사용
else:
single_image_metas[key] = value_list
print(f"🔍 prepare_mm_inputs - 단일 이미지 처리, 첫 번째 이미지 정보 사용")
else:
single_image_metas = image_metas
# 단일 이미지 처리 시 pixel_values가 리스트인지 확인
if isinstance(pixel_values, list):
single_pixel_values = pixel_values[0] # 첫 번째 이미지만 사용
else:
single_pixel_values = pixel_values
flattened_visual_embeds = self.forward_and_project_vision(
single_pixel_values, single_image_metas
)
# dtype 일치를 위해 visual_embeds를 text_embeds와 같은 dtype으로 변환
flattened_visual_embeds = flattened_visual_embeds.to(flattened_text_embeds.dtype)
# visual embeds와 -1 토큰 개수 확인 및 조정
num_visual_tokens = flattened_visual_embeds.shape[0]
num_neg_one_tokens = (flattened_input_ids == -1).sum().item()
print(f"🔍 prepare_mm_inputs - visual embeds 개수: {num_visual_tokens}")
print(f"🔍 prepare_mm_inputs - -1 토큰 개수: {num_neg_one_tokens}")
if num_visual_tokens != num_neg_one_tokens:
print(f"🔍 prepare_mm_inputs - 토큰 개수 불일치, 조정 필요")
if num_visual_tokens > num_neg_one_tokens:
# visual embeds가 많으면 자르기
flattened_visual_embeds = flattened_visual_embeds[:num_neg_one_tokens]
print(f"🔍 prepare_mm_inputs - visual embeds 자르기: {num_visual_tokens} -> {num_neg_one_tokens}")
else:
# visual embeds가 적으면 반복해서 사용
repeat_times = num_neg_one_tokens // num_visual_tokens
remainder = num_neg_one_tokens % num_visual_tokens
if repeat_times > 0:
# visual embeds를 반복
repeated_embeds = flattened_visual_embeds.repeat(repeat_times, 1)
if remainder > 0:
# 나머지 부분 추가
remainder_embeds = flattened_visual_embeds[:remainder]
repeated_embeds = torch.cat([repeated_embeds, remainder_embeds], dim=0)
flattened_visual_embeds = repeated_embeds
else:
# visual embeds가 너무 적으면 첫 번째 토큰을 반복
first_token = flattened_visual_embeds[0:1].repeat(num_neg_one_tokens, 1)
flattened_visual_embeds = first_token
print(f"🔍 prepare_mm_inputs - visual embeds 반복: {num_visual_tokens} -> {num_neg_one_tokens}")
flattened_text_embeds[flattened_input_ids == -1] = flattened_visual_embeds
input_embeds = rearrange(
flattened_text_embeds, "(b l) d -> b l d", b=input_ids.shape[0]
)
return_inputs = {
"inputs_embeds": input_embeds,
"attention_mask": attention_mask,
}
return return_inputs
def forward(
self,
pixel_values: list[torch.FloatTensor],
image_metas: dict[list],
input_ids: torch.FloatTensor,
seq_length: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = self.prepare_mm_inputs(
input_ids=input_ids,
pixel_values=pixel_values,
image_metas=image_metas,
attention_mask=attention_mask,
)
outputs = self.language_model(
**inputs,
labels=labels,
position_ids=None,
return_dict=return_dict,
output_attentions=self.config.output_attentions,
)
return outputs
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor = None,
image_metas: dict[list] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
seq_length: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
"""
Overrides `generate` function to be able to use the model as a conditional generator.
Args:
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
Input images to be processed.
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
if input_ids is None:
return self.language_model.generate(attention_mask=attention_mask, **generate_kwargs)
if pixel_values is None:
return self.language_model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
if (
image_metas is not None
and image_metas.get("vision_grid_thw") is not None
and isinstance(image_metas.get("vision_grid_thw"), torch.Tensor)
):
image_metas["vision_grid_thw"] = image_metas["vision_grid_thw"].to(input_ids.device)
inputs = self.prepare_mm_inputs(
input_ids=input_ids,
pixel_values=pixel_values,
image_metas=image_metas,
attention_mask=attention_mask,
)
outputs = self.language_model.generate(
**inputs,
**generate_kwargs,
)
return outputs
def _get_input_dtype(self):
dtype = next(self.vision_model.parameters()).dtype
return dtype