File size: 4,607 Bytes
f5e4236 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import transformers
from transformers import AutoConfig, PretrainedConfig
class HCXVisionConfig(PretrainedConfig):
model_type = "vlm"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
text_model_name_or_path=None,
vision_model_name_or_path=None,
q_former_model_name_or_path=None,
mm_projector_type="mlp",
use_nth_layer=-2,
img_start_id=100271, # <|IMAGE_PAD|>
video_start_id=100270, # <|VIDEO_PAD|>
freeze_encoder=False,
freeze_decoder=False,
freeze_mm_projector=False,
anyres=False,
unpad=False,
max_num_grids=-1,
num_queries_vis_abstractor=-1,
video_num_queries_fast=None,
video_num_queries_slow=None,
video_first_last_frames_slows=None,
video_max_num_frames=None,
ignore_index=-100,
proj_pos_emb=True,
proj_prenorm=False,
use_1x1_grid=False,
possible_resolutions=[],
**kwargs,
):
from transformers import CONFIG_MAPPING
if kwargs.get("language_config", None) is not None: # for bc
text_config = CONFIG_MAPPING[kwargs["language_config"]["model_type"]](**kwargs["language_config"])
elif text_config is None and text_model_name_or_path is not None:
text_config = AutoConfig.from_pretrained(text_model_name_or_path, trust_remote_code=True)
if vision_config is None and vision_model_name_or_path is not None:
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path, trust_remote_code=True)
if isinstance(text_config, dict):
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
if isinstance(vision_config, dict):
if vision_config["model_type"] == "qwen2_5_vl":
vision_config["model_type"] = "qwen2_5_vl_visual"
assert transformers.__version__ >= "4.52.4", "please upgrade transformers to 4.52.4 or higher"
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
self.text_config = text_config
self.vision_config = vision_config
if text_config is not None:
# deepspeed zero3에서 config의 hidden_size를 보고 메모리 크기를 자동으로 결정함.
self.hidden_size = text_config.hidden_size if hasattr(text_config, "hidden_size") else text_config.n_embd
# add VLM configs
self.text_model_name_or_path = text_model_name_or_path
self.vision_model_name_or_path = vision_model_name_or_path
self.q_former_model_name_or_path = q_former_model_name_or_path
self.mm_projector_type = mm_projector_type
self.use_nth_layer = use_nth_layer
self.freeze_encoder = freeze_encoder
self.freeze_decoder = freeze_decoder
self.freeze_mm_projector = freeze_mm_projector
self.anyres = anyres
self.unpad = unpad
self.max_num_grids = max_num_grids
self.num_queries_vis_abstractor = num_queries_vis_abstractor
self.video_num_queries_fast = video_num_queries_fast
self.video_num_queries_slow = video_num_queries_slow
self.video_first_last_frames_slows = video_first_last_frames_slows
self.video_max_num_frames = video_max_num_frames
self.img_start_id = img_start_id
self.image_token_id = img_start_id
self.video_start_id = video_start_id
self.video_token_id = video_start_id
self.ignore_index = ignore_index
self.proj_pos_emb = proj_pos_emb
self.proj_prenorm = proj_prenorm
self.use_1x1_grid = use_1x1_grid
self.possible_resolutions = possible_resolutions
super().__init__(**kwargs)
if self.text_config is not None: # needed for HCXVisionForSequenceClassification
self.pad_token_id = self.text_config.pad_token_id
AutoConfig.register("vlm", HCXVisionConfig)
try:
from .configuration_hyperclovax import HyperCLOVAXConfig
AutoConfig.register("hyperclovax", HyperCLOVAXConfig)
except:
pass
try:
from transformers import CONFIG_MAPPING, MODEL_MAPPING
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLPatchMerger,
Qwen2_5_VLVisionConfig,
)
MODEL_MAPPING.register(Qwen2_5_VLVisionConfig, Qwen2_5_VisionTransformerPretrainedModel)
CONFIG_MAPPING.register("qwen2_5_vl_visual", Qwen2_5_VLVisionConfig)
except:
pass
|