Refactor code to transformers convention

#3
.gitattributes CHANGED
@@ -1,35 +1,137 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Common settings that generally should always be used with your language specific settings
2
+
3
+ # Auto detect text files and perform LF normalization
4
+ * text=auto
5
+
6
+ #
7
+ # The above will handle all files NOT found below
8
+ #
9
+
10
+ # Documents
11
+ *.bibtex text diff=bibtex
12
+ *.doc diff=astextplain
13
+ *.DOC diff=astextplain
14
+ *.docx diff=astextplain
15
+ *.DOCX diff=astextplain
16
+ *.dot diff=astextplain
17
+ *.DOT diff=astextplain
18
+ *.pdf diff=astextplain
19
+ *.PDF diff=astextplain
20
+ *.rtf diff=astextplain
21
+ *.RTF diff=astextplain
22
+ *.md text diff=markdown
23
+ *.mdx text diff=markdown
24
+ *.tex text diff=tex
25
+ *.adoc text
26
+ *.textile text
27
+ *.mustache text
28
+ *.csv text eol=crlf
29
+ *.tab text
30
+ *.tsv text
31
+ *.txt text
32
+ *.sql text
33
+ *.epub diff=astextplain
34
+
35
+ # Graphics
36
+ *.png binary
37
+ *.jpg binary
38
+ *.jpeg binary
39
+ *.gif binary
40
+ *.tif binary
41
+ *.tiff binary
42
+ *.ico binary
43
+ # SVG treated as text by default.
44
+ *.svg text
45
+ # If you want to treat it as binary,
46
+ # use the following line instead.
47
+ # *.svg binary
48
+ *.eps binary
49
+
50
+ # Scripts
51
+ *.bash text eol=lf
52
+ *.fish text eol=lf
53
+ *.ksh text eol=lf
54
+ *.sh text eol=lf
55
+ *.zsh text eol=lf
56
+ # These are explicitly windows files and should use crlf
57
+ *.bat text eol=crlf
58
+ *.cmd text eol=crlf
59
+ *.ps1 text eol=crlf
60
+
61
+ # Serialisation
62
+ *.json text
63
+ *.toml text
64
+ *.xml text
65
+ *.yaml text
66
+ *.yml text
67
+
68
+ # Archives
69
+ *.7z binary
70
+ *.bz binary
71
+ *.bz2 binary
72
+ *.bzip2 binary
73
+ *.gz binary
74
+ *.lz binary
75
+ *.lzma binary
76
+ *.rar binary
77
+ *.tar binary
78
+ *.taz binary
79
+ *.tbz binary
80
+ *.tbz2 binary
81
+ *.tgz binary
82
+ *.tlz binary
83
+ *.txz binary
84
+ *.xz binary
85
+ *.Z binary
86
+ *.zip binary
87
+ *.zst binary
88
+
89
+ # Text files where line endings should be preserved
90
+ *.patch -text
91
+
92
+ #
93
+ # Exclude files from exporting
94
+ #
95
+
96
+ .gitattributes export-ignore
97
+ .gitignore export-ignore
98
+ .gitkeep export-ignore
99
+
100
+
101
+
102
+ # Basic .gitattributes for a python repo.
103
+
104
+ # Source files
105
+ # ============
106
+ *.pxd text diff=python
107
+ *.py text diff=python
108
+ *.py3 text diff=python
109
+ *.pyw text diff=python
110
+ *.pyx text diff=python
111
+ *.pyz text diff=python
112
+ *.pyi text diff=python
113
+
114
+ # Binary files
115
+ # ============
116
+ *.db binary
117
+ *.p binary
118
+ *.pkl binary
119
+ *.pickle binary
120
+ *.pyc binary export-ignore
121
+ *.pyo binary export-ignore
122
+ *.pyd binary
123
+
124
+ # Jupyter notebook
125
+ *.ipynb text eol=lf
126
+
127
+ # Note: .db, .p, and .pkl files are associated
128
+ # with the python modules ``pickle``, ``dbm.*``,
129
+ # ``shelve``, ``marshal``, ``anydbm``, & ``bsddb``
130
+ # (among others).
131
+
132
+
133
+
134
  *.safetensors filter=lfs diff=lfs merge=lfs -text
135
+ /llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
136
+ /llm/vocab.json filter=lfs diff=lfs merge=lfs -text
137
+ /tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+
177
+
178
+ .vscode/
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<image>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
+ }
config.json CHANGED
@@ -2,13 +2,14 @@
2
  "_attn_implementation_autoset": true,
3
  "_name_or_path": "NVILA-Lite-2B-hf-preview",
4
  "architectures": [
5
- "VILAForCausalLM"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "configuration_vila.VILAConfig",
9
- "AutoModel": "modeling_vila.VILAForCausalLM",
10
- "AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
11
- "AutoProcessor": "auto_processor.VILAProcessor"
 
12
  },
13
  "chat_template": null,
14
  "drop_path_rate": 0.0,
 
2
  "_attn_implementation_autoset": true,
3
  "_name_or_path": "NVILA-Lite-2B-hf-preview",
4
  "architectures": [
5
+ "VILAForConditionalGeneration"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "configuration_vila.VILAConfig",
9
+ "AutoModel": "modeling_vila_hf.VILAForConditionalGeneration",
10
+ "AutoModelForCausalLM": "modeling_vila_hf.VILAForConditionalGeneration",
11
+ "AutoModelForImageTextToText": "modeling_vila_hf.VILAForConditionalGeneration",
12
+ "AutoModelForVision2Seq": "modeling_vila_hf.VILAForConditionalGeneration"
13
  },
14
  "chat_template": null,
15
  "drop_path_rate": 0.0,
configuration_vila.py CHANGED
@@ -1,93 +1,34 @@
1
- import json
2
- import math
3
- import os
4
- import os.path as osp
5
- from copy import deepcopy
6
- from threading import Thread
7
- from typing import List, Optional
8
 
9
- import torch
10
- import torchvision
11
- from PIL import Image
12
- from transformers import (
13
- AutoProcessor,
14
- PretrainedConfig,
15
- PreTrainedModel,
16
- Qwen2Config,
17
- Qwen2ForCausalLM,
18
- Qwen2PreTrainedModel,
19
- TextIteratorStreamer,
20
- )
21
 
22
 
23
  class VILAConfig(PretrainedConfig):
24
- model_type = "vila"
25
- keys_to_ignore_at_inference = ["past_key_values"]
26
-
27
- def __init__(
28
- self,
29
- llm_cfg=None,
30
- vision_tower_cfg=None,
31
- mm_projector_cfg=None,
32
- architectures=None,
33
- resume_path=None,
34
- hidden_size=None,
35
- mm_hidden_size=None,
36
- image_aspect_ratio=None,
37
- num_video_frames=None,
38
- fps=None,
39
- mm_vision_select_layer=None,
40
- mm_vision_select_feature=None,
41
- mm_use_im_start_end=False,
42
- mm_use_im_patch_token=False,
43
- mm_projector_lr=None,
44
- vision_tower_lr=None,
45
- vision_resolution=None,
46
- interpolate_mode=None,
47
- s2=None,
48
- dynamic_s2=None,
49
- s2_scales=None,
50
- s2_max_split_size=None,
51
- s2_resize_output_to_scale_idx=0,
52
- min_tiles: Optional[int] = 1,
53
- max_tiles: Optional[int] = 12,
54
- num_time_tokens=None,
55
- time_token_format=None,
56
- image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
57
- video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
58
- **kwargs,
59
- ):
60
- super().__init__(**kwargs)
61
-
62
- self.architectures = architectures
63
- self.llm_cfg = llm_cfg
64
- self.vision_tower_cfg = vision_tower_cfg
65
- self.mm_projector_cfg = mm_projector_cfg
66
- self.resume_path = resume_path
67
-
68
- self.hidden_size = hidden_size
69
- self.mm_hidden_size = mm_hidden_size
70
- self.image_aspect_ratio = image_aspect_ratio
71
- self.num_video_frames = num_video_frames
72
- self.fps = fps
73
- self.mm_vision_select_layer = mm_vision_select_layer
74
- self.mm_vision_select_feature = mm_vision_select_feature
75
- self.mm_use_im_start_end = mm_use_im_start_end
76
- self.mm_use_im_patch_token = mm_use_im_patch_token
77
- self.mm_projector_lr = mm_projector_lr
78
- self.vision_tower_lr = vision_tower_lr
79
- self.vision_resolution = vision_resolution
80
- self.interpolate_mode = interpolate_mode
81
- self.s2 = s2
82
- self.dynamic_s2 = dynamic_s2
83
- self.s2_scales = s2_scales
84
- self.s2_max_split_size = s2_max_split_size
85
- self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
86
- self.min_tiles = min_tiles
87
- self.max_tiles = max_tiles
88
- self.num_time_tokens = num_time_tokens
89
- self.time_token_format = time_token_format
90
-
91
- self.image_encoder = image_encoder
92
- self.video_encoder = video_encoder
93
-
 
1
+ from typing import Any, Dict
 
 
 
 
 
 
2
 
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
5
+ from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class VILAConfig(PretrainedConfig):
9
+ # Overridden class attributes.
10
+ model_type: str = "vila"
11
+ is_composition: bool = True
12
+
13
+ # Common attributes.
14
+ vocab_size: int
15
+ hidden_size: int
16
+ num_attention_heads: int
17
+ num_hidden_layers: int
18
+
19
+ # Other attributes.
20
+ llm_cfg: Dict[str, Any]
21
+ mm_projector_cfg: Dict[str, Any]
22
+ vision_tower_cfg: Dict[str, Any]
23
+
24
+ @property
25
+ def text_config(self) -> Qwen2Config:
26
+ config = Qwen2Config.from_dict(self.llm_cfg)
27
+ assert isinstance(config, Qwen2Config)
28
+ return config
29
+
30
+ @property
31
+ def vision_config(self) -> SiglipVisionConfig:
32
+ config = SiglipVisionConfig.from_dict(self.mm_projector_cfg)
33
+ assert isinstance(config, SiglipVisionConfig)
34
+ return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm/vocab.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_vila_hf.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Type, Union, cast, override
3
+
4
+ import transformers.modeling_utils as modeling_utils
5
+ from torch import FloatTensor, LongTensor, Tensor
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.generation.utils import GenerationMixin
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
11
+
12
+ from .configuration_vila import VILAConfig
13
+ from .modeling_vila import VILAForCausalLM
14
+
15
+ IMAGE_TOKEN_ID = 151649
16
+
17
+
18
+ class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
19
+ config_class: Type[PretrainedConfig] = VILAConfig
20
+ base_model_prefix: str = "vila"
21
+ is_parallelizable: bool = True
22
+ main_input_name: str = "input_ids"
23
+
24
+ config: PretrainedConfig
25
+
26
+ mm_projector: PreTrainedModel
27
+ llm: Qwen2ForCausalLM
28
+ vision_tower: PreTrainedModel
29
+
30
+ def __init__(
31
+ self,
32
+ config: PretrainedConfig,
33
+ model: VILAForCausalLM,
34
+ *args,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(config, *args, **kwargs)
38
+
39
+ self.mm_projector = cast(PreTrainedModel, model.mm_projector)
40
+ self.llm = cast(Qwen2ForCausalLM, model.llm)
41
+ self.vision_tower = cast(PreTrainedModel, model.vision_tower)
42
+
43
+ def forward(
44
+ self,
45
+ *,
46
+ attention_mask: Optional[Tensor] = None,
47
+ input_ids: Optional[LongTensor] = None,
48
+ inputs_embeds: Optional[FloatTensor] = None,
49
+ pixel_values: Optional[Tensor] = None,
50
+ **kwargs,
51
+ ) -> CausalLMOutputWithPast:
52
+ # Vision info is only used for prefilling.
53
+ if kwargs.get("past_key_values", None) is not None:
54
+ pixel_values = None
55
+
56
+ if inputs_embeds is None:
57
+ assert input_ids is not None
58
+
59
+ inputs_embeds = self._embed(input_ids, pixel_values)
60
+ else:
61
+ assert input_ids is None
62
+ assert pixel_values is None
63
+
64
+ outputs = self.llm.forward(
65
+ inputs_embeds=inputs_embeds,
66
+ attention_mask=attention_mask,
67
+ **kwargs,
68
+ )
69
+
70
+ return outputs
71
+
72
+ @override
73
+ @classmethod
74
+ @modeling_utils.restore_default_torch_dtype
75
+ def from_pretrained(
76
+ cls: Type[modeling_utils.SpecificPreTrainedModelType],
77
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
78
+ *model_args,
79
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
80
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
81
+ ignore_mismatched_sizes: bool = False,
82
+ force_download: bool = False,
83
+ local_files_only: bool = False,
84
+ token: Optional[Union[str, bool]] = None,
85
+ revision: str = "main",
86
+ use_safetensors: Optional[bool] = None,
87
+ weights_only: bool = True,
88
+ **kwargs,
89
+ ) -> modeling_utils.SpecificPreTrainedModelType:
90
+ state_dict = kwargs.pop("state_dict", None)
91
+
92
+ if pretrained_model_name_or_path is not None:
93
+ config = VILAConfig.from_pretrained(
94
+ pretrained_model_name_or_path,
95
+ cache_dir=cache_dir,
96
+ force_download=force_download,
97
+ local_files_only=local_files_only,
98
+ revision=revision,
99
+ use_safetensors=use_safetensors,
100
+ **kwargs,
101
+ )
102
+ else:
103
+ assert (
104
+ config is not None and state_dict is not None
105
+ ), "Both config and state_dict must be provided if pretrained_model_name_or_path is None."
106
+
107
+ inner_model = VILAForCausalLM.from_pretrained(
108
+ pretrained_model_name_or_path, # type: ignore
109
+ *model_args,
110
+ config=config,
111
+ cache_dir=cache_dir,
112
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
113
+ force_download=force_download,
114
+ local_files_only=local_files_only,
115
+ token=token,
116
+ revision=revision,
117
+ use_safetensors=use_safetensors,
118
+ weights_only=weights_only,
119
+ **kwargs,
120
+ )
121
+
122
+ state_dict = inner_model.state_dict()
123
+
124
+ # Prefix keys with "model.".
125
+ # state_dict = {f"model.{k}": v for k, v in state_dict.items()}
126
+
127
+ return super().from_pretrained(
128
+ None,
129
+ inner_model,
130
+ *model_args,
131
+ config=config,
132
+ cache_dir=cache_dir,
133
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
134
+ force_download=force_download,
135
+ local_files_only=local_files_only,
136
+ token=token,
137
+ revision=revision,
138
+ state_dict=state_dict,
139
+ use_safetensors=use_safetensors,
140
+ weights_only=weights_only,
141
+ **kwargs,
142
+ )
143
+
144
+ def _embed(
145
+ self,
146
+ input_ids: LongTensor,
147
+ pixel_values: Optional[Tensor],
148
+ ) -> FloatTensor:
149
+ """Gets the embedding of the input ids and pixel values.
150
+
151
+ Args:
152
+ input_ids: The input ids.
153
+ pixel_values: The pixel values.
154
+
155
+ Returns:
156
+ The embedding of the input ids and pixel values.
157
+ """
158
+
159
+ text_embedding = self.llm.get_input_embeddings().__call__(input_ids)
160
+ text_embedding = cast(FloatTensor, text_embedding)
161
+
162
+ if pixel_values is None:
163
+ return text_embedding
164
+
165
+ image_features: Tensor = self.vision_tower.__call__(pixel_values)
166
+ image_features: Tensor = self.mm_projector.__call__(image_features)
167
+
168
+ n_images, n_feature, dim_feature = image_features.shape
169
+ image_features = image_features.view(n_images * n_feature, dim_feature)
170
+
171
+ image_token_mask = input_ids == IMAGE_TOKEN_ID
172
+
173
+ text_embedding[image_token_mask] = image_features
174
+
175
+ return text_embedding
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 3,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 448,
21
+ "width": 448
22
+ }
23
+ }
processing_vila.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Unpack, cast
2
+
3
+ import numpy as np
4
+ import transformers.image_transforms as image_transforms
5
+ import transformers.image_utils as image_utils
6
+ from numpy.typing import NDArray
7
+ from PIL.Image import Image
8
+ from torch import Tensor
9
+ from transformers.feature_extraction_utils import BatchFeature
10
+ from transformers.image_processing_utils import BaseImageProcessor
11
+ from transformers.image_processing_utils_fast import BaseImageProcessorFast
12
+ from transformers.image_utils import ImageInput, VideoInput
13
+ from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
14
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
15
+ from transformers.tokenization_utils import PreTrainedTokenizer
16
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase, TextInput
17
+
18
+
19
+ class VILAProcessorKwargs(ProcessingKwargs, total=False):
20
+ _defaults = {} # type: ignore
21
+
22
+
23
+ class VILAProcessorOutput(BatchFeature):
24
+ input_ids: List[List[int]] | NDArray[np.int64] | Tensor
25
+ attention_mask: List[List[int]] | NDArray[np.int64] | Tensor
26
+ pixel_values: Optional[List[NDArray[np.float32]] | NDArray[np.float32] | Tensor]
27
+
28
+
29
+ class VILAProcessor(ProcessorMixin):
30
+ attributes: List[str] = [
31
+ "image_processor",
32
+ "tokenizer",
33
+ ]
34
+ image_processor_class: str = "AutoImageProcessor"
35
+ tokenizer_class: str = "AutoTokenizer"
36
+
37
+ # Attributes.
38
+ image_processor: BaseImageProcessor | BaseImageProcessorFast
39
+ tokenizer: PreTrainedTokenizerBase
40
+
41
+ # Configuration parameters.
42
+ image_pad_len: int
43
+ image_token: str
44
+ max_tiles: int
45
+ min_tiles: int
46
+
47
+ def __init__(
48
+ self,
49
+ image_processor: BaseImageProcessor,
50
+ tokenizer: PreTrainedTokenizer,
51
+ *,
52
+ image_pad_len: int,
53
+ image_token: str,
54
+ max_tiles: int,
55
+ min_tiles: int,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(
59
+ image_processor,
60
+ tokenizer,
61
+ **kwargs,
62
+ )
63
+
64
+ self.image_pad_len = image_pad_len
65
+ self.image_token = image_token
66
+ self.max_tiles = max_tiles
67
+ self.min_tiles = min_tiles
68
+
69
+ def __call__(
70
+ self,
71
+ images: Optional[ImageInput] = None,
72
+ text: Optional[TextInput | List[TextInput]] = None,
73
+ audio: None = None,
74
+ videos: Optional[VideoInput] = None,
75
+ **kwargs: Unpack[VILAProcessorKwargs],
76
+ ) -> VILAProcessorOutput:
77
+ # Validate arguments.
78
+ assert text is not None and text != [], "text must be provided"
79
+ assert not kwargs.get(
80
+ "is_split_into_words", False
81
+ ), "is_split_into_words=True is not supported"
82
+
83
+ output_kwargs = self._merge_kwargs(
84
+ VILAProcessorKwargs,
85
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
86
+ **kwargs,
87
+ )
88
+
89
+ # Process images.
90
+ if images is not None and images != []:
91
+ image_inputs, num_cropped_images = self._process_images(
92
+ images=images,
93
+ **output_kwargs["images_kwargs"],
94
+ )
95
+ else:
96
+ # If no images are provided, do not define pixel_values.
97
+ image_inputs = BatchFeature()
98
+ num_cropped_images = []
99
+
100
+ # TODO: video processing.
101
+
102
+ # Process text.
103
+ text = text if isinstance(text, list) else [text]
104
+
105
+ text = self._pad_image_tokens_by_num_crops(
106
+ text,
107
+ num_cropped_images=num_cropped_images,
108
+ )
109
+
110
+ text = self._pad_image_tokens_by_num_embeddings(
111
+ text,
112
+ )
113
+
114
+ text_inputs = self.tokenizer.__call__(
115
+ text,
116
+ **output_kwargs["text_kwargs"],
117
+ )
118
+
119
+ return VILAProcessorOutput(
120
+ data={
121
+ **text_inputs,
122
+ **image_inputs,
123
+ }
124
+ )
125
+
126
+ def _crop_image(
127
+ self,
128
+ image: Image,
129
+ ) -> List[Image]:
130
+ """Crops the image into multiple tiles.
131
+
132
+ Args:
133
+ image: The image to be cropped.
134
+
135
+ Returns:
136
+ The cropped images.
137
+ """
138
+
139
+ # TODO: Support more image processors.
140
+ assert isinstance(self.image_processor, SiglipImageProcessor)
141
+
142
+ assert self.image_processor.size["height"] == self.image_processor.size["width"]
143
+ cropped_size = self.image_processor.size["height"]
144
+
145
+ cropped_images: List[Image] = dynamic_preprocess(
146
+ image,
147
+ min_num=self.min_tiles,
148
+ max_num=self.max_tiles,
149
+ image_size=cropped_size,
150
+ )
151
+
152
+ return cropped_images
153
+
154
+ def _pad_image_tokens_by_num_crops(
155
+ self,
156
+ text: List[TextInput],
157
+ *,
158
+ num_cropped_images: List[int],
159
+ ) -> List[TextInput]:
160
+ """Pads each <image> to num_cropped_images of "<image>\n\n".
161
+
162
+ Args:
163
+ text: The text to be padded.
164
+ num_cropped_images: The number of cropped images for each image token.
165
+
166
+ Returns:
167
+ The padded text.
168
+ """
169
+ # Validate arguments.
170
+ num_images = len(num_cropped_images)
171
+ num_image_tokens = sum([item.count(self.image_token) for item in text])
172
+ assert num_images == num_image_tokens, (
173
+ f"Number of image tokens ({num_image_tokens}) in text does not match "
174
+ f"the number of images ({num_images})."
175
+ )
176
+
177
+ assert all(
178
+ image_pad_len > 0 for image_pad_len in num_cropped_images
179
+ ), "All image padding lengths should be positive integers."
180
+
181
+ # Pad image tokens.
182
+ image_idx = 0
183
+ padded_text: List[TextInput] = []
184
+
185
+ for i in range(len(text)):
186
+ padded_text_item = ""
187
+ remaining_text = text[i]
188
+
189
+ while True:
190
+ token_pos = remaining_text.find(self.image_token)
191
+ if token_pos == -1:
192
+ padded_text_item += remaining_text
193
+ break
194
+
195
+ padded_text_item += remaining_text[:token_pos] + (
196
+ (self.image_token + "\n") * num_cropped_images[image_idx]
197
+ )
198
+
199
+ image_idx += 1
200
+ remaining_text = remaining_text[token_pos + len(self.image_token) :]
201
+
202
+ padded_text.append(padded_text_item)
203
+
204
+ return padded_text
205
+
206
+ def _pad_image_tokens_by_num_embeddings(
207
+ self,
208
+ text: List[TextInput],
209
+ ) -> List[TextInput]:
210
+ """Pads each <image> to image_pad_len times of "<image>".
211
+
212
+ Args:
213
+ text: The text to be padded.
214
+
215
+ Returns:
216
+ The padded text.
217
+ """
218
+ padded_text: List[TextInput] = []
219
+
220
+ for i in range(len(text)):
221
+ padded_text_item = ""
222
+ remaining_text = text[i]
223
+
224
+ while True:
225
+ token_pos = remaining_text.find(self.image_token)
226
+ if token_pos == -1:
227
+ padded_text_item += remaining_text
228
+ break
229
+
230
+ padded_text_item += remaining_text[:token_pos] + (
231
+ self.image_token * self.image_pad_len
232
+ )
233
+
234
+ remaining_text = remaining_text[token_pos + len(self.image_token) :]
235
+
236
+ padded_text.append(padded_text_item)
237
+
238
+ return padded_text
239
+
240
+ def _process_images(
241
+ self,
242
+ images: ImageInput,
243
+ **kwargs: Unpack[VILAProcessorKwargs],
244
+ ) -> Tuple[BatchFeature, List[int]]:
245
+ images_flatten = cast(
246
+ List[Image] | List[NDArray] | List[Tensor],
247
+ image_utils.make_flat_list_of_images(images),
248
+ )
249
+
250
+ cropped_images: List[Image] = []
251
+ num_cropped_images: List[int] = []
252
+ for image in images_flatten:
253
+ pil_image: Image = image_transforms.to_pil_image(image)
254
+ single_cropped_images = self._crop_image(pil_image)
255
+
256
+ cropped_images.extend(single_cropped_images)
257
+ num_cropped_images.append(len(single_cropped_images))
258
+
259
+ image_inputs = self.image_processor(
260
+ cropped_images,
261
+ **kwargs,
262
+ )
263
+
264
+ return image_inputs, num_cropped_images
265
+
266
+
267
+ def dynamic_preprocess(
268
+ image, min_num=1, max_num=12, image_size=384, use_thumbnail=True
269
+ ):
270
+ orig_width, orig_height = image.size
271
+ aspect_ratio = orig_width / orig_height
272
+
273
+ # calculate the existing image aspect ratio
274
+ target_ratios = {
275
+ (i, j)
276
+ for n in range(min_num, max_num + 1)
277
+ for i in range(1, n + 1)
278
+ for j in range(1, n + 1)
279
+ if i * j <= max_num and i * j >= min_num
280
+ }
281
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
282
+
283
+ # find the closest aspect ratio to the target
284
+ target_aspect_ratio = find_closest_aspect_ratio(
285
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
286
+ )
287
+
288
+ # calculate the target width and height
289
+ target_width = image_size * target_aspect_ratio[0]
290
+ target_height = image_size * target_aspect_ratio[1]
291
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
292
+
293
+ # resize the image
294
+ resized_img = image.resize((target_width, target_height))
295
+ processed_images = []
296
+ for i in range(blocks):
297
+ box = (
298
+ (i % (target_width // image_size)) * image_size,
299
+ (i // (target_width // image_size)) * image_size,
300
+ ((i % (target_width // image_size)) + 1) * image_size,
301
+ ((i // (target_width // image_size)) + 1) * image_size,
302
+ )
303
+ # split the image
304
+ split_img = resized_img.crop(box)
305
+ processed_images.append(split_img)
306
+ assert len(processed_images) == blocks
307
+ if use_thumbnail and len(processed_images) != 1:
308
+ thumbnail_img = image.resize((image_size, image_size))
309
+ processed_images.append(thumbnail_img)
310
+ return processed_images
311
+
312
+
313
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
314
+ best_ratio_diff = float("inf")
315
+ best_ratio = (1, 1)
316
+ area = width * height
317
+ for ratio in target_ratios:
318
+ target_aspect_ratio = ratio[0] / ratio[1]
319
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
320
+ if ratio_diff < best_ratio_diff:
321
+ best_ratio_diff = ratio_diff
322
+ best_ratio = ratio
323
+ elif ratio_diff == best_ratio_diff:
324
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
325
+ best_ratio = ratio
326
+ return best_ratio
processor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_vila.VILAProcessor"
4
+ },
5
+ "max_tiles": 12,
6
+ "min_tiles": 1,
7
+ "image_pad_len": 121,
8
+ "image_token": "<image>"
9
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fc37d325d718c91319f527fbe8258c03ac890aba2f252b85af89a625927908a
3
+ size 11419189
tokenizer_config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151646": {
29
+ "content": "[BOS]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "151647": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "151648": {
45
+ "content": "<vila/sentinel>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "151649": {
53
+ "content": "<image>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "151650": {
61
+ "content": "<vila/video>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ }
68
+ },
69
+ "additional_special_tokens": [
70
+ "<|im_start|>",
71
+ "<|im_end|>"
72
+ ],
73
+ "bos_token": "[BOS]",
74
+ "chat_template": "{% if messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages if message['content'] is not none %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
75
+ "clean_up_tokenization_spaces": false,
76
+ "eos_token": "<|im_end|>",
77
+ "errors": "replace",
78
+ "legacy": false,
79
+ "model_max_length": 4096,
80
+ "pad_token": "<|endoftext|>",
81
+ "padding_side": "left",
82
+ "split_special_tokens": false,
83
+ "tokenizer_class": "Qwen2Tokenizer",
84
+ "unk_token": null
85
+ }