HyperCLOVAX-SEED-Omni-8B / patch_vuvlm.py
PenPaperKeyCode's picture
Init
3169f6c
import contextlib
import gc
import inspect
import json
import os
import time
from functools import partial
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from liger_kernel.transformers import (
LigerCrossEntropyLoss,
LigerFusedLinearCrossEntropyLoss,
)
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
from hcxvlm.models.ulysses.sp_utils import (
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_rank,
get_ulysses_sequence_parallel_world_size,
slice_input_tensor,
)
from .configuration_vlm import HCXVisionConfig
from .modeling_vlm import HCXVisionForCausalLM, get_rank
extra_special_tokens = {
"image_token": "<|IMAGE_PAD|>",
"discrete_image_token": "<|DISCRETE_IMAGE_PAD|>",
"discrete_image_unit_0_id": "<|vision00000|>",
"video_token": "<|VIDEO_PAD|>",
"video_audio_token": "<|VIDEO_AUDIO_PAD|>",
"audio_token": "<|AUDIO_PAD|>",
"discrete_audio_token": "<|DISCRETE_AUDIO_PAD|>",
"discrete_audio_unit_0_id": "<|audio0000|>",
}
def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if is_deepspeed_zero3_enabled():
import deepspeed
named_parameters = dict(
module.named_parameters(prefix=prefix[:-1], recurse=False)
)
params_to_gather = [
named_parameters[k]
for k in state_dict.keys()
if k in named_parameters
]
if len(params_to_gather) > 0:
with deepspeed.zero.GatheredParameters(
params_to_gather, modifier_rank=0
):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def load_sharded_checkpoint(
model,
folder,
pick_prefix="",
replace_prefix_list=[],
replace_prefix_dict={},
print_info=True,
):
if folder is None:
return {}
files = os.listdir(folder)
pytorch_bin_files = [
file
for file in files
if file.startswith("pytorch_model") and file.endswith(".bin")
]
safetensor_files = [file for file in files if file.endswith(".safetensors")]
shard_index_file = [file for file in files if file.endswith(".index.json")]
index_present = len(shard_index_file) > 0
index_file = os.path.join(folder, shard_index_file[0]) if index_present else []
is_safetensor = len(safetensor_files) > 0
model_keys = model.state_dict().keys()
if is_safetensor:
from safetensors.torch import load_file
load_function = load_file
shard_files = safetensor_files
else:
load_function = partial(torch.load, map_location="cpu")
shard_files = pytorch_bin_files
if index_present:
with open(index_file, "r", encoding="utf-8") as f:
index = json.load(f)
loaded_keys = index["weight_map"].keys()
if pick_prefix:
loaded_keys = [
k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix)
]
if replace_prefix_list:
for rep_prefix in replace_prefix_list:
loaded_keys = [
k[len(rep_prefix) :] if k.startswith(rep_prefix) else k
for k in loaded_keys
]
if replace_prefix_dict:
for rep_prefix in replace_prefix_dict:
loaded_keys = [
(
k.replace(rep_prefix, replace_prefix_dict[rep_prefix])
if k.startswith(rep_prefix)
else k
)
for k in loaded_keys
]
for i, shard_file in enumerate(shard_files):
state_dict = load_function(os.path.join(folder, shard_file))
if pick_prefix:
state_dict = {
k[len(pick_prefix) :]: v
for k, v in state_dict.items()
if k.startswith(pick_prefix)
}
for rep_prefix in replace_prefix_list:
state_dict = {
k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v
for k, v in state_dict.items()
}
for rep_prefix in replace_prefix_dict:
state_dict = {
(
k.replace(rep_prefix, replace_prefix_dict[rep_prefix])
if k.startswith(rep_prefix)
else k
): v
for k, v in state_dict.items()
}
if is_deepspeed_zero3_enabled():
rank = torch.distributed.get_rank()
print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}")
load_state_dict_into_model(model, state_dict, strict=False)
elif is_fsdp_enabled():
if is_local_dist_rank_0():
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict, strict=False)
if not index_present:
loaded_keys = state_dict.keys()
del state_dict
gc.collect()
missing_keys = [key for key in model_keys if key not in loaded_keys]
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
if get_rank() == 0 and print_info:
print(f"[info] missing_keys: {missing_keys}")
print(f"[info] unexpected_keys: {unexpected_keys}")
return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
class HCXVisionForCausalLM_VU(HCXVisionForCausalLM):
def __init__(self, config, **kwargs):
self.use_liger = kwargs.pop("use_liger", True)
self.use_fused_ce = kwargs.pop("use_fused_ce", True)
self.use_meansum_loss = kwargs.pop("use_meansum_loss", True)
self.use_turnmeansum_loss = kwargs.pop("use_turnmeansum_loss", False)
self.use_sqrtsum_loss = kwargs.pop("use_sqrtsum_loss", False)
use_sum_loss = True if kwargs.pop("use_sum_loss", False) else False
self.sequence_parallel_size = kwargs.pop("sequence_parallel_size", 1)
self.sp_manager = kwargs.pop("sp_manager", None)
self.train_video = kwargs.pop("train_video", False)
assert (
int(self.use_meansum_loss)
+ int(self.use_turnmeansum_loss)
+ int(self.use_sqrtsum_loss)
) <= 1, "use_meansum_loss, use_turnmeansum_loss, use_sqrtsum_loss 중 둘 이상을 동시에 True로 설정할 수 없습니다."
if self.use_meansum_loss or self.use_turnmeansum_loss or self.use_sqrtsum_loss:
self.reduction = "none"
elif use_sum_loss:
self.reduction = "sum"
else:
self.reduction = "mean"
super().__init__(config, **kwargs)
if config.text_config.model_type == "hyperclovax" and self.use_liger:
self.language_model._get_apply_liger_kernel_converter()(
model=self.language_model
)
print("[info] use liger kernel for hcx 24b")
if config.freeze_encoder:
for param in self.vision_model.parameters():
param.requires_grad = False
assert (
all(param.requires_grad for param in self.vision_model.parameters())
== False
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
text_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
discrete_vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
discrete_audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
q_former_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
without_llm: bool = False,
*model_args,
**kwargs,
):
"""
:param pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for LLM(text_model_name_or_path) e.g. /path/to/model/
:param vision_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VisionModule(HyperClova-VisionModule) e.g. /path/to/vision/module/
:param q_former_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VLM e.g. /path/to/vlm/checkpoint/
:param without_llm: Bool: False: init/load llm weight from pre-trained True: init/load llm weight from dummy file
:param model_args:
:param kwargs:
:return:
"""
assert pretrained_model_name_or_path is not None or (
text_model_name_or_path is not None
and vision_model_name_or_path is not None
)
cache_dirpath = kwargs.pop("cache_dirpath", None)
if cache_dirpath is None:
cache_dirpath = "~/.cache"
runtime_only_keys = {
"use_liger",
"use_fused_ce",
"use_meansum_loss",
"use_turnmeansum_loss",
"use_sqrtsum_loss",
"use_sum_loss",
"sequence_parallel_size",
"sp_manager",
"train_video",
}
runtime_kwargs = {}
for k in list(runtime_only_keys):
if k in kwargs:
runtime_kwargs[k] = kwargs.pop(k)
kwargs["vision_model_name_or_path"] = vision_model_name_or_path
kwargs["discrete_vision_model_name_or_path"] = (
discrete_vision_model_name_or_path
)
kwargs["audio_model_name_or_path"] = audio_model_name_or_path
kwargs["discrete_audio_model_name_or_path"] = discrete_audio_model_name_or_path
save_only_vision = (
kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False
)
save_only_qformer = (
kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False
)
save_shard_size = (
kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB"
)
def _purge_runtime_from_config(cfg):
for rk in runtime_only_keys:
if hasattr(cfg, rk):
delattr(cfg, rk)
template_path = "hcxvlm/dataset/chat_template.jinja"
with open(template_path, "r", encoding="utf-8") as f:
chat_template_str = f.read()
if without_llm:
assert pretrained_model_name_or_path is not None and os.path.exists(
pretrained_model_name_or_path
)
dummy_config = HCXVisionConfig.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
*model_args,
**kwargs,
)
_purge_runtime_from_config(dummy_config)
dummy_config.text_config.num_hidden_layers = 0
dummy_config.text_config.num_attention_heads = 1
if isinstance(
dummy_config.vision_model_name_or_path, str
) and os.path.exists(dummy_config.vision_model_name_or_path):
vision_model_name_or_path = dummy_config.vision_model_name_or_path
assert isinstance(vision_model_name_or_path, str) and os.path.exists(
vision_model_name_or_path
), f"# [error] invalid vision_model_name_or_path: {vision_model_name_or_path}"
dummy_config.vision_model_name_or_path = vision_model_name_or_path
dummy_config.vision_config._name_or_path = vision_model_name_or_path
dummy_config.vision_config.vison_pretrained_name_or_path = (
vision_model_name_or_path
)
model = super().from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
without_llm=True,
config=dummy_config,
*model_args,
**{**kwargs, **runtime_kwargs},
)
model.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path
)
model.tokenizer.chat_template = chat_template_str
model.transformer = None
else:
if pretrained_model_name_or_path is not None and (
audio_model_name_or_path is not None
or discrete_audio_model_name_or_path is not None
or discrete_vision_model_name_or_path is not None
):
assert (
audio_model_name_or_path is not None
and discrete_audio_model_name_or_path is not None
and discrete_vision_model_name_or_path is not None
)
print(f"[DEBUG] image stage2 끝난 시점에서 audio 를 stage3 로 붙일때.")
pt_config = HCXVisionConfig.from_pretrained(
pretrained_model_name_or_path
)
_purge_runtime_from_config(pt_config)
config_dict = pt_config.to_dict()
config_dict["audio_model_name_or_path"] = audio_model_name_or_path
config_dict["discrete_audio_model_name_or_path"] = (
discrete_audio_model_name_or_path
)
config_dict["discrete_vision_model_name_or_path"] = (
discrete_vision_model_name_or_path
)
config = HCXVisionConfig.from_dict(config_dict)
print(f"config: {config}")
model = super().from_pretrained(
pretrained_model_name_or_path,
without_llm=False,
config=config,
_fast_init=False,
*model_args,
**kwargs,
)
model.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path
)
model.tokenizer.chat_template = chat_template_str
elif isinstance(q_former_model_name_or_path, str):
config = HCXVisionConfig.from_dict(
{"text_model_name_or_path": text_model_name_or_path, **kwargs}
)
_purge_runtime_from_config(config)
model = super().from_pretrained(
q_former_model_name_or_path,
without_llm=False,
config=config,
_fast_init=False,
*model_args,
**{**kwargs, **runtime_kwargs},
)
model.tokenizer = AutoTokenizer.from_pretrained(
q_former_model_name_or_path
)
model.tokenizer.chat_template = chat_template_str
elif pretrained_model_name_or_path is not None:
config = HCXVisionConfig.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
_purge_runtime_from_config(config)
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
**runtime_kwargs,
)
model.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path
)
model.tokenizer.chat_template = chat_template_str
else:
config = HCXVisionConfig.from_dict(
{"text_model_name_or_path": text_model_name_or_path, **kwargs}
)
_purge_runtime_from_config(config)
model = HCXVisionForCausalLM_VU(
config, *model_args, **{**kwargs, **runtime_kwargs}
)
model.tokenizer = AutoTokenizer.from_pretrained(text_model_name_or_path)
model.tokenizer.chat_template = chat_template_str
model.mm_projector.apply(model._init_weights)
img_start_id = model.tokenizer.encode(
extra_special_tokens["image_token"], add_special_tokens=False
)
assert (
len(img_start_id) == 1
), f'{extra_special_tokens["image_token"]} was not encoded into a single special token. Encoding result: {img_start_id}'
model.config.img_start_id = img_start_id[0]
model.config.image_token_id = img_start_id[0]
video_start_id = model.tokenizer.encode(
extra_special_tokens["video_token"], add_special_tokens=False
)
assert (
len(video_start_id) == 1
), f"video_token was not encoded into a single special token. Encoding result: {video_start_id}"
model.config.video_start_id = video_start_id[0]
model.config.video_token_id = video_start_id[0]
video_audio_start_id = model.tokenizer.encode(
extra_special_tokens["video_audio_token"], add_special_tokens=False
)
assert (
len(video_audio_start_id) == 1
), f"video_audio_token was not encoded into a single special token. Encoding result: {video_audio_start_id}"
model.config.video_audio_start_id = video_audio_start_id[0]
model.config.video_audio_token_id = video_audio_start_id[0]
if (
audio_model_name_or_path is not None
or discrete_audio_model_name_or_path is not None
or discrete_vision_model_name_or_path is not None
):
audio_start_id = model.tokenizer.encode(
extra_special_tokens["audio_token"], add_special_tokens=False
)
assert (
len(audio_start_id) == 1
), f"audio_token was not encoded into a single special token. Encoding result: {audio_start_id}"
model.config.audio_start_id = audio_start_id[0]
model.config.audio_token_id = audio_start_id[0]
discrete_audio_start_id = model.tokenizer.encode(
extra_special_tokens["discrete_audio_token"], add_special_tokens=False
)
assert (
len(discrete_audio_start_id) == 1
), f"discrete_audio_token was not encoded into a single special token. Encoding result: {discrete_audio_start_id}"
model.config.discrete_audio_start_id = discrete_audio_start_id[0]
model.config.discrete_audio_token_id = discrete_audio_start_id[0]
discrete_audio_unit_0_id = model.tokenizer.encode(
extra_special_tokens["discrete_audio_unit_0_id"],
add_special_tokens=False,
)
assert (
len(discrete_audio_unit_0_id) == 1
), f'{extra_special_tokens["discrete_audio_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_audio_unit_0_id}'
model.config.discrete_audio_unit_0_id = discrete_audio_unit_0_id[0]
discrete_image_start_id = model.tokenizer.encode(
extra_special_tokens["discrete_image_token"], add_special_tokens=False
)
assert (
len(discrete_image_start_id) == 1
), f'{extra_special_tokens["discrete_image_token"]} was not encoded into a single special token. Encoding result: {discrete_image_start_id}'
model.config.discrete_image_start_id = discrete_image_start_id[0]
model.config.discrete_image_token_id = discrete_image_start_id[0]
discrete_image_unit_0_id = model.tokenizer.encode(
extra_special_tokens["discrete_image_unit_0_id"],
add_special_tokens=False,
)
assert (
len(discrete_image_unit_0_id) == 1
), f'{extra_special_tokens["discrete_image_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_image_unit_0_id}'
model.config.discrete_image_unit_0_id = discrete_image_unit_0_id[0]
model.save_only_vision = save_only_vision
model.save_only_qformer = save_only_qformer
model.save_shard_size = save_shard_size
if pretrained_model_name_or_path is None or (
pretrained_model_name_or_path is not None
and audio_model_name_or_path is not None
):
vision_model_name_or_path = kwargs.get("vision_model_name_or_path", None)
if vision_model_name_or_path is not None:
load_sharded_checkpoint(model.vision_model, vision_model_name_or_path)
if get_rank() == 0:
print("[info] vision model loading complete")
discrete_vision_model_name_or_path = kwargs.get(
"discrete_vision_model_name_or_path", None
)
if discrete_vision_model_name_or_path is not None:
model.discrete_vision_model.load_state_dict(
torch.load(
discrete_vision_model_name_or_path,
map_location=model.device,
weights_only=False,
)["model"]["sd"],
strict=True,
)
if get_rank() == 0:
print("[info] discrete vision model loading complete")
audio_model_name_or_path = kwargs.get("audio_model_name_or_path", None)
if audio_model_name_or_path is not None:
load_sharded_checkpoint(model.audio_model, audio_model_name_or_path)
if get_rank() == 0:
print("[info] audio model loading complete")
discrete_audio_model_name_or_path = kwargs.get(
"discrete_audio_model_name_or_path", None
)
if discrete_audio_model_name_or_path is not None:
model.discrete_audio_model.load_state_dict(
torch.load(
discrete_audio_model_name_or_path,
map_location=model.device,
weights_only=False,
),
strict=True,
)
if get_rank() == 0:
print("[info] discrete audio model loading complete")
if text_model_name_or_path is not None:
load_sharded_checkpoint(model.language_model, text_model_name_or_path)
if get_rank() == 0:
print("[info] text model loading complete")
if isinstance(q_former_model_name_or_path, str):
assert Path(
q_former_model_name_or_path
).exists(), f"# [error] given q_former_name_or_path not exist: {q_former_model_name_or_path}"
load_result = load_sharded_checkpoint(
model,
q_former_model_name_or_path,
replace_prefix_dict={
"vision_model.image_encoder.model.vision_tower": "vision_model",
"model": "language_model.model",
"lm_head.weight": "language_model.lm_head.weight",
},
print_info=False,
)
if get_rank() == 0:
missing_keys_summary = dict()
for key in load_result["missing_keys"]:
if key.split(".")[0] in missing_keys_summary:
missing_keys_summary[key.split(".")[0]] += 1
else:
missing_keys_summary[key.split(".")[0]] = 1
print(f"[info] missing_keys summary : {missing_keys_summary}")
print("[info] q_former model loading complete")
config: HCXVisionConfig = model.config
if config.model_type != "vlm":
model.config.model_type = "vlm"
return model
def _pad_sequence_for_sp(
self,
inputs_embeds: torch.Tensor,
labels: Optional[torch.Tensor],
sp_world_size: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Ensure sequence length is divisible by the SP group size by padding on the sequence dimension.
Returns the possibly padded (inputs_embeds, labels).
"""
batch_size, seqlen, hidden_size = inputs_embeds.shape
remainder = seqlen % sp_world_size
if remainder != 0:
print(
f"[info] Padding sequence dimension to make it divisible by {sp_world_size}"
)
pad_len = sp_world_size - remainder
pad_embeds = torch.zeros(
(batch_size, pad_len, hidden_size),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1)
if labels is not None:
ignore_index = getattr(self.config, "ignore_index", -100)
pad_labels = torch.full(
(batch_size, pad_len),
fill_value=ignore_index,
dtype=labels.dtype,
device=labels.device,
)
labels = torch.cat([labels, pad_labels], dim=1)
return inputs_embeds, labels
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
discrete_pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
image_sizes: Optional[List[List[List[int]]]] = None,
mm_query_lengths: Optional[List[List[int]]] = None,
non_mm_query_lengths: Optional[List[List[int]]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
video_audio_values: Optional[torch.FloatTensor] = None,
video_audio_masks: Optional[torch.FloatTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None,
audio_masks: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
:param pixel_values: List of List of 4D tensor (torch.float32)
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
:param past_key_values: None
:param inputs_embeds: None
:param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX
:param use_cache: None
:param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
:param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
:param return_dict: Optional[bool] : True - return dict, Fasle - return tensor
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
In cases where a sample contains no images, a single dummy image is included.
:param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
In cases where a sample does not contain any images, an empty list is included.
:param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample.
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
:first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch.
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video.
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:return:
"""
if self.sp_manager is not None and self.train_video:
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
sp_rank = get_ulysses_sequence_parallel_rank(sp_group)
sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group)
if sp_rank == 0:
payload = {
"input_ids": input_ids,
"labels": labels,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
"video_audio_values": video_audio_values,
"video_audio_masks": video_audio_masks,
}
else:
payload = {
"input_ids": None,
"labels": None,
"pixel_values": None,
"image_grid_thw": None,
"pixel_values_videos": None,
"video_grid_thw": None,
"video_audio_values": None,
"video_audio_masks": None,
}
obj_list = [payload]
src_global_rank = dist.get_global_rank(sp_group, 0)
dist.broadcast_object_list(
obj_list, src=src_global_rank, group=sp_group
)
payload = obj_list[0]
if sp_rank != 0:
device = input_ids.device
input_ids = payload["input_ids"]
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.to(device)
labels = payload["labels"]
if isinstance(labels, torch.Tensor):
labels = labels.to(device)
image_grid_thw = payload["image_grid_thw"]
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.to(device)
pixel_values_videos = payload["pixel_values_videos"]
if isinstance(pixel_values_videos, torch.Tensor):
pixel_values_videos = pixel_values_videos.to(device)
video_grid_thw = payload["video_grid_thw"]
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.to(device)
video_audio_values = payload["video_audio_values"]
if isinstance(video_audio_values, torch.Tensor):
video_audio_values = video_audio_values.to(device)
video_audio_masks = payload["video_audio_masks"]
if isinstance(video_audio_masks, torch.Tensor):
video_audio_masks = video_audio_masks.to(device)
pixel_values = payload["pixel_values"]
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.to(device)
attention_mask = None
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.vision_config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.vision_config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None and past_key_values is None:
inputs_embeds, labels = self.model.extract_inputs_embeds(
input_ids=input_ids,
labels=labels,
pixel_values=pixel_values,
discrete_pixel_values=discrete_pixel_values,
past_key_values=past_key_values,
image_sizes=image_sizes,
mm_query_lengths=mm_query_lengths,
non_mm_query_lengths=non_mm_query_lengths,
img_start_ids_list=img_start_ids_list,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
first_last_frames_slows=first_last_frames_slows,
is_videos=is_videos,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
video_audio_values=video_audio_values,
video_audio_masks=video_audio_masks,
audio_values=audio_values,
discrete_audio_values=discrete_audio_values,
discrete_audio_value_num_per_sample=discrete_audio_value_num_per_sample,
audio_masks=audio_masks,
)
if labels is not None and labels.size(1) > 32768:
print(
f"[RANK {rank} debug] ❌ labels.size(1) > 32768. labels.size(): {labels.size()}"
)
if inputs_embeds is not None:
input_ids = None
import os
rank = int(os.environ.get("RANK", -1))
if inputs_embeds is not None:
expected_hidden_size = self.config.text_config.hidden_size
if inputs_embeds.shape[-1] != expected_hidden_size:
print(f"[RANK {rank}] ❌ inputs_embeds dimension mismatch!")
print(
f" Expected: {expected_hidden_size}, Got: {inputs_embeds.shape[-1]}"
)
if labels is not None:
vocab_size = self.get_input_embeddings().num_embeddings
valid_labels = labels[labels != -100]
if len(valid_labels) > 0:
if (valid_labels >= vocab_size).any() or (valid_labels < 0).any():
print(f"[RANK {rank}] ❌ CRITICAL: labels out of vocab range!")
print(
f" labels min/max: {valid_labels.min().item()}/{valid_labels.max().item()}"
)
print(f" vocab_size: {vocab_size}")
print(
f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}"
)
if attention_mask is not None and inputs_embeds is not None:
if attention_mask.shape[1] != inputs_embeds.shape[1]:
print(f"[RANK {rank}] ❌ attention_mask shape mismatch!")
print(
f" attention_mask: {attention_mask.shape}, inputs_embeds: {inputs_embeds.shape}"
)
if position_ids is not None:
max_position = position_ids.max().item()
if hasattr(self.language_model.config, "max_position_embeddings"):
max_allowed = self.language_model.config.max_position_embeddings
if max_position >= max_allowed:
print(f"[RANK {rank}] ❌ position_ids out of range!")
print(f" max_position: {max_position}, max_allowed: {max_allowed}")
if self.sp_manager is not None:
batch_size, seqlen, hidden_size = inputs_embeds.shape
sp_group = get_ulysses_sequence_parallel_group()
sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group)
inputs_embeds, labels = self._pad_sequence_for_sp(
inputs_embeds, labels, sp_world_size
)
if position_ids is None:
position_ids = torch.arange(
seqlen, device=inputs_embeds.device, dtype=torch.long
)
position_ids = (
position_ids.unsqueeze(0).expand(batch_size, -1).contiguous()
)
inputs_embeds = slice_input_tensor(
inputs_embeds, 1, padding=False, group=sp_group
)
labels = slice_input_tensor(labels, 1, padding=False, group=sp_group)
use_cache = False
outputs = self.language_model.base_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = hidden_states * self.config.text_config.logits_scaling
loss = None
logits = None
if labels is not None:
if self.use_liger and self.use_fused_ce:
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
hidden_states = hidden_states[..., :-1, :].contiguous()
hidden_states = hidden_states.view(
-1, self.language_model.config.hidden_size
).to(self.language_model.lm_head.weight.dtype)
import os
rank = int(os.environ.get("RANK", -1))
vocab_size = self.language_model.lm_head.weight.shape[0]
valid_labels = shift_labels[shift_labels != -100]
if len(valid_labels) > 0 and (
(valid_labels >= vocab_size).any() or (valid_labels < 0).any()
):
print(
f"[RANK {rank}] ❌ CRITICAL: shift_labels out of vocab range!"
)
print(
f" min/max: {valid_labels.min().item()}/{valid_labels.max().item()}, vocab: {vocab_size}"
)
print(
f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}"
)
lce = LigerFusedLinearCrossEntropyLoss(reduction=self.reduction)
try:
loss = lce(
self.language_model.lm_head.weight, hidden_states, shift_labels
)
except RuntimeError as e:
print(
f"[RANK {rank}] ❌ FATAL: LigerFusedLinearCrossEntropyLoss failed!"
)
print(f" Error: {e}")
print(
f" hidden_states: shape={hidden_states.shape}, dtype={hidden_states.dtype}"
)
print(
f" shift_labels: shape={shift_labels.shape}, unique_values={torch.unique(shift_labels).tolist()[:20]}"
)
print(
f" lm_head.weight: shape={self.language_model.lm_head.weight.shape}"
)
raise
elif self.use_liger:
logits = self.language_model.lm_head(hidden_states)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = LigerCrossEntropyLoss(reduction=self.reduction)
shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
else:
logits = self.language_model.lm_head(hidden_states)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=self.reduction)
shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if self.sp_manager is not None:
loss = gather_outputs_and_unpad(
loss, gather_dim=0, unpad_dim=0, padding_size=0, group=sp_group
)
if self.use_meansum_loss:
loss = loss.view(labels.size(0), -1).mean(dim=1).sum()
elif self.use_sqrtsum_loss:
per_token = loss.view(labels.size(0), -1)
per_sample_mean = per_token.mean(dim=1)
with torch.no_grad():
labels_2d = labels.view(labels.size(0), -1)
ignore_index = getattr(self.config, "ignore_index", -100)
valid_mask = labels_2d.ne(ignore_index)
valid_count = valid_mask.sum(dim=1).clamp(min=1).float()
raw_w = valid_count.sqrt()
w_mean = raw_w.mean().clamp(min=1e-6)
norm_w = raw_w / w_mean
loss = (per_sample_mean * norm_w).sum()
elif self.use_turnmeansum_loss:
with torch.no_grad():
mask = shift_labels.view(labels.size(0), -1).ne(
self.config.ignore_index
)
prev_mask = mask.roll(shifts=1, dims=1)
prev_mask[:, 0] = False
turn_starts = mask & (~prev_mask)
turn_count = turn_starts.sum(dim=1).clamp(min=1).float()
loss = (loss.view(labels.size(0), -1).mean(dim=1) * turn_count).sum()
if self.sp_manager is not None:
loss = loss / self.sp_manager.device_mesh.shape[1]
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
*args,
**kwargs,
):
state_dict = (
kwargs["state_dict"]
if kwargs.get("state_dict", None)
else self.state_dict()
)
partial_state_dict = self.get_pretrained_state_dict(
state_dict,
)
kwargs["state_dict"] = partial_state_dict
kwargs["safe_serialization"] = self.is_safetensor_save
kwargs.setdefault("max_shard_size", self.save_shard_size)
super().save_pretrained(save_directory, *args, **kwargs)
if self.is_qwen_visual:
self.config.architectures = ["HCXVisionV2ForCausalLM"]
else:
self.config.architectures = ["HCXVisionForCausalLM"]
self.config.auto_map["AutoModelForCausalLM"] = (
"modeling_vlm.HCXVisionForCausalLM"
)
self.config.auto_map["AutoModelForSequenceClassification"] = (
"modeling_vlm.HCXVisionForSequenceClassification"
)
self.config.save_pretrained(save_directory)
def get_pretrained_state_dict(self, state_dict):
vision_key = "vision_model."
llm_keys = ["language_model."]
head_key = "lm_head."
for key in list(state_dict.keys()):
if self.save_only_vision:
for llm_key in llm_keys:
if llm_key in key:
state_dict.pop(key)
if key.startswith(head_key):
state_dict.pop(key)
elif self.save_only_qformer:
if f"{vision_key}" in key:
state_dict.pop(key)
return state_dict