UNIVA-Jason commited on
Commit
a16cf3c
·
0 Parent(s):

first commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # Model Summary
6
+ ## 25EMBAI-VLM-FM is a Vision-Language Foundation Model built by combining:
7
+ ### Vision Encoder: ViT-H/14 (OpenCLIP)
8
+ ### Language Model: Qwen-based LLM
9
+ ### Bridging Modules: Resampler + Projector (image → LLM embedding space)
10
+
11
+ It takes an image, encodes it into patch tokens, compresses them into a fixed-length set of visual tokens, projects them into the language model’s hidden space, and then performs multimodal reasoning conditioned on a text prompt.
12
+
13
+ ## Architecture Flow
14
+ Image → ViT-H/14 → Resampler → Projector → Qwen LLM → Text Output
15
+ LLM Input Format
[Batch, K_image_tokens + T_text_tokens, D_hidden]
16
+
17
+ ## Training Summary
18
+ ### Pre-training (Stage 1 & 2)
19
+ Hardware: 8 × H100 80GB
20
+ Stage 1 (3.6h):
Freeze ViT + LLM → Train Resampler + Projector
21
+ Stage 2 (5.4h):
Unfreeze all → Train end-to-end
22
+ Data: ~2M image–caption pairs (BLIP3 style)
23
+ ### Instruction Fine-tuning
24
+ ~2M images + ~200M text tokens
25
+ ~20 multimodal tasks: VQA, OCR, captioning, commands
26
+ max_length: 1024
27
+ effective batch size: ~64
28
+
29
+
30
+ # Usage
31
+
32
+ ## Install
33
+ pip install torch transformers pillow
34
+
35
+ ## Inference Example
36
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
37
+
38
+ import torch
39
+
40
+ from PIL import Image
41
+
42
+ model_path = "YOUR_HF_USERNAME/25EMBAI-VLM-FM"
43
+
44
+ dtype = torch.bfloat16
45
+
46
+ ### Load model
47
+ model = AutoModel.from_pretrained(
48
+ model_path,
49
+ trust_remote_code=True,
50
+ ).to(device="cuda", dtype=dtype)
51
+
52
+ ### Load tokenizer
53
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
54
+
55
+ ### Load image processor from model assets
56
+ image_processor = AutoImageProcessor.from_pretrained(
57
+ model_path,
58
+ trust_remote_code=True,
59
+ )
60
+
61
+ model.eval()
62
+
63
+ ### Load image
64
+ img = Image.open("sample.png").convert("RGB")
65
+
66
+ ### Transform image → visual embeddings
67
+ pixel = image_processor(img, return_tensors="pt")["pixel_values"].to(
68
+ dtype=dtype, device="cuda"
69
+ )
70
+
71
+ ### Prompt
72
+ prompt = "please describe this image."
73
+
74
+ ### Multimodal generation
75
+ output = model.generate_text(
76
+ images=pixel,
77
+ prompt=prompt,
78
+ max_new_tokens=512,
79
+ do_sample=True,
80
+ top_p=0.9,
81
+ temperature=0.7,
82
+ )
83
+
84
+ print(output)
85
+ # Limitations & Biases
86
+ This model is an early-stage prototype.
It will be updated and reorganized in future releases.
87
+ Because it was trained on web-scale multimodal data:
88
+ It may reflect social biases and stereotypes
89
+ It may hallucinate, invent facts, or produce unverifiable content
90
+ It may perform suboptimally on:
91
+ Non-English languages
92
+ Specialized and domain-specific tasks
93
+ Safety-critical contexts
94
+ This model is not recommended for medical, legal, or safety-critical use without additional validation, guardrails, or fine-tuning.
95
+ Users should apply external filtering, grounding, and safety alignment before deployment.
VLM_prototype/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
VLM_prototype/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
VLM_prototype/chat_template.jinja ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- for message in messages %}
18
+ {%- if message.content is string %}
19
+ {%- set content = message.content %}
20
+ {%- else %}
21
+ {%- set content = '' %}
22
+ {%- endif %}
23
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
24
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
25
+ {%- elif message.role == "assistant" %}
26
+ {{- '<|im_start|>' + message.role + '\n' + content }}
27
+ {%- if message.tool_calls %}
28
+ {%- for tool_call in message.tool_calls %}
29
+ {%- if (loop.first and content) or (not loop.first) %}
30
+ {{- '\n' }}
31
+ {%- endif %}
32
+ {%- if tool_call.function %}
33
+ {%- set tool_call = tool_call.function %}
34
+ {%- endif %}
35
+ {{- '<tool_call>\n{"name": "' }}
36
+ {{- tool_call.name }}
37
+ {{- '", "arguments": ' }}
38
+ {%- if tool_call.arguments is string %}
39
+ {{- tool_call.arguments }}
40
+ {%- else %}
41
+ {{- tool_call.arguments | tojson }}
42
+ {%- endif %}
43
+ {{- '}\n</tool_call>' }}
44
+ {%- endfor %}
45
+ {%- endif %}
46
+ {{- '<|im_end|>\n' }}
47
+ {%- elif message.role == "tool" %}
48
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
49
+ {{- '<|im_start|>user' }}
50
+ {%- endif %}
51
+ {{- '\n<tool_response>\n' }}
52
+ {{- content }}
53
+ {{- '\n</tool_response>' }}
54
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
55
+ {{- '<|im_end|>\n' }}
56
+ {%- endif %}
57
+ {%- endif %}
58
+ {%- endfor %}
59
+ {%- if add_generation_prompt %}
60
+ {{- '<|im_start|>assistant\n' }}
61
+ {%- endif %}
VLM_prototype/config.json ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "VLMModel",
3
+ "architectures": [
4
+ "VLMModel"
5
+ ],
6
+ "model_type": "vision-language",
7
+ "version": 1,
8
+ "vision_model": "ViT-H-14-378-quickgelu",
9
+ "vision_pretrained": "dfn5b",
10
+ "vision_width": 1280,
11
+ "lm_model": "Qwen/Qwen3-4B-Instruct-2507",
12
+ "lm_hidden_size": 2560,
13
+ "n_vis_tokens": 196,
14
+ "auto_map": {
15
+ "AutoConfig": "model.VLMConfig",
16
+ "AutoModel": "model.VLMModel"
17
+ },
18
+ "pad_token_id": 151643,
19
+ "eos_token_id": 151645,
20
+ "model_args": {
21
+ "vision_model": "ViT-H-14-378-quickgelu",
22
+ "vision_pretrained": "dfn5b",
23
+ "lm_model": "Qwen/Qwen3-4B-Instruct-2507",
24
+ "vision_freeze": true,
25
+ "projector_freeze": true,
26
+ "llm_freeze": true,
27
+ "n_vis_tokens": 196,
28
+ "checkpoint_path": null,
29
+ "use_chat_template": true
30
+ },
31
+ "lm_config": {
32
+ "vocab_size": 151936,
33
+ "max_position_embeddings": 262144,
34
+ "hidden_size": 2560,
35
+ "intermediate_size": 9728,
36
+ "num_hidden_layers": 36,
37
+ "num_attention_heads": 32,
38
+ "use_sliding_window": false,
39
+ "sliding_window": null,
40
+ "max_window_layers": 36,
41
+ "num_key_value_heads": 8,
42
+ "head_dim": 128,
43
+ "hidden_act": "silu",
44
+ "initializer_range": 0.02,
45
+ "rms_norm_eps": 1e-06,
46
+ "use_cache": false,
47
+ "rope_theta": 5000000,
48
+ "rope_scaling": null,
49
+ "attention_bias": false,
50
+ "attention_dropout": 0.0,
51
+ "layer_types": [
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention",
61
+ "full_attention",
62
+ "full_attention",
63
+ "full_attention",
64
+ "full_attention",
65
+ "full_attention",
66
+ "full_attention",
67
+ "full_attention",
68
+ "full_attention",
69
+ "full_attention",
70
+ "full_attention",
71
+ "full_attention",
72
+ "full_attention",
73
+ "full_attention",
74
+ "full_attention",
75
+ "full_attention",
76
+ "full_attention",
77
+ "full_attention",
78
+ "full_attention",
79
+ "full_attention",
80
+ "full_attention",
81
+ "full_attention",
82
+ "full_attention",
83
+ "full_attention",
84
+ "full_attention",
85
+ "full_attention",
86
+ "full_attention",
87
+ "full_attention"
88
+ ],
89
+ "return_dict": true,
90
+ "output_hidden_states": false,
91
+ "torchscript": false,
92
+ "dtype": "float32",
93
+ "pruned_heads": {},
94
+ "tie_word_embeddings": true,
95
+ "chunk_size_feed_forward": 0,
96
+ "is_encoder_decoder": false,
97
+ "is_decoder": false,
98
+ "cross_attention_hidden_size": null,
99
+ "add_cross_attention": false,
100
+ "tie_encoder_decoder": false,
101
+ "architectures": [
102
+ "Qwen3ForCausalLM"
103
+ ],
104
+ "finetuning_task": null,
105
+ "id2label": {
106
+ "0": "LABEL_0",
107
+ "1": "LABEL_1"
108
+ },
109
+ "label2id": {
110
+ "LABEL_0": 0,
111
+ "LABEL_1": 1
112
+ },
113
+ "task_specific_params": null,
114
+ "problem_type": null,
115
+ "tokenizer_class": null,
116
+ "prefix": null,
117
+ "bos_token_id": 151643,
118
+ "pad_token_id": null,
119
+ "eos_token_id": 151645,
120
+ "sep_token_id": null,
121
+ "decoder_start_token_id": null,
122
+ "max_length": 20,
123
+ "min_length": 0,
124
+ "do_sample": false,
125
+ "early_stopping": false,
126
+ "num_beams": 1,
127
+ "temperature": 1.0,
128
+ "top_k": 50,
129
+ "top_p": 1.0,
130
+ "typical_p": 1.0,
131
+ "repetition_penalty": 1.0,
132
+ "length_penalty": 1.0,
133
+ "no_repeat_ngram_size": 0,
134
+ "encoder_no_repeat_ngram_size": 0,
135
+ "bad_words_ids": null,
136
+ "num_return_sequences": 1,
137
+ "output_scores": false,
138
+ "return_dict_in_generate": false,
139
+ "forced_bos_token_id": null,
140
+ "forced_eos_token_id": null,
141
+ "remove_invalid_values": false,
142
+ "exponential_decay_length_penalty": null,
143
+ "suppress_tokens": null,
144
+ "begin_suppress_tokens": null,
145
+ "num_beam_groups": 1,
146
+ "diversity_penalty": 0.0,
147
+ "_name_or_path": "Qwen/Qwen3-4B-Instruct-2507",
148
+ "transformers_version": "4.57.1",
149
+ "model_type": "qwen3",
150
+ "tf_legacy_loss": false,
151
+ "use_bfloat16": false,
152
+ "output_attentions": false
153
+ },
154
+ "resampler_config": {
155
+ "K": 196,
156
+ "n_layers": 2,
157
+ "use_pos": true,
158
+ "q_grid": 14,
159
+ "adaptive_kv_pos": true,
160
+ "use_q_self_attn": true
161
+ },
162
+ "projector_config": {
163
+ "type": "Projector",
164
+ "in_features": 1280,
165
+ "out_features": 2560,
166
+ "num_layers": 2
167
+ }
168
+ }
VLM_prototype/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
VLM_prototype/model.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass, asdict
4
+ from typing import Any, Dict, Optional, List
5
+ import json, os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+
12
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
13
+ from transformers.image_processing_utils import BaseImageProcessor
14
+
15
+ import open_clip
16
+ from PIL import Image
17
+ from safetensors.torch import save_file as _save_sf, load_file as _load_sf
18
+
19
+ from utils import *
20
+ from resampler import VisualResampler
21
+
22
+ class VLMConfig(PretrainedConfig):
23
+ model_type = "vision-language"
24
+
25
+ def __init__(
26
+ self,
27
+ vision_model: str = "ViT-H-14-378-quickgelu",
28
+ vision_pretrained: str = "dfn5b",
29
+ lm_model: str = "Qwen/Qwen3-4B-Instruct-2507",
30
+ n_vis_tokens: int = 196,
31
+ use_chat_template: bool = True,
32
+ dtype: str = torch.bfloat16,
33
+ **kwargs,
34
+ ):
35
+ super().__init__(**kwargs)
36
+ self.vision_model = vision_model
37
+ self.vision_pretrained = vision_pretrained
38
+ self.lm_model = lm_model
39
+ self.n_vis_tokens = n_vis_tokens
40
+ self.use_chat_template = use_chat_template
41
+
42
+ # ------------------------------- Args -------------------------------
43
+
44
+ @dataclass
45
+ class ModelArgs:
46
+ vision_model: str = "ViT-H-14-378-quickgelu"
47
+ vision_pretrained: str = "dfn5b"
48
+ # ✅ 기본값을 Qwen3 계열 Instruct로 설정 (원하면 CLI에서 덮어쓰기)
49
+ lm_model: str = "Qwen/Qwen3-4B-Instruct-2507"
50
+ vision_freeze: bool = False
51
+ projector_freeze: bool = False
52
+ llm_freeze: bool = False
53
+ n_vis_tokens: int = 196
54
+ checkpoint_path: Optional[str] = None
55
+ use_chat_template: bool = True
56
+
57
+ # ------------------------------- Vision wrappers -------------------------------
58
+
59
+ class OpenCLIPVisionOnly(nn.Module):
60
+ def __init__(self, clip_model):
61
+ super().__init__()
62
+ self.visual = clip_model.visual
63
+ self._feat_dim = self._infer_feat_dim()
64
+
65
+ def _infer_feat_dim(self) -> int:
66
+ v = self.visual
67
+ candidates = [
68
+ getattr(v, "width", None),
69
+ getattr(getattr(v, "ln_post", None), "normalized_shape", None),
70
+ getattr(getattr(v, "conv1", None), "out_channels", None),
71
+ getattr(v, "embed_dim", None),
72
+ ]
73
+ for cand in candidates:
74
+ if cand is None:
75
+ continue
76
+ if isinstance(cand, (list, tuple)):
77
+ return int(cand[0])
78
+ return int(cand)
79
+ # 안전 폴백
80
+ return 768
81
+
82
+ @property
83
+ def feat_dim(self) -> int:
84
+ return self._feat_dim
85
+
86
+ def tokens_or_global(self, pixel_values: torch.Tensor) -> torch.Tensor:
87
+ v = self.visual
88
+ # 1) 패치 임베딩
89
+ x = v.conv1(pixel_values) # [B, C, H/ps, W/ps]
90
+ x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1) # [B, HW, C]
91
+
92
+ # 2) CLS 붙이기
93
+ cls = v.class_embedding.to(x.dtype)
94
+ cls = cls.unsqueeze(0).expand(x.size(0), 1, -1) # [B,1,C]
95
+ x = torch.cat([cls, x], dim=1) # [B, 1+HW, C]
96
+
97
+ # 3) 위치 임베딩(해상도 보간)
98
+ pe = getattr(v, "positional_embedding", None)
99
+ if pe is not None:
100
+ if pe.dim() == 2:
101
+ pe = pe.unsqueeze(0) # [1, N, C]
102
+ if pe.shape[1] != x.shape[1]:
103
+ cls_pe, patch_pe = pe[:, :1, :], pe[:, 1:, :]
104
+ s0 = int((patch_pe.shape[1]) ** 0.5)
105
+ s1 = int((x.shape[1] - 1) ** 0.5)
106
+ patch_pe = patch_pe.reshape(1, s0, s0, -1).permute(0, 3, 1, 2)
107
+ patch_pe = F.interpolate(patch_pe, size=(s1, s1), mode="bicubic", align_corners=False)
108
+ patch_pe = patch_pe.permute(0, 2, 3, 1).reshape(1, s1 * s1, -1)
109
+ pe = torch.cat([cls_pe, patch_pe], dim=1)
110
+ x = x + pe.to(dtype=x.dtype, device=x.device)
111
+
112
+ x = v.ln_pre(x)
113
+ for blk in v.transformer.resblocks:
114
+ x = blk(x)
115
+ x = v.ln_post(x) # [B, 1+T, C]
116
+
117
+ # CLS 제외 패치 토큰 반환
118
+ return x[:, 1:, :] # [B, T, C]
119
+
120
+ # ------------------------------- Adapters -------------------------------
121
+
122
+ class Projector(nn.Module):
123
+ def __init__(self, d_in: int, d_out: int):
124
+ super().__init__()
125
+ self.net = nn.Sequential(
126
+ nn.LayerNorm(d_in),
127
+ nn.Linear(d_in, d_out),
128
+ nn.GELU(),
129
+ nn.Linear(d_out, d_out),
130
+ nn.LayerNorm(d_out),
131
+ )
132
+ def forward(self, x):
133
+ return self.net(x)
134
+
135
+ # ------------------------------- VLM -------------------------------
136
+
137
+ class VLMModel(nn.Module):
138
+ def __init__(self, margs: ModelArgs, device: Optional[torch.device] = None):
139
+ super().__init__()
140
+ self.margs = margs
141
+ self.device_ = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
+
143
+ # --- Vision ---
144
+ clip, _, _ = open_clip.create_model_and_transforms(
145
+ margs.vision_model, pretrained=margs.vision_pretrained, device=self.device_
146
+ )
147
+ self.vision = OpenCLIPVisionOnly(clip).to(self.device_)
148
+ self.vision_width = int(self.vision.feat_dim)
149
+
150
+ # export용 메타
151
+ self.vision_model_name = margs.vision_model
152
+ self.vision_pretrained_tag = margs.vision_pretrained
153
+
154
+ if margs.vision_freeze:
155
+ for p in self.vision.parameters():
156
+ p.requires_grad_(False)
157
+
158
+ # --- Tokens pipeline
159
+ lm_cfg = AutoConfig.from_pretrained(margs.lm_model, trust_remote_code=True)
160
+ lm_hidden = lm_cfg.hidden_size
161
+
162
+ n_vis_tokens = margs.n_vis_tokens
163
+ self.resampler = VisualResampler(
164
+ C=self.vision_width,
165
+ K=n_vis_tokens,
166
+ n_heads=8,
167
+ n_layers=2,
168
+ kv_dim=self.vision_width,
169
+ use_pos=True,
170
+ q_grid=int(n_vis_tokens ** 0.5) if int(n_vis_tokens ** 0.5) ** 2 == n_vis_tokens else None,
171
+ adaptive_kv_pos=True,
172
+ dropout=0.0,
173
+ use_q_self_attn=True
174
+ ).to(self.device_)
175
+
176
+ self.projector = Projector(d_in=self.vision_width, d_out=lm_hidden).to(self.device_)
177
+ if margs.projector_freeze:
178
+ for p in self.projector.parameters():
179
+ p.requires_grad_(False)
180
+
181
+ # --- LM (Qwen3) ---
182
+ self.lm = AutoModelForCausalLM.from_pretrained(
183
+ margs.lm_model, trust_remote_code=True
184
+ )
185
+ self.tokenizer = AutoTokenizer.from_pretrained(
186
+ margs.lm_model, use_fast=True, trust_remote_code=True
187
+ )
188
+
189
+ self.lm.to(self.device_)
190
+ if self.tokenizer.pad_token is None:
191
+ self.tokenizer.pad_token = self.tokenizer.eos_token
192
+ self.lm_hidden = self.lm.config.hidden_size
193
+
194
+ if hasattr(self.lm, "config"):
195
+ self.lm.config.use_cache = False
196
+
197
+ if margs.llm_freeze:
198
+ for p in self.lm.parameters():
199
+ p.requires_grad_(False)
200
+
201
+ if margs.checkpoint_path:
202
+ target_path, kind = _resolve_ckpt_target_from_path(margs.checkpoint_path)
203
+ if not os.path.exists(target_path):
204
+ raise FileNotFoundError(f"[ckpt] not found: {target_path} (kind={kind})")
205
+ rank0_print(f"[ckpt] target={target_path} ({kind})")
206
+
207
+ # ▶ 단일 파일/샤드 디렉터리/루트 인덱스 모두 지원
208
+ state = _hf_load_state_dict_any(target_path, map_location=self.device_)
209
+
210
+ # {'state_dict': ...} 래핑 해제
211
+ if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
212
+ state = state["state_dict"]
213
+
214
+ # DDP 저장본 접두어 제거
215
+ if any(isinstance(k, str) and k.startswith("module.") for k in state.keys()):
216
+ state = {k.replace("module.", "", 1): v for k, v in state.items()}
217
+
218
+ # ▶ 현재 모델(self)에 주입 (버그 fix: self.model X)
219
+ missing, unexpected = self.load_state_dict(state, strict=False)
220
+ rank0_print(f"[load_state_dict] missing:{len(missing)} unexpected:{len(unexpected)}")
221
+ if missing:
222
+ rank0_print(" - missing (first 10): " + ", ".join(missing[:10]))
223
+ if unexpected:
224
+ rank0_print(" - unexpected (first 10): " + ", ".join(unexpected[:10]))
225
+
226
+
227
+ @property
228
+ def device(self):
229
+ return self.device_
230
+
231
+ @property
232
+ def tok_emb(self):
233
+ return self.lm.get_input_embeddings()
234
+
235
+ # ---- Vision encode (patch tokens 우선) ----
236
+ def encode_patches(self, pixel_values: torch.Tensor) -> torch.Tensor:
237
+ """
238
+ Returns: patch tokens [B, T, C] (CLS 제외). 불가 시 [B, 1, C] 글로벌 토큰.
239
+ """
240
+ return self.vision.tokens_or_global(pixel_values)
241
+
242
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
243
+ if hasattr(self.lm, "gradient_checkpointing_enable"):
244
+ if gradient_checkpointing_kwargs is not None:
245
+ self.lm.gradient_checkpointing_enable(
246
+ gradient_checkpointing_kwargs=gradient_checkpointing_kwargs
247
+ )
248
+ else:
249
+ self.lm.gradient_checkpointing_enable()
250
+ if hasattr(self.lm, "config"):
251
+ self.lm.config.use_cache = False
252
+ if hasattr(self.lm, "enable_input_require_grads"):
253
+ try:
254
+ self.lm.enable_input_require_grads()
255
+ except Exception:
256
+ pass
257
+
258
+ def gradient_checkpointing_disable(self):
259
+ if hasattr(self.lm, "gradient_checkpointing_disable"):
260
+ self.lm.gradient_checkpointing_disable()
261
+
262
+ # ---- Visual-first concat (no splice, no markers) ----
263
+ def _concat_visual_first(
264
+ self,
265
+ text_embeds: torch.Tensor, # [B, T, D]
266
+ attention_mask: torch.Tensor, # [B, T]
267
+ proj_feats: Optional[torch.Tensor], # [B, K, D] or None
268
+ ):
269
+ if proj_feats is None:
270
+ return text_embeds, attention_mask
271
+ B, T, D = text_embeds.shape
272
+ K = proj_feats.size(1)
273
+ new_embeds = torch.cat([proj_feats, text_embeds], dim=1) # [B, K+T, D]
274
+ pad_ones = torch.ones(B, K, dtype=attention_mask.dtype, device=attention_mask.device)
275
+ new_attn = torch.cat([pad_ones, attention_mask], dim=1) # [B, K+T]
276
+ return new_embeds, new_attn
277
+
278
+ # ---- Forward ----
279
+ def forward(
280
+ self,
281
+ images: Optional[torch.Tensor] = None,
282
+ input_ids: Optional[torch.Tensor] = None,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ labels: Optional[torch.Tensor] = None,
285
+ ) -> Dict[str, Any]:
286
+ # input_ids 없으면 labels에서 복구
287
+ if input_ids is None:
288
+ if labels is None:
289
+ raise ValueError("Either input_ids or labels must be provided.")
290
+ pad_id = self.tokenizer.pad_token_id
291
+ input_ids = labels.clone()
292
+ input_ids[input_ids == -100] = pad_id
293
+ if attention_mask is None:
294
+ attention_mask = (input_ids != pad_id).long()
295
+
296
+ vis_tokens_lm = None
297
+ if images is not None:
298
+ patches = self.encode_patches(images) # [B,Tv,Cv]
299
+ patches_K = self.resampler(patches) # [B,K,Cv]
300
+ vis_tokens_lm = self.projector(patches_K) # [B,K,Dlm]
301
+ lm_dtype = next(self.lm.parameters()).dtype
302
+ if vis_tokens_lm.dtype != lm_dtype:
303
+ vis_tokens_lm = vis_tokens_lm.to(lm_dtype)
304
+
305
+ # 텍스트 임베딩 및 마스크 준비
306
+ text_emb = self.tok_emb(input_ids) # [B,T,D]
307
+ attn_in = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
308
+
309
+ # 🔁 splice 제거 → 항상 "비전 토큰 앞 concat"
310
+ inputs_embeds, attn = self._concat_visual_first(
311
+ text_embeds=text_emb,
312
+ attention_mask=attn_in,
313
+ proj_feats=vis_tokens_lm, # 이미지 없으면 None
314
+ )
315
+
316
+ # 라벨 재구성
317
+ new_labels = None
318
+ if labels is not None:
319
+ if vis_tokens_lm is None:
320
+ new_labels = labels
321
+ else:
322
+ B, K, _ = vis_tokens_lm.shape
323
+ pad_mask = torch.full((B, K), -100, dtype=labels.dtype, device=labels.device)
324
+ new_labels = torch.cat([pad_mask, labels], dim=1) # [B, K+T]
325
+
326
+ out = self.lm(
327
+ inputs_embeds=inputs_embeds if vis_tokens_lm is not None else None,
328
+ attention_mask=attn if vis_tokens_lm is not None else attention_mask,
329
+ input_ids=None if vis_tokens_lm is not None else input_ids,
330
+ labels=new_labels if vis_tokens_lm is not None else labels,
331
+ output_hidden_states=False,
332
+ use_cache=False,
333
+ )
334
+ return {"loss": out.loss}
335
+
336
+ # ---- Inference helper (Qwen chat template) ----
337
+ @torch.no_grad()
338
+ def generate_text(
339
+ self,
340
+ images: Optional[torch.Tensor] = None,
341
+ prompt: Optional[str] = None,
342
+ max_new_tokens: int = 128,
343
+ do_sample: bool = True,
344
+ top_p: float = 0.9,
345
+ temperature: float = 0.7,
346
+ **gen_kw,
347
+ ) -> str:
348
+ self.eval()
349
+ device = self.device
350
+
351
+ system_prompt = "You are a helpful assistant."
352
+ user_text = (prompt or "").strip()
353
+
354
+ if getattr(self.margs, "use_chat_template", True):
355
+ messages = []
356
+ if system_prompt.strip():
357
+ messages.append({"role": "system", "content": system_prompt.strip()})
358
+ messages.append({"role": "user", "content": user_text})
359
+
360
+ input_ids = self.tokenizer.apply_chat_template(
361
+ messages,
362
+ tokenize=True,
363
+ add_generation_prompt=True,
364
+ return_tensors="pt",
365
+ )
366
+ attention_mask = torch.ones_like(input_ids)
367
+ else:
368
+ # (fallback) 단순 문자열
369
+ text = f"System: {system_prompt}\nUser: {user_text}\nAssistant: "
370
+ enc = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)
371
+ input_ids = enc.input_ids
372
+ attention_mask = enc.attention_mask
373
+
374
+ input_ids = input_ids.to(device)
375
+ attention_mask = attention_mask.to(device)
376
+
377
+ if images is not None:
378
+ # 비전 임베딩
379
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
380
+ vision_feats = self.encode_patches(images.to(device, non_blocking=True))
381
+ patches_K = self.resampler(vision_feats) # [1,K,Cv]
382
+ projected_feats = self.projector(patches_K) # [1,K,Dlm]
383
+ lm_dtype = next(self.lm.parameters()).dtype
384
+ if projected_feats.dtype != lm_dtype:
385
+ projected_feats = projected_feats.to(lm_dtype)
386
+
387
+ # 텍스트 임베딩
388
+ text_embeds = self.lm.get_input_embeddings()(input_ids) # [1,T,D]
389
+ inputs_embeds, attn = self._concat_visual_first(
390
+ text_embeds=text_embeds,
391
+ attention_mask=attention_mask,
392
+ proj_feats=projected_feats,
393
+ )
394
+ gen_out = self.lm.generate(
395
+ inputs_embeds=inputs_embeds,
396
+ attention_mask=attn,
397
+ max_new_tokens=max_new_tokens,
398
+ do_sample=gen_kw.get("do_sample", do_sample),
399
+ temperature=gen_kw.get("temperature", temperature),
400
+ top_p=gen_kw.get("top_p", top_p),
401
+ eos_token_id=self.tokenizer.eos_token_id,
402
+ pad_token_id=self.tokenizer.pad_token_id,
403
+ use_cache=True
404
+ )
405
+ seq = gen_out[0]
406
+ decoded = self.tokenizer.decode(seq, skip_special_tokens=True)
407
+ return decoded.strip()
408
+ else:
409
+ gen_out = self.lm.generate(
410
+ input_ids=input_ids,
411
+ attention_mask=attention_mask,
412
+ max_new_tokens=max_new_tokens,
413
+ do_sample=gen_kw.get("do_sample", do_sample),
414
+ temperature=gen_kw.get("temperature", temperature),
415
+ top_p=gen_kw.get("top_p", top_p),
416
+ eos_token_id=self.tokenizer.eos_token_id,
417
+ pad_token_id=self.tokenizer.pad_token_id,
418
+ use_cache=True
419
+ )
420
+ seq = gen_out[0]
421
+ decoded = self.tokenizer.decode(seq[input_ids.size(1):], skip_special_tokens=True)
422
+ return decoded.strip()
423
+
424
+ @staticmethod
425
+ def _break_shared_tensors(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
426
+ by_ptr: Dict[tuple, list] = {}
427
+ for k, v in sd.items():
428
+ if not isinstance(v, torch.Tensor):
429
+ continue
430
+
431
+ if v.device.type != "meta":
432
+ # 새 API 우선 사용, 없으면 옛날 storage()로 폴백
433
+ if hasattr(v, "untyped_storage"):
434
+ storage = v.untyped_storage()
435
+ else:
436
+ storage = v.storage()
437
+ ptr = storage.data_ptr()
438
+ else:
439
+ # meta tensor는 storage가 없으니 그냥 data_ptr
440
+ ptr = v.data_ptr()
441
+
442
+ by_ptr.setdefault((ptr, v.dtype, tuple(v.size())), []).append(k)
443
+
444
+ sd = dict(sd) # shallow copy
445
+ for (_ptr, _dtype, _shape), keys in by_ptr.items():
446
+ if len(keys) <= 1:
447
+ continue
448
+ master = keys[0]
449
+ for k in keys[1:]:
450
+ sd[k] = sd[k].clone()
451
+ return sd
452
+
453
+ @staticmethod
454
+ def _retie_lm_head_if_possible(model: "VLMModel"):
455
+ try:
456
+ emb = model.lm.get_input_embeddings().weight
457
+ except Exception:
458
+ emb = None
459
+ candidates = [
460
+ "lm.lm_head.weight",
461
+ "lm.get_output_embeddings.weight",
462
+ ]
463
+ for name in candidates:
464
+ head = model
465
+ try:
466
+ for part in name.split("."):
467
+ head = getattr(head, part)
468
+ if isinstance(head, torch.nn.Parameter) and emb is not None and head.shape == emb.shape:
469
+ with torch.no_grad():
470
+ head.set_(emb)
471
+ break
472
+ except Exception:
473
+ continue
474
+
475
+ # ---------------- config I/O ----------------
476
+ def _export_vlm_config(self) -> Dict[str, Any]:
477
+ """
478
+ HuggingFace 호환용 config.json 생성기.
479
+
480
+ - model_type: "vision-language" (커스텀 타입)
481
+ - architectures: ["VLMModel"] (AutoModel가 참고)
482
+ - auto_map:
483
+ * AutoConfig -> model.VLMConfig
484
+ * AutoModel -> model.VLMModel
485
+ """
486
+ # 1) 기본 VLM 메타
487
+ cfg: Dict[str, Any] = {
488
+ "_class_name": self.__class__.__name__, # "VLMModel"
489
+ "architectures": [self.__class__.__name__],
490
+ "model_type": "vision-language",
491
+ "version": 1,
492
+
493
+ # --- 간단 메타 ---
494
+ "vision_model": getattr(self, "vision_model_name", None) or "ViT-B-16-quickgelu",
495
+ "vision_pretrained": getattr(self, "vision_pretrained_tag", None) or "metaclip_400m",
496
+ "vision_width": int(getattr(self, "vision_width", 0)),
497
+
498
+ "lm_model": getattr(self.lm, "name_or_path", None) or self.margs.lm_model,
499
+ "lm_hidden_size": int(getattr(self, "lm_hidden", 0)),
500
+
501
+ "n_vis_tokens": getattr(self.resampler, "K", None) or self.margs.n_vis_tokens
502
+ }
503
+
504
+ # 1-1) remote code용 매핑 정보
505
+ cfg["auto_map"] = {
506
+ # model.py 모듈에서 VLMConfig / VLMModel을 import
507
+ "AutoConfig": "model.VLMConfig",
508
+ "AutoModel": "model.VLMModel",
509
+ }
510
+
511
+ # 1-2) 토크나이저 관련 메타 (있으면)
512
+ try:
513
+ tok = getattr(self, "tokenizer", None)
514
+ if tok is not None:
515
+ if tok.pad_token_id is not None:
516
+ cfg["pad_token_id"] = int(tok.pad_token_id)
517
+ if tok.eos_token_id is not None:
518
+ cfg["eos_token_id"] = int(tok.eos_token_id)
519
+ if tok.bos_token_id is not None:
520
+ cfg["bos_token_id"] = int(tok.bos_token_id)
521
+ except Exception:
522
+ pass
523
+
524
+ # 2) 원본 ModelArgs 저장
525
+ margs = getattr(self, "margs", None)
526
+ if margs is not None:
527
+ try:
528
+ cfg["model_args"] = asdict(margs)
529
+ except Exception:
530
+ try:
531
+ cfg["model_args"] = dict(margs.__dict__)
532
+ except Exception:
533
+ pass
534
+
535
+ # 3) LLM(HF) config 전체도 같이 저장
536
+ if hasattr(self.lm, "config") and self.lm.config is not None:
537
+ try:
538
+ cfg["lm_config"] = self.lm.config.to_dict()
539
+ except Exception:
540
+ pass
541
+
542
+ # 4) Resampler 관련 하이퍼파라미터
543
+ res = getattr(self, "resampler", None)
544
+ if res is not None:
545
+ res_cfg = {}
546
+ for attr in [
547
+ "K", "n_heads", "n_layers", "kv_dim",
548
+ "use_pos", "q_grid", "adaptive_kv_pos",
549
+ "dropout", "use_q_self_attn",
550
+ ]:
551
+ if hasattr(res, attr):
552
+ v = getattr(res, attr)
553
+ if isinstance(v, torch.Tensor):
554
+ v = v.item() if v.numel() == 1 else list(v.shape)
555
+ res_cfg[attr] = v
556
+ cfg["resampler_config"] = res_cfg
557
+
558
+ # 5) Projector 구조 요약
559
+ proj = getattr(self, "projector", None)
560
+ if proj is not None:
561
+ proj_cfg = {
562
+ "type": proj.__class__.__name__,
563
+ }
564
+ if hasattr(proj, "net") and isinstance(proj.net, nn.Sequential):
565
+ in_dim = None
566
+ out_dim = None
567
+ for m in proj.net:
568
+ if isinstance(m, nn.Linear):
569
+ if in_dim is None:
570
+ in_dim = m.in_features
571
+ out_dim = m.out_features
572
+ if in_dim is not None:
573
+ proj_cfg["in_features"] = int(in_dim)
574
+ if out_dim is not None:
575
+ proj_cfg["out_features"] = int(out_dim)
576
+ proj_cfg["num_layers"] = sum(
577
+ 1 for m in proj.net if isinstance(m, nn.Linear)
578
+ )
579
+ cfg["projector_config"] = proj_cfg
580
+
581
+ return cfg
582
+
583
+ def save_pretrained(
584
+ self,
585
+ save_directory: str,
586
+ safe: bool = True,
587
+ **kwargs, # HF가 넘기는 기타 인자 무시용
588
+ ) -> None:
589
+ os.makedirs(save_directory, exist_ok=True)
590
+
591
+ # 1) state_dict 추출 및 공유 스토리지 해제
592
+ sd = self.state_dict()
593
+ sd = self._break_shared_tensors(sd)
594
+
595
+ # 2) 저장 (safetensors 권장)
596
+ if safe:
597
+ _save_sf(sd, os.path.join(save_directory, "model.safetensors"))
598
+ else:
599
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
600
+
601
+ # 3) 구성/토크나이저
602
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
603
+ json.dump(self._export_vlm_config(), f, ensure_ascii=False, indent=2)
604
+ try:
605
+ if getattr(self, "tokenizer", None) is not None:
606
+ self.tokenizer.save_pretrained(save_directory)
607
+ except Exception:
608
+ pass
609
+
610
+ @classmethod
611
+ def from_pretrained(
612
+ cls,
613
+ pretrained_model_name_or_path: str,
614
+ *model_args,
615
+ config: Optional[Any] = None,
616
+ **kwargs,
617
+ ):
618
+ """
619
+ AutoModel.from_pretrained(..., trust_remote_code=True)가 호출하는 진입점.
620
+ pretrained_model_name_or_path: save_pretrained로 저장된 디렉토리 경로
621
+ """
622
+ load_directory = pretrained_model_name_or_path
623
+
624
+ # 1) config.json 로드
625
+ cfg_path = os.path.join(load_directory, "config.json")
626
+ if not os.path.exists(cfg_path):
627
+ raise FileNotFoundError(f"config.json not found in {load_directory}")
628
+ with open(cfg_path, "r", encoding="utf-8") as f:
629
+ cfg = json.load(f)
630
+
631
+ # 2) ModelArgs 복원 (config 안의 model_args가 있으면 우선 사용)
632
+ model_args_cfg = cfg.get("model_args", {})
633
+ margs = ModelArgs(
634
+ vision_model = model_args_cfg.get("vision_model", cfg.get("vision_model", "ViT-B-16-quickgelu")),
635
+ vision_pretrained= model_args_cfg.get("vision_pretrained",cfg.get("vision_pretrained", "metaclip_400m")),
636
+ lm_model = model_args_cfg.get("lm_model", cfg.get("lm_model", "state-spaces/mamba-370m-hf")),
637
+ n_vis_tokens = int(model_args_cfg.get("n_vis_tokens", cfg.get("n_vis_tokens", 196))),
638
+ use_chat_template= model_args_cfg.get("use_chat_template", True),
639
+ )
640
+
641
+ model = cls(margs)
642
+
643
+ # 5) 가중치 로드 (CPU로 먼저 불러오고 이후 .to(device))
644
+ wt_sf = os.path.join(load_directory, "model.safetensors")
645
+ wt_pt = os.path.join(load_directory, "pytorch_model.bin")
646
+ if os.path.exists(wt_sf):
647
+ sd = _load_sf(wt_sf, device="cpu")
648
+ elif os.path.exists(wt_pt):
649
+ sd = torch.load(wt_pt, map_location="cpu")
650
+ else:
651
+ raise FileNotFoundError(
652
+ f"No weight file found in {load_directory} "
653
+ f"(expected model.safetensors or pytorch_model.bin)"
654
+ )
655
+
656
+ # 6) DDP prefix 제거
657
+ if any(k.startswith("module.") for k in sd):
658
+ sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
659
+
660
+ missing, unexpected = model.load_state_dict(sd, strict=False)
661
+ print(f"[from_pretrained] missing={len(missing)} unexpected={len(unexpected)}")
662
+ if missing:
663
+ print(" - missing (first 10):", missing[:10])
664
+ if unexpected:
665
+ print(" - unexpected (first 10):", unexpected[:10])
666
+
667
+ # 7) 로드 후 tie 복구(메모리 절감)
668
+ try:
669
+ cls._retie_lm_head_if_possible(model)
670
+ except Exception:
671
+ pass
672
+
673
+ return model
674
+
675
+ @classmethod
676
+ def register_for_auto_class(cls, auto_class=None):
677
+ """
678
+ transformers.AutoModel.from_pretrained(..., trust_remote_code=True)가
679
+ 커스텀 모델을 AutoModel 레지스트리에 등록하려고 호출하는 메서드.
680
+ """
681
+ return
682
+
683
+ class OpenCLIPImageProcessor(BaseImageProcessor):
684
+ model_input_names = ["pixel_values"]
685
+ image_processor_type = "open_clip"
686
+
687
+ def __init__(
688
+ self,
689
+ vision_model: str = "ViT-H-14-378-quickgelu",
690
+ vision_pretrained: str = "dfn5b",
691
+ is_train: bool = False,
692
+ **kwargs,
693
+ ):
694
+ super().__init__(**kwargs)
695
+
696
+ self.vision_model = vision_model
697
+ self.vision_pretrained = vision_pretrained
698
+ self.is_train = is_train
699
+
700
+ # HF AutoImageProcessor용 remote code 매핑
701
+ self.auto_map = {"AutoImageProcessor": "model.OpenCLIPImageProcessor"}
702
+
703
+ # 실제 torchvision transform은 lazy하게 생성 (처음 호출 시)
704
+ self._transform = None
705
+
706
+ @classmethod
707
+ def register_for_auto_class(cls, auto_class=None):
708
+ # AutoImageProcessor.from_pretrained(..., trust_remote_code=True) 호출 시 사용될 수 있는 훅
709
+ # 우리는 별도 레지스트리 안 써도 되니 no-op
710
+ return
711
+
712
+ def _ensure_transform(self):
713
+ if self._transform is not None:
714
+ return
715
+
716
+ clip_model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
717
+ self.vision_model,
718
+ pretrained=self.vision_pretrained,
719
+ device="cpu", # 전처리용이니 CPU로 충분
720
+ )
721
+ del clip_model # 모델은 필요 없음
722
+
723
+ self._transform = preprocess_train if self.is_train else preprocess_val
724
+
725
+ def to_dict(self) -> Dict[str, Any]:
726
+ # BaseImageProcessor → dict 직렬화 시 호출
727
+ config = super().to_dict()
728
+ config.update(
729
+ {
730
+ "vision_model": self.vision_model,
731
+ "vision_pretrained": self.vision_pretrained,
732
+ "is_train": self.is_train,
733
+ "auto_map": self.auto_map,
734
+ "image_processor_type": self.image_processor_type,
735
+ }
736
+ )
737
+ return config
738
+
739
+ def __call__(
740
+ self,
741
+ images,
742
+ return_tensors: Optional[str] = "pt",
743
+ **kwargs,
744
+ ) -> Dict[str, Any]:
745
+ """
746
+ images: PIL.Image 또는 그 리스트
747
+ return: {"pixel_values": Tensor[B, 3, H, W]}
748
+ """
749
+ self._ensure_transform()
750
+
751
+ if not isinstance(images, (list, tuple)):
752
+ images = [images]
753
+
754
+ proc = []
755
+ for img in images:
756
+ if not isinstance(img, Image.Image):
757
+ raise TypeError(f"Expected PIL.Image, but got {type(img)}")
758
+ proc.append(self._transform(img))
759
+ pixel_values = torch.stack(proc, dim=0) # [B,3,H,W]
760
+
761
+ # HF 스타일에선 대부분 "pt" 텐서면 충분
762
+ return {"pixel_values": pixel_values}
763
+
764
+ # ------------------------------- Builder -------------------------------
765
+
766
+ def build_model(
767
+ vision_model: str,
768
+ vision_pretrained: str,
769
+ lm_model: str,
770
+ vision_freeze: bool,
771
+ projector_freeze: bool,
772
+ llm_freeze: bool,
773
+ device: Optional[torch.device] = None,
774
+ checkpoint_path: Optional[str] = None,
775
+ n_vis_tokens: int = 196,
776
+ use_chat_template: bool = True,
777
+ ) -> VLMModel:
778
+ margs = ModelArgs(
779
+ vision_model=vision_model,
780
+ vision_pretrained=vision_pretrained,
781
+ lm_model=lm_model,
782
+ checkpoint_path=checkpoint_path,
783
+ vision_freeze=vision_freeze,
784
+ projector_freeze=projector_freeze,
785
+ llm_freeze=llm_freeze,
786
+ n_vis_tokens=n_vis_tokens,
787
+ use_chat_template=use_chat_template,
788
+ )
789
+ model = VLMModel(margs, device)
790
+ model.to(device or model.device)
791
+ return model
792
+
VLM_prototype/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db36e5ec2007edfa7df1427237f8097ada6772f26535179e44c7d7b6d6fe8e1c
3
+ size 20426846152
VLM_prototype/preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_transform": null,
3
+ "auto_map": {
4
+ "AutoImageProcessor": "model.OpenCLIPImageProcessor"
5
+ },
6
+ "image_processor_type": "open_clip",
7
+ "is_train": false,
8
+ "vision_model": "ViT-H-14-378-quickgelu",
9
+ "vision_pretrained": "dfn5b"
10
+ }
VLM_prototype/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
VLM_prototype/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
VLM_prototype/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 1010000,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
VLM_prototype/vocab.json ADDED
The diff for this file is too large to render. See raw diff