beshkenadze commited on
Commit
5d99344
·
verified ·
1 Parent(s): b30ea75

Add 8-bit MLX quant of moondream3-preview (mlx-vlm)

Browse files
Files changed (17) hide show
  1. LICENSE.md +27 -0
  2. README.md +42 -0
  3. config.json +30 -0
  4. config.py +102 -0
  5. hf_moondream.py +190 -0
  6. image_crops.py +231 -0
  7. layers.py +259 -0
  8. lora.py +437 -0
  9. model-00001-of-00002.safetensors +3 -0
  10. model-00002-of-00002.safetensors +3 -0
  11. model.safetensors.index.json +1059 -0
  12. moondream.py +1097 -0
  13. region.py +136 -0
  14. rope.py +47 -0
  15. text.py +223 -0
  16. utils.py +41 -0
  17. vision.py +147 -0
LICENSE.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | License | Business Source License (BSL 1.1) |
2
+ | --- | --- |
3
+ | Licensor | M87 Labs, Inc. |
4
+ | Licensed Work | “Moondream 3 (Preview)” including Model Weights and any Derivatives (“Derivatives” include fine-tunes, merges, quantizations, weight deltas, and other weight-level modifications or conversions.) |
5
+ | Additional Use Grant | You may make production use of the Licensed Work, provided Your use does not include offering the Licensed Work to third parties on a hosted or embedded basis in order to compete with M87 Labs’s paid version(s) of the Licensed Work. For purposes of this license:<br><br>A “competitive offering” is a Product that is offered to third parties on a paid basis, including through paid support arrangements, that significantly overlaps with the capabilities of M87 Labs’s paid version(s) of the Licensed Work. If Your Product is not a competitive offering when You first make it generally available, it will not become a competitive offering later due to M87 Labs releasing a new version of the Licensed Work with additional capabilities. In addition, Products that are not provided on a paid basis are not competitive.<br><br>“Product” means software that is offered to end users to manage in their own environments or offered as a service on a hosted basis.<br><br>“Embedded” means including the source code or executable code from the Licensed Work in a competitive offering. “Embedded” also means packaging the competitive offering in such a way that the Licensed Work must be accessed or downloaded for the competitive offering to operate.<br><br>Hosting or using the Licensed Work(s) for internal purposes within an organization is not considered a competitive offering. M87 Labs considers your organization to include all of your affiliates under common control. |
6
+ | Change Date | Two years after the first public release of this version of the Licensed Work |
7
+ | Change License | Apache License, Version 2.0 |
8
+
9
+ For information about alternative licensing arrangements for the Licensed Work, please contact [contact@m87.ai](mailto:contact@m87.ai).
10
+
11
+ The text of the Business Source License 1.1 follows. License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved. “Business Source License” is a trademark of MariaDB Corporation Ab.
12
+
13
+ ## Terms
14
+
15
+ The Licensor hereby grants you the right to copy, modify, create derivative works, redistribute, and make non-production use of the Licensed Work. The Licensor may make an Additional Use Grant, above, permitting limited production use.
16
+
17
+ Effective on the Change Date, or the fourth anniversary of the first publicly available distribution of a specific version of the Licensed Work under this License, whichever comes first, the Licensor hereby grants you rights under the terms of the Change License, and the rights granted in the paragraph above terminate.
18
+
19
+ If your use of the Licensed Work does not comply with the requirements currently in effect as described in this License, you must purchase a commercial license from the Licensor, its affiliated entities, or authorized resellers, or you must refrain from using the Licensed Work.
20
+
21
+ All copies of the original and modified Licensed Work, and derivative works of the Licensed Work, are subject to this License. This License applies separately for each version of the Licensed Work and the Change Date may vary for each version of the Licensed Work released by Licensor.
22
+
23
+ You must conspicuously display this License on each original or modified copy of the Licensed Work. If you receive the Licensed Work in original or modified form from a third party, the terms and conditions set forth in this License apply to your use of that work.
24
+
25
+ Any use of the Licensed Work in violation of this License will automatically terminate your rights under this License for the current and all other versions of the Licensed Work.
26
+
27
+ This License does not grant you any right in any trademark or logo of Licensor or its affiliates (provided that you may use a trademark or logo of Licensor as expressly required by this License).TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND TITLE.
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: mlx
3
+ pipeline_tag: image-text-to-text
4
+ license: other
5
+ license_name: bsl-1.1
6
+ license_link: LICENSE.md
7
+ base_model: moondream/moondream3-preview
8
+ tags:
9
+ - mlx
10
+ - moondream
11
+ - moondream3
12
+ - vision-language
13
+ - image-text-to-text
14
+ ---
15
+
16
+ # moondream3-preview-mlx-8bit
17
+
18
+ An **8-bit** MLX quantization of [moondream/moondream3-preview](https://huggingface.co/moondream/moondream3-preview)
19
+ for running on Apple Silicon with [mlx-vlm](https://github.com/Blaizzy/mlx-vlm).
20
+
21
+ | | |
22
+ |---|---|
23
+ | Quantization | affine, **8 bits**, group size 64 (vision tower included) |
24
+ | On-disk size | ~9.5 GB |
25
+ | Peak memory | ~11 GB |
26
+ | Tokenizer | loaded from [`moondream/starmie-v1`](https://huggingface.co/moondream/starmie-v1) at runtime (not bundled) |
27
+
28
+ ## Usage
29
+
30
+ ```bash
31
+ pip install mlx-vlm
32
+
33
+ python -m mlx_vlm.generate \
34
+ --model beshkenadze/moondream3-preview-mlx-8bit \
35
+ --image path/to/image.jpg \
36
+ --prompt "Describe this image." \
37
+ --max-tokens 128 --temperature 0.0
38
+ ```
39
+
40
+ ## License
41
+
42
+ moondream3 is released under the **Business Source License 1.1 (BSL 1.1)** — see [LICENSE.md](LICENSE.md). This quantization is a derivative redistribution under the same terms.
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HfMoondream"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_moondream.HfConfig",
7
+ "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
+ },
9
+ "config": {
10
+ "skills": [
11
+ "query",
12
+ "caption",
13
+ "detect",
14
+ "point"
15
+ ]
16
+ },
17
+ "model_type": "moondream3",
18
+ "quantization": {
19
+ "group_size": 64,
20
+ "bits": 8,
21
+ "mode": "affine"
22
+ },
23
+ "quantization_config": {
24
+ "group_size": 64,
25
+ "bits": 8,
26
+ "mode": "affine"
27
+ },
28
+ "torch_dtype": "bfloat16",
29
+ "transformers_version": "4.51.1"
30
+ }
config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class TextMoeConfig:
7
+ num_experts: int = 64
8
+ start_layer: int = 4
9
+ experts_per_token: int = 8
10
+ expert_inner_dim: int = 1024
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TextConfig:
15
+ dim: int = 2048
16
+ ff_dim: int = 8192
17
+ n_layers: int = 24
18
+ vocab_size: int = 51200
19
+ max_context: int = 4096
20
+ n_heads: int = 32
21
+ n_kv_heads: int = 32
22
+ prefix_attn: int = 730
23
+ group_size: Optional[int] = None
24
+ moe: Optional[TextMoeConfig] = TextMoeConfig()
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class VisionConfig:
29
+ enc_dim: int = 1152
30
+ enc_patch_size: int = 14
31
+ enc_n_layers: int = 27
32
+ enc_ff_dim: int = 4304
33
+ enc_n_heads: int = 16
34
+ proj_out_dim: int = 2048
35
+ crop_size: int = 378
36
+ in_channels: int = 3
37
+ max_crops: int = 12
38
+ overlap_margin: int = 4
39
+ proj_inner_dim: int = 8192
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class RegionConfig:
44
+ dim: int = 2048
45
+ coord_feat_dim: int = 256
46
+ coord_out_dim: int = 1024
47
+ size_feat_dim: int = 512
48
+ size_out_dim: int = 2048
49
+ group_size: Optional[int] = None
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class TokenizerConfig:
54
+ bos_id: int = 0
55
+ eos_id: int = 0
56
+ answer_id: int = 3
57
+ thinking_id: int = 4
58
+ coord_id: int = 5
59
+ size_id: int = 6
60
+ start_ground_points_id: int = 7
61
+ end_ground_id: int = 9
62
+ templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
63
+ default_factory=lambda: {
64
+ "caption": {
65
+ "short": [1, 32708, 2, 12492, 3],
66
+ "normal": [1, 32708, 2, 6382, 3],
67
+ "long": [1, 32708, 2, 4059, 3],
68
+ },
69
+ "query": {"prefix": [1, 15381, 2], "suffix": [3]},
70
+ "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
71
+ "point": {"prefix": [1, 2581, 2], "suffix": [3]},
72
+ }
73
+ )
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class MoondreamConfig:
78
+ text: TextConfig = TextConfig()
79
+ vision: VisionConfig = VisionConfig()
80
+ region: RegionConfig = RegionConfig()
81
+ tokenizer: TokenizerConfig = TokenizerConfig()
82
+
83
+ @classmethod
84
+ def from_dict(cls, config_dict: dict):
85
+ text_config = TextConfig(**config_dict.get("text", {}))
86
+ vision_config = VisionConfig(**config_dict.get("vision", {}))
87
+ region_config = RegionConfig(**config_dict.get("region", {}))
88
+ tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
89
+ return cls(
90
+ text=text_config,
91
+ vision=vision_config,
92
+ region=region_config,
93
+ tokenizer=tokenizer_config,
94
+ )
95
+
96
+ def to_dict(self):
97
+ return {
98
+ "text": self.text.__dict__,
99
+ "vision": self.vision.__dict__,
100
+ "region": self.region.__dict__,
101
+ "tokenizer": self.tokenizer.__dict__,
102
+ }
hf_moondream.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from typing import Union
5
+
6
+ from .config import MoondreamConfig
7
+ from .moondream import MoondreamModel
8
+
9
+ # Files sometimes don't get loaded without these...
10
+ from .image_crops import *
11
+ from .vision import *
12
+ from .text import *
13
+ from .region import *
14
+ from .utils import *
15
+
16
+
17
+ def extract_question(text):
18
+ prefix = "<image>\n\nQuestion: "
19
+ suffix = "\n\nAnswer:"
20
+
21
+ if text.startswith(prefix) and text.endswith(suffix):
22
+ return text[len(prefix) : -len(suffix)]
23
+ else:
24
+ return None
25
+
26
+
27
+ class HfConfig(PretrainedConfig):
28
+ _auto_class = "AutoConfig"
29
+ model_type = "moondream3"
30
+
31
+ def __init__(self, **kwargs):
32
+ super().__init__(**kwargs)
33
+ self.config = {"skills": ["query", "caption", "detect", "point"]}
34
+
35
+
36
+ class HfMoondream(PreTrainedModel):
37
+ _auto_class = "AutoModelForCausalLM"
38
+ config_class = HfConfig
39
+
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+ self.model = MoondreamModel(
43
+ MoondreamConfig.from_dict(config.config), setup_caches=False
44
+ )
45
+ self._is_kv_cache_setup = False
46
+ self.post_init()
47
+
48
+ @classmethod
49
+ def from_pretrained(cls, *args, **kwargs):
50
+ output = super().from_pretrained(*args, **kwargs)
51
+ model = output[0] if isinstance(output, tuple) else output
52
+ model.model._refresh_runtime_buffers()
53
+ return output
54
+
55
+ def _setup_caches(self):
56
+ if not self._is_kv_cache_setup:
57
+ self.model._setup_caches()
58
+ self._is_kv_cache_setup = True
59
+
60
+ @property
61
+ def encode_image(self):
62
+ self._setup_caches()
63
+ return self.model.encode_image
64
+
65
+ @property
66
+ def query(self):
67
+ self._setup_caches()
68
+ return self.model.query
69
+
70
+ @property
71
+ def caption(self):
72
+ self._setup_caches()
73
+ return self.model.caption
74
+
75
+ @property
76
+ def detect(self):
77
+ self._setup_caches()
78
+ return self.model.detect
79
+
80
+ @property
81
+ def point(self):
82
+ self._setup_caches()
83
+ return self.model.point
84
+
85
+ @property
86
+ def detect_gaze(self):
87
+ self._setup_caches()
88
+ return self.model.detect_gaze
89
+
90
+ def answer_question(
91
+ self,
92
+ image_embeds,
93
+ question,
94
+ tokenizer=None,
95
+ chat_history="",
96
+ result_queue=None,
97
+ max_new_tokens=256,
98
+ **kwargs
99
+ ):
100
+ answer = self.query(image_embeds, question)["answer"].strip()
101
+
102
+ if result_queue is not None:
103
+ result_queue.put(answer)
104
+ return answer
105
+
106
+ def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
107
+ answers = []
108
+ for image, prompt in zip(images, prompts):
109
+ answers.append(self.query(image, prompt)["answer"].strip())
110
+ return answers
111
+
112
+ def _unsupported_exception(self):
113
+ raise NotImplementedError(
114
+ "This method is not supported in the latest version of moondream. "
115
+ "Consider upgrading to the updated API spec, or alternately pin "
116
+ "to 'revision=2024-08-26'."
117
+ )
118
+
119
+ def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
120
+ """
121
+ Function definition remains unchanged for backwards compatibility.
122
+ Be aware that tokenizer, max_new_takens, and kwargs are ignored.
123
+ """
124
+ prompt_extracted = extract_question(prompt)
125
+ if prompt_extracted is not None:
126
+ answer = self.model.query(
127
+ image=image_embeds, question=prompt_extracted, stream=False
128
+ )["answer"]
129
+ else:
130
+ image_embeds = self.encode_image(image_embeds)
131
+ prompt_tokens = torch.tensor(
132
+ [self.model.tokenizer.encode(prompt).ids],
133
+ device=self.device,
134
+ )
135
+
136
+ def generator():
137
+ for token in self.model._generate_answer(
138
+ prompt_tokens,
139
+ image_embeds.kv_cache,
140
+ image_embeds.pos,
141
+ max_new_tokens,
142
+ ):
143
+ yield token
144
+
145
+ answer = "".join(list(generator()))
146
+
147
+ return [answer]
148
+
149
+ def get_input_embeddings(self) -> nn.Embedding:
150
+ """
151
+ Lazily wrap the raw parameter `self.model.text.wte` in a real
152
+ `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
153
+ **shares** the weight tensor—no copy is made.
154
+ """
155
+ if not hasattr(self, "_input_embeddings"):
156
+ self._input_embeddings = nn.Embedding.from_pretrained(
157
+ self.model.text.wte, # tensor created in text.py
158
+ freeze=True, # set to False if you need it trainable
159
+ )
160
+ return self._input_embeddings
161
+
162
+ def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
163
+ """
164
+ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
165
+ embeddings and keeps everything tied to `self.model.text.wte`.
166
+ """
167
+ # 1. point the low-level parameter to the new weight matrix
168
+ self.model.text.wte = value.weight
169
+ # 2. keep a reference for get_input_embeddings()
170
+ self._input_embeddings = value
171
+
172
+ def input_embeds(
173
+ self,
174
+ input_ids: Union[torch.LongTensor, list, tuple],
175
+ *,
176
+ device: torch.device | None = None
177
+ ) -> torch.FloatTensor:
178
+ """
179
+ Back-compat wrapper that turns token IDs into embeddings.
180
+
181
+ Example:
182
+ ids = torch.tensor([[1, 2, 3]])
183
+ embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
184
+ """
185
+ if not torch.is_tensor(input_ids):
186
+ input_ids = torch.as_tensor(input_ids)
187
+ if device is not None:
188
+ input_ids = input_ids.to(device)
189
+
190
+ return self.get_input_embeddings()(input_ids)
image_crops.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ from typing import TypedDict
6
+
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
+
17
+ def select_tiling(
18
+ height: int, width: int, crop_size: int, max_crops: int
19
+ ) -> tuple[int, int]:
20
+ """
21
+ Determine the optimal number of tiles to cover an image with overlapping crops.
22
+ """
23
+ if height <= crop_size or width <= crop_size:
24
+ return (1, 1)
25
+
26
+ # Minimum required tiles in each dimension
27
+ min_h = math.ceil(height / crop_size)
28
+ min_w = math.ceil(width / crop_size)
29
+
30
+ # If minimum required tiles exceed max_crops, return proportional distribution
31
+ if min_h * min_w > max_crops:
32
+ ratio = math.sqrt(max_crops / (min_h * min_w))
33
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
34
+
35
+ # Perfect aspect-ratio tiles that satisfy max_crops
36
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
37
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
38
+
39
+ # Ensure we meet minimum tile requirements
40
+ h_tiles = max(h_tiles, min_h)
41
+ w_tiles = max(w_tiles, min_w)
42
+
43
+ # If we exceeded max_crops, scale down the larger dimension
44
+ if h_tiles * w_tiles > max_crops:
45
+ if w_tiles > h_tiles:
46
+ w_tiles = math.floor(max_crops / h_tiles)
47
+ else:
48
+ h_tiles = math.floor(max_crops / w_tiles)
49
+
50
+ return (max(1, h_tiles), max(1, w_tiles))
51
+
52
+
53
+ class OverlapCropOutput(TypedDict):
54
+ crops: np.ndarray
55
+ tiling: tuple[int, int]
56
+
57
+
58
+ def overlap_crop_image(
59
+ image: np.ndarray,
60
+ overlap_margin: int,
61
+ max_crops: int,
62
+ base_size: tuple[int, int] = (378, 378),
63
+ patch_size: int = 14,
64
+ ) -> OverlapCropOutput:
65
+ """
66
+ Process an image using an overlap-and-resize cropping strategy with margin handling.
67
+
68
+ This function takes an input image and creates multiple overlapping crops with
69
+ consistent margins. It produces:
70
+ 1. A single global crop resized to base_size
71
+ 2. Multiple overlapping local crops that maintain high resolution details
72
+ 3. A patch ordering matrix that tracks correspondence between crops
73
+
74
+ The overlap strategy ensures:
75
+ - Smooth transitions between adjacent crops
76
+ - No loss of information at crop boundaries
77
+ - Proper handling of features that cross crop boundaries
78
+ - Consistent patch indexing across the full image
79
+
80
+ Args:
81
+ image (np.ndarray): Input image as numpy array with shape (H,W,C)
82
+ base_size (tuple[int,int]): Target size for crops, default (378,378)
83
+ patch_size (int): Size of patches in pixels, default 14
84
+ overlap_margin (int): Margin size in patch units, default 4
85
+ max_crops (int): Maximum number of crops allowed, default 12
86
+
87
+ Returns:
88
+ OverlapCropOutput: Dictionary containing:
89
+ - crops: A numpy array containing the global crop of the full image (index 0)
90
+ followed by the overlapping cropped regions (indices 1+)
91
+ - tiling: Tuple of (height,width) tile counts
92
+ """
93
+ original_h, original_w = image.shape[:2]
94
+
95
+ # Convert margin from patch units to pixels
96
+ margin_pixels = patch_size * overlap_margin
97
+ total_margin_pixels = margin_pixels * 2 # Both sides
98
+
99
+ # Calculate crop parameters
100
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
101
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
102
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
103
+
104
+ # Determine tiling
105
+ tiling = select_tiling(
106
+ original_h - total_margin_pixels,
107
+ original_w - total_margin_pixels,
108
+ crop_window_size,
109
+ max_crops,
110
+ )
111
+
112
+ # Pre-allocate crops.
113
+ n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
114
+ crops = np.zeros(
115
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
116
+ )
117
+
118
+ # Resize image to fit tiling
119
+ target_size = (
120
+ tiling[0] * crop_window_size + total_margin_pixels,
121
+ tiling[1] * crop_window_size + total_margin_pixels,
122
+ )
123
+
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
+
152
+ for i in range(tiling[0]):
153
+ for j in range(tiling[1]):
154
+ # Calculate crop coordinates
155
+ y0 = i * crop_window_size
156
+ x0 = j * crop_window_size
157
+
158
+ # Extract crop with padding if needed
159
+ y_end = min(y0 + base_size[0], image.shape[0])
160
+ x_end = min(x0 + base_size[1], image.shape[1])
161
+
162
+ crop_region = image[y0:y_end, x0:x_end]
163
+ crops[
164
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
165
+ ] = crop_region
166
+
167
+ return {"crops": crops, "tiling": tiling}
168
+
169
+
170
+ def reconstruct_from_crops(
171
+ crops: torch.Tensor,
172
+ tiling: tuple[int, int],
173
+ overlap_margin: int,
174
+ patch_size: int = 14,
175
+ ) -> torch.Tensor:
176
+ """
177
+ Reconstruct the original image from overlapping crops into a single seamless image.
178
+
179
+ Takes a list of overlapping image crops along with their positional metadata and
180
+ reconstructs them into a single coherent image by carefully stitching together
181
+ non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
182
+
183
+ Args:
184
+ crops: List of image crops as numpy arrays or PyTorch tensors with shape
185
+ (H,W,C)
186
+ tiling: Tuple of (height,width) indicating crop grid layout
187
+ patch_size: Size in pixels of each patch, default 14
188
+ overlap_margin: Number of overlapping patches on each edge, default 4
189
+
190
+ Returns:
191
+ Reconstructed image as numpy array or PyTorch tensor matching input type,
192
+ with shape (H,W,C) where H,W are the original image dimensions
193
+ """
194
+ tiling_h, tiling_w = tiling
195
+ crop_height, crop_width = crops[0].shape[:2]
196
+ margin_pixels = overlap_margin * patch_size
197
+
198
+ # Calculate output size (only adding margins once)
199
+ output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
200
+ output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
201
+
202
+ reconstructed = torch.zeros(
203
+ (output_h, output_w, crops[0].shape[2]),
204
+ device=crops[0].device,
205
+ dtype=crops[0].dtype,
206
+ )
207
+
208
+ for i, crop in enumerate(crops):
209
+ tile_y = i // tiling_w
210
+ tile_x = i % tiling_w
211
+
212
+ # For each tile, determine which part to keep
213
+ # Keep left margin only for first column
214
+ x_start = 0 if tile_x == 0 else margin_pixels
215
+ # Keep right margin only for last column
216
+ x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
217
+ # Keep top margin only for first row
218
+ y_start = 0 if tile_y == 0 else margin_pixels
219
+ # Keep bottom margin only for last row
220
+ y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
221
+
222
+ # Calculate where this piece belongs in the output
223
+ out_x = tile_x * (crop_width - 2 * margin_pixels)
224
+ out_y = tile_y * (crop_height - 2 * margin_pixels)
225
+
226
+ # Place the piece
227
+ reconstructed[
228
+ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
229
+ ] = crop[y_start:y_end, x_start:x_end]
230
+
231
+ return reconstructed
layers.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Literal, Optional
7
+
8
+ from .lora import (
9
+ DenseLoRALayer,
10
+ MoELoRALayer,
11
+ apply_dense_lora,
12
+ apply_moe_lora_fc1_flat,
13
+ apply_moe_lora_fc2_flat,
14
+ )
15
+
16
+ try:
17
+ from torchao import quantize_
18
+ from torchao.quantization import int4_weight_only
19
+ except ImportError:
20
+
21
+ def quantize_(model, quant_mode):
22
+ raise ImportError(
23
+ "torchao is not installed. Please install it with `pip install torchao`."
24
+ )
25
+
26
+ def int4_weight_only(group_size):
27
+ raise ImportError(
28
+ "torchao is not installed. Please install it with `pip install torchao`."
29
+ )
30
+
31
+
32
+ def gelu_approx(x):
33
+ return F.gelu(x, approximate="tanh")
34
+
35
+
36
+ @dataclass
37
+ class LinearWeights:
38
+ weight: torch.Tensor
39
+ bias: torch.Tensor
40
+
41
+
42
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
43
+ return F.linear(x, w.weight, w.bias)
44
+
45
+
46
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
47
+ _step = W_q.shape[0]
48
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
49
+ W_r[:_step] = (W_q & 0b11110000) >> 4
50
+ W_r[_step:] = W_q & 0b00001111
51
+ W_r.sub_(zero).mul_(scale)
52
+ return W_r.reshape(orig_shape)
53
+
54
+
55
+ class QuantizedLinear(nn.Module):
56
+ def __init__(
57
+ self,
58
+ in_features: int,
59
+ out_features: int,
60
+ dtype: torch.dtype,
61
+ ):
62
+ # TODO: Take group_size as an input instead of hardcoding it here.
63
+ super().__init__()
64
+ self.in_features = in_features
65
+ self.out_features = out_features
66
+ self.weight = nn.ParameterDict(
67
+ {
68
+ "packed": nn.Parameter(
69
+ torch.empty(
70
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
71
+ ),
72
+ requires_grad=False,
73
+ ),
74
+ "scale": nn.Parameter(
75
+ torch.empty(out_features * in_features // 128, 1),
76
+ requires_grad=False,
77
+ ),
78
+ "zero_point": nn.Parameter(
79
+ torch.empty(out_features * in_features // 128, 1),
80
+ requires_grad=False,
81
+ ),
82
+ }
83
+ )
84
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
85
+ self.unpacked = False
86
+
87
+ def unpack(self):
88
+ if self.unpacked:
89
+ return
90
+
91
+ self.weight = nn.Parameter(
92
+ dequantize_tensor(
93
+ self.weight["packed"],
94
+ self.weight["scale"],
95
+ self.weight["zero_point"],
96
+ (self.out_features, self.in_features),
97
+ torch.bfloat16,
98
+ )
99
+ )
100
+ with torch.device("meta"):
101
+ self.linear = nn.Linear(
102
+ self.in_features, self.out_features, dtype=torch.bfloat16
103
+ )
104
+ self.linear.weight = self.weight
105
+ self.linear.bias = nn.Parameter(
106
+ self.bias.to(torch.bfloat16), requires_grad=False
107
+ )
108
+
109
+ del self.weight, self.bias
110
+ quantize_(self, int4_weight_only(group_size=128))
111
+ self.unpacked = True
112
+ torch.cuda.empty_cache()
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ if not self.unpacked:
116
+ self.unpack()
117
+ return self.linear(x)
118
+
119
+
120
+ @dataclass
121
+ class LayerNormWeights:
122
+ weight: torch.Tensor
123
+ bias: torch.Tensor
124
+
125
+
126
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
127
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
128
+
129
+
130
+ @dataclass
131
+ class MLPWeights:
132
+ fc1: LinearWeights
133
+ fc2: LinearWeights
134
+ act: Literal["gelu_approx"] = "gelu_approx"
135
+
136
+
137
+ def mlp(
138
+ x: torch.Tensor, w: MLPWeights, lora: Optional[DenseLoRALayer] = None
139
+ ) -> torch.Tensor:
140
+ x0 = w.fc1(x)
141
+ if lora is not None:
142
+ x = x0 + apply_dense_lora(x, lora.up_a, lora.up_b)
143
+ else:
144
+ x = x0
145
+
146
+ x = gelu_approx(x)
147
+
148
+ x0 = w.fc2(x)
149
+ if lora is not None:
150
+ x = x0 + apply_dense_lora(x, lora.down_a, lora.down_b)
151
+ else:
152
+ x = x0
153
+
154
+ return x
155
+
156
+
157
+ def moe_mlp(
158
+ x: torch.Tensor,
159
+ mlp_module: nn.Module,
160
+ experts_per_token: int,
161
+ lora: Optional[MoELoRALayer] = None,
162
+ ) -> torch.Tensor:
163
+ B, T, C = x.shape
164
+ x = x.reshape(-1, C)
165
+
166
+ # Router computation
167
+ router_logits = mlp_module.router(x)
168
+ topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)
169
+ topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)
170
+ num_tokens, top_k = topk_idxs.shape
171
+
172
+ if T == 1:
173
+ w1_weight = mlp_module.fc1.weight
174
+ w2_weight = mlp_module.fc2.weight
175
+
176
+ # Flatten to process all token-expert pairs at once
177
+ flat_idxs = topk_idxs.view(-1) # [T*A]
178
+ flat_weights = topk_weights.view(-1) # [T*A]
179
+
180
+ # Select expert weights
181
+ w1_selected = w1_weight[flat_idxs]
182
+ w2_selected = w2_weight[flat_idxs]
183
+
184
+ # Expand input for all token-expert pairs
185
+ x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
186
+
187
+ # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
188
+ x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(-1) # [T*A, H]
189
+ if lora is not None:
190
+ x1_full = x1_full + apply_moe_lora_fc1_flat(x_expanded, lora, flat_idxs)
191
+ x1, g = x1_full.chunk(2, dim=-1)
192
+ x1 = F.gelu(x1) * (g + 1)
193
+
194
+ # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
195
+ expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
196
+ if lora is not None:
197
+ expert_outs = expert_outs + apply_moe_lora_fc2_flat(x1, lora, flat_idxs)
198
+
199
+ # Apply weights and reshape
200
+ weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
201
+ weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D]
202
+
203
+ # Sum over experts
204
+ mlp_out = weighted_outs.sum(dim=1) # [T, D]
205
+ mlp_out = mlp_out.view(B, T, C)
206
+
207
+ return mlp_out
208
+ else:
209
+ out = x.new_zeros(x.size())
210
+
211
+ for expert_id in range(mlp_module.fc1.weight.shape[0]):
212
+ token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)
213
+ if token_pos.numel() == 0:
214
+ continue
215
+
216
+ x_tok = x.index_select(0, token_pos)
217
+ gate_tok = topk_weights[token_pos, which_k]
218
+
219
+ w1 = mlp_module.fc1.weight[expert_id]
220
+ h_full = F.linear(x_tok, w1)
221
+ if lora is not None:
222
+ lora_up_a = lora.up_a[expert_id]
223
+ lora_up_b = lora.up_b[expert_id]
224
+ lora_mid = F.linear(x_tok, lora_up_a)
225
+ h_full = h_full + F.linear(lora_mid, lora_up_b)
226
+ h, g = h_full.chunk(2, dim=-1)
227
+ h = F.gelu(h) * (g + 1)
228
+ w2 = mlp_module.fc2.weight[expert_id]
229
+ y = F.linear(h, w2)
230
+ if lora is not None:
231
+ lora_down_a = lora.down_a[expert_id]
232
+ lora_down_b = lora.down_b[expert_id]
233
+ lora_mid = F.linear(h, lora_down_a)
234
+ y = y + F.linear(lora_mid, lora_down_b)
235
+
236
+ y.mul_(gate_tok.unsqueeze(-1))
237
+ out.index_add_(0, token_pos, y)
238
+
239
+ return out.view(B, T, C)
240
+
241
+
242
+ @dataclass
243
+ class AttentionWeights:
244
+ qkv: LinearWeights
245
+ proj: LinearWeights
246
+
247
+
248
+ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
249
+ bsz, q_len, d_model = x.shape
250
+ head_dim = d_model // n_heads
251
+
252
+ q, k, v = [
253
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
254
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
255
+ ]
256
+ out = F.scaled_dot_product_attention(q, k, v)
257
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
258
+ out = linear(out, w.proj)
259
+ return out
lora.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple
8
+ from urllib.request import Request, urlopen
9
+
10
+ import torch
11
+
12
+ from .config import TextConfig
13
+
14
+
15
+ class AdapterLoadError(RuntimeError):
16
+ pass
17
+
18
+
19
+ def _cache_root() -> Path:
20
+ hf_hub_cache = os.environ.get("HF_HUB_CACHE")
21
+ if hf_hub_cache:
22
+ return Path(hf_hub_cache)
23
+
24
+ hf_home = os.environ.get("HF_HOME")
25
+ if hf_home:
26
+ return Path(hf_home) / "hub"
27
+
28
+ return Path("~/.cache/huggingface/hub").expanduser()
29
+
30
+
31
+ def adapter_cache_dir() -> Path:
32
+ return _cache_root() / "md_finetunes"
33
+
34
+
35
+ def normalize_adapter_id(value: Optional[str]) -> Optional[str]:
36
+ if not value:
37
+ return None
38
+ tail = value.split("/")[-1].strip()
39
+ if "@" not in tail:
40
+ return None
41
+ return tail
42
+
43
+
44
+ def parse_adapter_id(adapter_id: str) -> Tuple[str, str]:
45
+ if not adapter_id or "@" not in adapter_id:
46
+ raise AdapterLoadError(
47
+ f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
48
+ )
49
+ finetune_id, step = adapter_id.split("@", 1)
50
+ if not finetune_id or not step:
51
+ raise AdapterLoadError(
52
+ f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
53
+ )
54
+ return finetune_id, step
55
+
56
+
57
+ def _fetch_presigned_url(finetune_id: str, step: str) -> str:
58
+ endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai").rstrip("/")
59
+ api_key = os.getenv("MOONDREAM_API_KEY")
60
+ if not api_key:
61
+ raise AdapterLoadError("MOONDREAM_API_KEY is required to load finetune adapters.")
62
+
63
+ headers = {"User-Agent": "moondream-torch", "X-Moondream-Auth": api_key}
64
+ url = f"{endpoint}/v1/tuning/finetunes/{finetune_id}/checkpoints/{step}/download"
65
+ req = Request(url, headers=headers)
66
+ try:
67
+ with urlopen(req) as r:
68
+ payload = json.loads(r.read().decode("utf-8"))
69
+ except Exception as e:
70
+ raise AdapterLoadError(f"Failed to fetch adapter URL: {e}") from e
71
+
72
+ presigned = payload.get("url")
73
+ if not presigned:
74
+ raise AdapterLoadError("Adapter URL response missing 'url' field.")
75
+ return presigned
76
+
77
+
78
+ def cached_adapter_path(adapter_id: str) -> Path:
79
+ finetune_id, step = parse_adapter_id(adapter_id)
80
+
81
+ cache_dir = adapter_cache_dir() / finetune_id / step
82
+ cache_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ for name in ("adapter.pt", "adapter.safetensors"):
85
+ path = cache_dir / name
86
+ if path.exists() and path.stat().st_size > 0:
87
+ return path
88
+
89
+ presigned_url = _fetch_presigned_url(finetune_id, step)
90
+ dest = cache_dir / "adapter.pt"
91
+
92
+ try:
93
+ with urlopen(presigned_url) as r, open(dest, "wb") as f:
94
+ shutil.copyfileobj(r, f)
95
+ except Exception as e:
96
+ raise AdapterLoadError(f"Failed to download adapter: {e}") from e
97
+ return dest
98
+
99
+
100
+ def _load_state_dict(path: Path, device: torch.device) -> Dict[str, Any]:
101
+ if path.suffix == ".safetensors":
102
+ try:
103
+ from safetensors.torch import safe_open
104
+ except Exception as e:
105
+ raise AdapterLoadError(
106
+ "safetensors is required to load .safetensors adapters."
107
+ ) from e
108
+ data = {}
109
+ with safe_open(str(path), framework="pt") as f:
110
+ for key in f.keys():
111
+ data[key] = f.get_tensor(key).to(device=device)
112
+ return data
113
+
114
+ try:
115
+ return torch.load(path, map_location=device, weights_only=True)
116
+ except TypeError:
117
+ return torch.load(path, map_location=device)
118
+
119
+
120
+ @dataclass
121
+ class DenseLoRALayer:
122
+ up_a: torch.Tensor
123
+ up_b: torch.Tensor
124
+ down_a: torch.Tensor
125
+ down_b: torch.Tensor
126
+
127
+
128
+ @dataclass
129
+ class MoELoRALayer:
130
+ up_a: torch.Tensor
131
+ up_b: torch.Tensor
132
+ down_a: torch.Tensor
133
+ down_b: torch.Tensor
134
+
135
+
136
+ class TextLoRA:
137
+ def __init__(
138
+ self,
139
+ text_config: TextConfig,
140
+ *,
141
+ rank: int,
142
+ max_rank: int,
143
+ dtype: torch.dtype,
144
+ device: torch.device,
145
+ adapter_id: Optional[str] = None,
146
+ ) -> None:
147
+ if rank <= 0:
148
+ raise AdapterLoadError("LoRA rank must be positive.")
149
+ if max_rank < rank:
150
+ raise AdapterLoadError("max_rank must be >= rank.")
151
+
152
+ self.text_config = text_config
153
+ self.rank = rank
154
+ self.max_rank = max_rank
155
+ self.adapter_id = adapter_id
156
+
157
+ moe_cfg = text_config.moe
158
+ self.start_layer = moe_cfg.start_layer if moe_cfg else text_config.n_layers
159
+
160
+ if moe_cfg is not None:
161
+ self.rank_per_expert = rank // moe_cfg.experts_per_token
162
+ if self.rank_per_expert < 1:
163
+ raise AdapterLoadError(
164
+ f"rank ({rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
165
+ )
166
+ self.max_rank_per_expert = max_rank // moe_cfg.experts_per_token
167
+ if self.max_rank_per_expert < 1:
168
+ raise AdapterLoadError(
169
+ f"max_rank ({max_rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
170
+ )
171
+ else:
172
+ self.rank_per_expert = 0
173
+ self.max_rank_per_expert = 0
174
+
175
+ d_model = text_config.dim
176
+ d_ffn = text_config.ff_dim
177
+
178
+ self.dense: list[DenseLoRALayer] = []
179
+ for _ in range(self.start_layer):
180
+ self.dense.append(
181
+ DenseLoRALayer(
182
+ up_a=torch.zeros((max_rank, d_model), device=device, dtype=dtype),
183
+ up_b=torch.zeros((d_ffn, max_rank), device=device, dtype=dtype),
184
+ down_a=torch.zeros((max_rank, d_ffn), device=device, dtype=dtype),
185
+ down_b=torch.zeros((d_model, max_rank), device=device, dtype=dtype),
186
+ )
187
+ )
188
+
189
+ self.moe: list[MoELoRALayer] = []
190
+ if moe_cfg is not None:
191
+ num_experts = moe_cfg.num_experts
192
+ d_expert = moe_cfg.expert_inner_dim
193
+ for _ in range(text_config.n_layers - self.start_layer):
194
+ self.moe.append(
195
+ MoELoRALayer(
196
+ up_a=torch.zeros(
197
+ (num_experts, self.max_rank_per_expert, d_model),
198
+ device=device,
199
+ dtype=dtype,
200
+ ),
201
+ up_b=torch.zeros(
202
+ (num_experts, d_expert * 2, self.max_rank_per_expert),
203
+ device=device,
204
+ dtype=dtype,
205
+ ),
206
+ down_a=torch.zeros(
207
+ (num_experts, self.max_rank_per_expert, d_expert),
208
+ device=device,
209
+ dtype=dtype,
210
+ ),
211
+ down_b=torch.zeros(
212
+ (num_experts, d_model, self.max_rank_per_expert),
213
+ device=device,
214
+ dtype=dtype,
215
+ ),
216
+ )
217
+ )
218
+
219
+ def dense_layer(self, layer_idx: int) -> Optional[DenseLoRALayer]:
220
+ if layer_idx < len(self.dense):
221
+ return self.dense[layer_idx]
222
+ return None
223
+
224
+ def moe_layer(self, layer_idx: int) -> Optional[MoELoRALayer]:
225
+ moe_idx = layer_idx - self.start_layer
226
+ if 0 <= moe_idx < len(self.moe):
227
+ return self.moe[moe_idx]
228
+ return None
229
+
230
+ @staticmethod
231
+ def _pad_axis(tensor: torch.Tensor, target: int, axis: int) -> torch.Tensor:
232
+ if tensor.shape[axis] == target:
233
+ return tensor
234
+ if tensor.shape[axis] > target:
235
+ raise AdapterLoadError(
236
+ f"LoRA tensor rank {tensor.shape[axis]} exceeds max {target}"
237
+ )
238
+ pad_shape = list(tensor.shape)
239
+ pad_shape[axis] = target - tensor.shape[axis]
240
+ pad = torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)
241
+ return torch.cat([tensor, pad], dim=axis)
242
+
243
+ @staticmethod
244
+ def detect_rank(state_dict: Dict[str, Any], text_config: TextConfig) -> int:
245
+ for key, tensor in state_dict.items():
246
+ if "dense" in key and "up_a" in key:
247
+ return int(tensor.shape[0])
248
+ for key, tensor in state_dict.items():
249
+ if "moe" in key and "up_a" in key:
250
+ rank_per_expert = int(tensor.shape[1])
251
+ moe_cfg = text_config.moe
252
+ if moe_cfg:
253
+ return rank_per_expert * moe_cfg.experts_per_token
254
+ return rank_per_expert
255
+ raise AdapterLoadError("Could not detect LoRA rank from state dict.")
256
+
257
+ @classmethod
258
+ def from_state_dict(
259
+ cls,
260
+ state_dict: Dict[str, Any],
261
+ *,
262
+ text_config: TextConfig,
263
+ max_rank: int,
264
+ dtype: torch.dtype,
265
+ device: torch.device,
266
+ adapter_id: Optional[str] = None,
267
+ ) -> "TextLoRA":
268
+ rank = cls.detect_rank(state_dict, text_config)
269
+ if rank > max_rank:
270
+ raise AdapterLoadError(
271
+ f"Adapter rank ({rank}) exceeds max_rank ({max_rank})."
272
+ )
273
+
274
+ lora = cls(
275
+ text_config,
276
+ rank=rank,
277
+ max_rank=max_rank,
278
+ dtype=dtype,
279
+ device=device,
280
+ adapter_id=adapter_id,
281
+ )
282
+
283
+ dense_seen = set()
284
+ moe_seen = set()
285
+
286
+ pattern = re.compile(r"(dense|moe)\.(\d+)\.(up_a|up_b|down_a|down_b)$")
287
+ for key, tensor in state_dict.items():
288
+ match = pattern.search(key)
289
+ if not match:
290
+ continue
291
+ kind, idx_str, name = match.group(1), match.group(2), match.group(3)
292
+ idx = int(idx_str)
293
+ arr = tensor.to(device=device, dtype=dtype)
294
+
295
+ if kind == "dense":
296
+ if idx >= len(lora.dense):
297
+ raise AdapterLoadError(f"Dense LoRA layer index {idx} out of range.")
298
+ layer = lora.dense[idx]
299
+ if name in ("up_a", "down_a"):
300
+ arr = cls._pad_axis(arr, lora.max_rank, axis=0)
301
+ else:
302
+ arr = cls._pad_axis(arr, lora.max_rank, axis=1)
303
+ setattr(layer, name, arr)
304
+ dense_seen.add((idx, name))
305
+ else:
306
+ if idx >= len(lora.moe):
307
+ raise AdapterLoadError(f"MoE LoRA layer index {idx} out of range.")
308
+ layer = lora.moe[idx]
309
+ if name in ("up_a", "down_a"):
310
+ arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=1)
311
+ else:
312
+ arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=2)
313
+ setattr(layer, name, arr)
314
+ moe_seen.add((idx, name))
315
+
316
+ for layer_idx in range(len(lora.dense)):
317
+ for name in ("up_a", "up_b", "down_a", "down_b"):
318
+ if (layer_idx, name) not in dense_seen:
319
+ raise AdapterLoadError(
320
+ f"Adapter missing dense LoRA for layer {layer_idx} ({name})."
321
+ )
322
+ for layer_idx in range(len(lora.moe)):
323
+ for name in ("up_a", "up_b", "down_a", "down_b"):
324
+ if (layer_idx, name) not in moe_seen:
325
+ raise AdapterLoadError(
326
+ f"Adapter missing MoE LoRA for layer {layer_idx} ({name})."
327
+ )
328
+
329
+ return lora
330
+
331
+
332
+ def select_layer_lora(
333
+ lora: Optional[TextLoRA], layer_idx: int, *, is_moe: bool
334
+ ) -> Optional[object]:
335
+ if lora is None:
336
+ return None
337
+ return lora.moe_layer(layer_idx) if is_moe else lora.dense_layer(layer_idx)
338
+
339
+
340
+ def apply_dense_lora(
341
+ x: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor
342
+ ) -> torch.Tensor:
343
+ b, t, c = x.shape
344
+ x_flat = x.reshape(-1, c)
345
+ lora_mid = torch.matmul(x_flat, lora_a.t())
346
+ lora_out = torch.matmul(lora_mid, lora_b.t())
347
+ return lora_out.reshape(b, t, -1)
348
+
349
+
350
+ def apply_moe_lora_fc1_flat(
351
+ x_expanded: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
352
+ ) -> torch.Tensor:
353
+ lora_up_a = lora.up_a[flat_idxs]
354
+ lora_up_b = lora.up_b[flat_idxs]
355
+ lora_mid = torch.bmm(lora_up_a, x_expanded.unsqueeze(-1)).squeeze(-1)
356
+ lora_up = torch.bmm(lora_up_b, lora_mid.unsqueeze(-1)).squeeze(-1)
357
+ return lora_up
358
+
359
+
360
+ def apply_moe_lora_fc2_flat(
361
+ h: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
362
+ ) -> torch.Tensor:
363
+ lora_down_a = lora.down_a[flat_idxs]
364
+ lora_down_b = lora.down_b[flat_idxs]
365
+ lora_mid = torch.bmm(lora_down_a, h.unsqueeze(-1)).squeeze(-1)
366
+ lora_down = torch.bmm(lora_down_b, lora_mid.unsqueeze(-1)).squeeze(-1)
367
+ return lora_down
368
+
369
+
370
+ _ADAPTER_CACHE: Dict[Tuple[str, str, str, Tuple], TextLoRA] = {}
371
+ _CACHE_ORDER: list[Tuple[str, str, str, Tuple]] = []
372
+ _CACHE_SIZE = 8
373
+
374
+
375
+ def _config_key(text_config: TextConfig) -> Tuple:
376
+ moe = text_config.moe
377
+ moe_key = None
378
+ if moe is not None:
379
+ moe_key = (
380
+ moe.num_experts,
381
+ moe.start_layer,
382
+ moe.experts_per_token,
383
+ moe.expert_inner_dim,
384
+ )
385
+ return (
386
+ text_config.dim,
387
+ text_config.ff_dim,
388
+ text_config.n_layers,
389
+ moe_key,
390
+ )
391
+
392
+
393
+ def load_adapter(
394
+ adapter_id: Optional[str],
395
+ *,
396
+ text_config: TextConfig,
397
+ device: torch.device,
398
+ dtype: torch.dtype,
399
+ max_rank: int = 16,
400
+ ) -> Optional[TextLoRA]:
401
+ if adapter_id is None:
402
+ return None
403
+
404
+ adapter_id = normalize_adapter_id(adapter_id)
405
+ if adapter_id is None:
406
+ return None
407
+
408
+ key = (adapter_id, str(device), str(dtype), _config_key(text_config))
409
+ cached = _ADAPTER_CACHE.get(key)
410
+ if cached is not None:
411
+ return cached
412
+
413
+ path = cached_adapter_path(adapter_id)
414
+ checkpoint = _load_state_dict(path, device)
415
+ if not isinstance(checkpoint, dict):
416
+ raise AdapterLoadError("Invalid adapter checkpoint format.")
417
+
418
+ state_dict = checkpoint.get("lora_state_dict", checkpoint)
419
+ if not isinstance(state_dict, dict):
420
+ raise AdapterLoadError("Adapter checkpoint missing lora_state_dict.")
421
+
422
+ lora = TextLoRA.from_state_dict(
423
+ state_dict,
424
+ text_config=text_config,
425
+ max_rank=max_rank,
426
+ dtype=dtype,
427
+ device=device,
428
+ adapter_id=adapter_id,
429
+ )
430
+
431
+ _ADAPTER_CACHE[key] = lora
432
+ _CACHE_ORDER.append(key)
433
+ if len(_CACHE_ORDER) > _CACHE_SIZE:
434
+ old = _CACHE_ORDER.pop(0)
435
+ _ADAPTER_CACHE.pop(old, None)
436
+
437
+ return lora
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b120e9db73fb74c13016f96e2a6217b25b3ec8b866a614456406b555feaba90
3
+ size 5256085089
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84a2b3dba409a84b1e847cb25ca6b84956b17518b33e471a58d06590c97bb454
3
+ size 4720107260
model.safetensors.index.json ADDED
@@ -0,0 +1,1059 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 9976074848
4
+ },
5
+ "weight_map": {
6
+ "text.lm_head.bias": "model-00002-of-00002.safetensors",
7
+ "text.lm_head.biases": "model-00002-of-00002.safetensors",
8
+ "text.lm_head.scales": "model-00002-of-00002.safetensors",
9
+ "text.lm_head.weight": "model-00002-of-00002.safetensors",
10
+ "text.model.blocks.0.attn.proj.bias": "model-00001-of-00002.safetensors",
11
+ "text.model.blocks.0.attn.proj.biases": "model-00001-of-00002.safetensors",
12
+ "text.model.blocks.0.attn.proj.scales": "model-00001-of-00002.safetensors",
13
+ "text.model.blocks.0.attn.proj.weight": "model-00001-of-00002.safetensors",
14
+ "text.model.blocks.0.attn.qkv.bias": "model-00001-of-00002.safetensors",
15
+ "text.model.blocks.0.attn.qkv.biases": "model-00001-of-00002.safetensors",
16
+ "text.model.blocks.0.attn.qkv.scales": "model-00001-of-00002.safetensors",
17
+ "text.model.blocks.0.attn.qkv.weight": "model-00001-of-00002.safetensors",
18
+ "text.model.blocks.0.attn.tau.alpha": "model-00001-of-00002.safetensors",
19
+ "text.model.blocks.0.attn.tau.wq": "model-00001-of-00002.safetensors",
20
+ "text.model.blocks.0.attn.tau.wv": "model-00001-of-00002.safetensors",
21
+ "text.model.blocks.0.ln.bias": "model-00001-of-00002.safetensors",
22
+ "text.model.blocks.0.ln.weight": "model-00001-of-00002.safetensors",
23
+ "text.model.blocks.0.mlp.fc1.bias": "model-00001-of-00002.safetensors",
24
+ "text.model.blocks.0.mlp.fc1.biases": "model-00001-of-00002.safetensors",
25
+ "text.model.blocks.0.mlp.fc1.scales": "model-00001-of-00002.safetensors",
26
+ "text.model.blocks.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
27
+ "text.model.blocks.0.mlp.fc2.bias": "model-00001-of-00002.safetensors",
28
+ "text.model.blocks.0.mlp.fc2.biases": "model-00001-of-00002.safetensors",
29
+ "text.model.blocks.0.mlp.fc2.scales": "model-00001-of-00002.safetensors",
30
+ "text.model.blocks.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
31
+ "text.model.blocks.1.attn.proj.bias": "model-00001-of-00002.safetensors",
32
+ "text.model.blocks.1.attn.proj.biases": "model-00001-of-00002.safetensors",
33
+ "text.model.blocks.1.attn.proj.scales": "model-00001-of-00002.safetensors",
34
+ "text.model.blocks.1.attn.proj.weight": "model-00001-of-00002.safetensors",
35
+ "text.model.blocks.1.attn.qkv.bias": "model-00001-of-00002.safetensors",
36
+ "text.model.blocks.1.attn.qkv.biases": "model-00001-of-00002.safetensors",
37
+ "text.model.blocks.1.attn.qkv.scales": "model-00001-of-00002.safetensors",
38
+ "text.model.blocks.1.attn.qkv.weight": "model-00001-of-00002.safetensors",
39
+ "text.model.blocks.1.attn.tau.alpha": "model-00001-of-00002.safetensors",
40
+ "text.model.blocks.1.attn.tau.wq": "model-00001-of-00002.safetensors",
41
+ "text.model.blocks.1.attn.tau.wv": "model-00001-of-00002.safetensors",
42
+ "text.model.blocks.1.ln.bias": "model-00001-of-00002.safetensors",
43
+ "text.model.blocks.1.ln.weight": "model-00001-of-00002.safetensors",
44
+ "text.model.blocks.1.mlp.fc1.bias": "model-00001-of-00002.safetensors",
45
+ "text.model.blocks.1.mlp.fc1.biases": "model-00001-of-00002.safetensors",
46
+ "text.model.blocks.1.mlp.fc1.scales": "model-00001-of-00002.safetensors",
47
+ "text.model.blocks.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
48
+ "text.model.blocks.1.mlp.fc2.bias": "model-00001-of-00002.safetensors",
49
+ "text.model.blocks.1.mlp.fc2.biases": "model-00001-of-00002.safetensors",
50
+ "text.model.blocks.1.mlp.fc2.scales": "model-00001-of-00002.safetensors",
51
+ "text.model.blocks.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
52
+ "text.model.blocks.10.attn.proj.bias": "model-00001-of-00002.safetensors",
53
+ "text.model.blocks.10.attn.proj.biases": "model-00001-of-00002.safetensors",
54
+ "text.model.blocks.10.attn.proj.scales": "model-00001-of-00002.safetensors",
55
+ "text.model.blocks.10.attn.proj.weight": "model-00001-of-00002.safetensors",
56
+ "text.model.blocks.10.attn.qkv.bias": "model-00001-of-00002.safetensors",
57
+ "text.model.blocks.10.attn.qkv.biases": "model-00001-of-00002.safetensors",
58
+ "text.model.blocks.10.attn.qkv.scales": "model-00001-of-00002.safetensors",
59
+ "text.model.blocks.10.attn.qkv.weight": "model-00001-of-00002.safetensors",
60
+ "text.model.blocks.10.attn.tau.alpha": "model-00001-of-00002.safetensors",
61
+ "text.model.blocks.10.attn.tau.wq": "model-00001-of-00002.safetensors",
62
+ "text.model.blocks.10.attn.tau.wv": "model-00001-of-00002.safetensors",
63
+ "text.model.blocks.10.ln.bias": "model-00001-of-00002.safetensors",
64
+ "text.model.blocks.10.ln.weight": "model-00001-of-00002.safetensors",
65
+ "text.model.blocks.10.mlp.fc1.biases": "model-00001-of-00002.safetensors",
66
+ "text.model.blocks.10.mlp.fc1.scales": "model-00001-of-00002.safetensors",
67
+ "text.model.blocks.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
68
+ "text.model.blocks.10.mlp.fc2.biases": "model-00001-of-00002.safetensors",
69
+ "text.model.blocks.10.mlp.fc2.scales": "model-00001-of-00002.safetensors",
70
+ "text.model.blocks.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
71
+ "text.model.blocks.10.mlp.router.bias": "model-00001-of-00002.safetensors",
72
+ "text.model.blocks.10.mlp.router.biases": "model-00001-of-00002.safetensors",
73
+ "text.model.blocks.10.mlp.router.scales": "model-00001-of-00002.safetensors",
74
+ "text.model.blocks.10.mlp.router.weight": "model-00001-of-00002.safetensors",
75
+ "text.model.blocks.11.attn.proj.bias": "model-00001-of-00002.safetensors",
76
+ "text.model.blocks.11.attn.proj.biases": "model-00001-of-00002.safetensors",
77
+ "text.model.blocks.11.attn.proj.scales": "model-00001-of-00002.safetensors",
78
+ "text.model.blocks.11.attn.proj.weight": "model-00001-of-00002.safetensors",
79
+ "text.model.blocks.11.attn.qkv.bias": "model-00001-of-00002.safetensors",
80
+ "text.model.blocks.11.attn.qkv.biases": "model-00001-of-00002.safetensors",
81
+ "text.model.blocks.11.attn.qkv.scales": "model-00001-of-00002.safetensors",
82
+ "text.model.blocks.11.attn.qkv.weight": "model-00001-of-00002.safetensors",
83
+ "text.model.blocks.11.attn.tau.alpha": "model-00001-of-00002.safetensors",
84
+ "text.model.blocks.11.attn.tau.wq": "model-00001-of-00002.safetensors",
85
+ "text.model.blocks.11.attn.tau.wv": "model-00001-of-00002.safetensors",
86
+ "text.model.blocks.11.ln.bias": "model-00001-of-00002.safetensors",
87
+ "text.model.blocks.11.ln.weight": "model-00001-of-00002.safetensors",
88
+ "text.model.blocks.11.mlp.fc1.biases": "model-00001-of-00002.safetensors",
89
+ "text.model.blocks.11.mlp.fc1.scales": "model-00001-of-00002.safetensors",
90
+ "text.model.blocks.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
91
+ "text.model.blocks.11.mlp.fc2.biases": "model-00001-of-00002.safetensors",
92
+ "text.model.blocks.11.mlp.fc2.scales": "model-00001-of-00002.safetensors",
93
+ "text.model.blocks.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
94
+ "text.model.blocks.11.mlp.router.bias": "model-00001-of-00002.safetensors",
95
+ "text.model.blocks.11.mlp.router.biases": "model-00001-of-00002.safetensors",
96
+ "text.model.blocks.11.mlp.router.scales": "model-00001-of-00002.safetensors",
97
+ "text.model.blocks.11.mlp.router.weight": "model-00001-of-00002.safetensors",
98
+ "text.model.blocks.12.attn.proj.bias": "model-00001-of-00002.safetensors",
99
+ "text.model.blocks.12.attn.proj.biases": "model-00001-of-00002.safetensors",
100
+ "text.model.blocks.12.attn.proj.scales": "model-00001-of-00002.safetensors",
101
+ "text.model.blocks.12.attn.proj.weight": "model-00001-of-00002.safetensors",
102
+ "text.model.blocks.12.attn.qkv.bias": "model-00001-of-00002.safetensors",
103
+ "text.model.blocks.12.attn.qkv.biases": "model-00001-of-00002.safetensors",
104
+ "text.model.blocks.12.attn.qkv.scales": "model-00001-of-00002.safetensors",
105
+ "text.model.blocks.12.attn.qkv.weight": "model-00001-of-00002.safetensors",
106
+ "text.model.blocks.12.attn.tau.alpha": "model-00001-of-00002.safetensors",
107
+ "text.model.blocks.12.attn.tau.wq": "model-00001-of-00002.safetensors",
108
+ "text.model.blocks.12.attn.tau.wv": "model-00001-of-00002.safetensors",
109
+ "text.model.blocks.12.ln.bias": "model-00001-of-00002.safetensors",
110
+ "text.model.blocks.12.ln.weight": "model-00001-of-00002.safetensors",
111
+ "text.model.blocks.12.mlp.fc1.biases": "model-00001-of-00002.safetensors",
112
+ "text.model.blocks.12.mlp.fc1.scales": "model-00001-of-00002.safetensors",
113
+ "text.model.blocks.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
114
+ "text.model.blocks.12.mlp.fc2.biases": "model-00001-of-00002.safetensors",
115
+ "text.model.blocks.12.mlp.fc2.scales": "model-00001-of-00002.safetensors",
116
+ "text.model.blocks.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
117
+ "text.model.blocks.12.mlp.router.bias": "model-00001-of-00002.safetensors",
118
+ "text.model.blocks.12.mlp.router.biases": "model-00001-of-00002.safetensors",
119
+ "text.model.blocks.12.mlp.router.scales": "model-00001-of-00002.safetensors",
120
+ "text.model.blocks.12.mlp.router.weight": "model-00001-of-00002.safetensors",
121
+ "text.model.blocks.13.attn.proj.bias": "model-00001-of-00002.safetensors",
122
+ "text.model.blocks.13.attn.proj.biases": "model-00001-of-00002.safetensors",
123
+ "text.model.blocks.13.attn.proj.scales": "model-00001-of-00002.safetensors",
124
+ "text.model.blocks.13.attn.proj.weight": "model-00001-of-00002.safetensors",
125
+ "text.model.blocks.13.attn.qkv.bias": "model-00001-of-00002.safetensors",
126
+ "text.model.blocks.13.attn.qkv.biases": "model-00001-of-00002.safetensors",
127
+ "text.model.blocks.13.attn.qkv.scales": "model-00001-of-00002.safetensors",
128
+ "text.model.blocks.13.attn.qkv.weight": "model-00001-of-00002.safetensors",
129
+ "text.model.blocks.13.attn.tau.alpha": "model-00001-of-00002.safetensors",
130
+ "text.model.blocks.13.attn.tau.wq": "model-00001-of-00002.safetensors",
131
+ "text.model.blocks.13.attn.tau.wv": "model-00001-of-00002.safetensors",
132
+ "text.model.blocks.13.ln.bias": "model-00001-of-00002.safetensors",
133
+ "text.model.blocks.13.ln.weight": "model-00001-of-00002.safetensors",
134
+ "text.model.blocks.13.mlp.fc1.biases": "model-00001-of-00002.safetensors",
135
+ "text.model.blocks.13.mlp.fc1.scales": "model-00001-of-00002.safetensors",
136
+ "text.model.blocks.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
137
+ "text.model.blocks.13.mlp.fc2.biases": "model-00002-of-00002.safetensors",
138
+ "text.model.blocks.13.mlp.fc2.scales": "model-00002-of-00002.safetensors",
139
+ "text.model.blocks.13.mlp.fc2.weight": "model-00002-of-00002.safetensors",
140
+ "text.model.blocks.13.mlp.router.bias": "model-00001-of-00002.safetensors",
141
+ "text.model.blocks.13.mlp.router.biases": "model-00001-of-00002.safetensors",
142
+ "text.model.blocks.13.mlp.router.scales": "model-00001-of-00002.safetensors",
143
+ "text.model.blocks.13.mlp.router.weight": "model-00001-of-00002.safetensors",
144
+ "text.model.blocks.14.attn.proj.bias": "model-00002-of-00002.safetensors",
145
+ "text.model.blocks.14.attn.proj.biases": "model-00002-of-00002.safetensors",
146
+ "text.model.blocks.14.attn.proj.scales": "model-00002-of-00002.safetensors",
147
+ "text.model.blocks.14.attn.proj.weight": "model-00002-of-00002.safetensors",
148
+ "text.model.blocks.14.attn.qkv.bias": "model-00002-of-00002.safetensors",
149
+ "text.model.blocks.14.attn.qkv.biases": "model-00002-of-00002.safetensors",
150
+ "text.model.blocks.14.attn.qkv.scales": "model-00002-of-00002.safetensors",
151
+ "text.model.blocks.14.attn.qkv.weight": "model-00002-of-00002.safetensors",
152
+ "text.model.blocks.14.attn.tau.alpha": "model-00002-of-00002.safetensors",
153
+ "text.model.blocks.14.attn.tau.wq": "model-00002-of-00002.safetensors",
154
+ "text.model.blocks.14.attn.tau.wv": "model-00002-of-00002.safetensors",
155
+ "text.model.blocks.14.ln.bias": "model-00002-of-00002.safetensors",
156
+ "text.model.blocks.14.ln.weight": "model-00002-of-00002.safetensors",
157
+ "text.model.blocks.14.mlp.fc1.biases": "model-00002-of-00002.safetensors",
158
+ "text.model.blocks.14.mlp.fc1.scales": "model-00002-of-00002.safetensors",
159
+ "text.model.blocks.14.mlp.fc1.weight": "model-00002-of-00002.safetensors",
160
+ "text.model.blocks.14.mlp.fc2.biases": "model-00002-of-00002.safetensors",
161
+ "text.model.blocks.14.mlp.fc2.scales": "model-00002-of-00002.safetensors",
162
+ "text.model.blocks.14.mlp.fc2.weight": "model-00002-of-00002.safetensors",
163
+ "text.model.blocks.14.mlp.router.bias": "model-00002-of-00002.safetensors",
164
+ "text.model.blocks.14.mlp.router.biases": "model-00002-of-00002.safetensors",
165
+ "text.model.blocks.14.mlp.router.scales": "model-00002-of-00002.safetensors",
166
+ "text.model.blocks.14.mlp.router.weight": "model-00002-of-00002.safetensors",
167
+ "text.model.blocks.15.attn.proj.bias": "model-00002-of-00002.safetensors",
168
+ "text.model.blocks.15.attn.proj.biases": "model-00002-of-00002.safetensors",
169
+ "text.model.blocks.15.attn.proj.scales": "model-00002-of-00002.safetensors",
170
+ "text.model.blocks.15.attn.proj.weight": "model-00002-of-00002.safetensors",
171
+ "text.model.blocks.15.attn.qkv.bias": "model-00002-of-00002.safetensors",
172
+ "text.model.blocks.15.attn.qkv.biases": "model-00002-of-00002.safetensors",
173
+ "text.model.blocks.15.attn.qkv.scales": "model-00002-of-00002.safetensors",
174
+ "text.model.blocks.15.attn.qkv.weight": "model-00002-of-00002.safetensors",
175
+ "text.model.blocks.15.attn.tau.alpha": "model-00002-of-00002.safetensors",
176
+ "text.model.blocks.15.attn.tau.wq": "model-00002-of-00002.safetensors",
177
+ "text.model.blocks.15.attn.tau.wv": "model-00002-of-00002.safetensors",
178
+ "text.model.blocks.15.ln.bias": "model-00002-of-00002.safetensors",
179
+ "text.model.blocks.15.ln.weight": "model-00002-of-00002.safetensors",
180
+ "text.model.blocks.15.mlp.fc1.biases": "model-00002-of-00002.safetensors",
181
+ "text.model.blocks.15.mlp.fc1.scales": "model-00002-of-00002.safetensors",
182
+ "text.model.blocks.15.mlp.fc1.weight": "model-00002-of-00002.safetensors",
183
+ "text.model.blocks.15.mlp.fc2.biases": "model-00002-of-00002.safetensors",
184
+ "text.model.blocks.15.mlp.fc2.scales": "model-00002-of-00002.safetensors",
185
+ "text.model.blocks.15.mlp.fc2.weight": "model-00002-of-00002.safetensors",
186
+ "text.model.blocks.15.mlp.router.bias": "model-00002-of-00002.safetensors",
187
+ "text.model.blocks.15.mlp.router.biases": "model-00002-of-00002.safetensors",
188
+ "text.model.blocks.15.mlp.router.scales": "model-00002-of-00002.safetensors",
189
+ "text.model.blocks.15.mlp.router.weight": "model-00002-of-00002.safetensors",
190
+ "text.model.blocks.16.attn.proj.bias": "model-00002-of-00002.safetensors",
191
+ "text.model.blocks.16.attn.proj.biases": "model-00002-of-00002.safetensors",
192
+ "text.model.blocks.16.attn.proj.scales": "model-00002-of-00002.safetensors",
193
+ "text.model.blocks.16.attn.proj.weight": "model-00002-of-00002.safetensors",
194
+ "text.model.blocks.16.attn.qkv.bias": "model-00002-of-00002.safetensors",
195
+ "text.model.blocks.16.attn.qkv.biases": "model-00002-of-00002.safetensors",
196
+ "text.model.blocks.16.attn.qkv.scales": "model-00002-of-00002.safetensors",
197
+ "text.model.blocks.16.attn.qkv.weight": "model-00002-of-00002.safetensors",
198
+ "text.model.blocks.16.attn.tau.alpha": "model-00002-of-00002.safetensors",
199
+ "text.model.blocks.16.attn.tau.wq": "model-00002-of-00002.safetensors",
200
+ "text.model.blocks.16.attn.tau.wv": "model-00002-of-00002.safetensors",
201
+ "text.model.blocks.16.ln.bias": "model-00002-of-00002.safetensors",
202
+ "text.model.blocks.16.ln.weight": "model-00002-of-00002.safetensors",
203
+ "text.model.blocks.16.mlp.fc1.biases": "model-00002-of-00002.safetensors",
204
+ "text.model.blocks.16.mlp.fc1.scales": "model-00002-of-00002.safetensors",
205
+ "text.model.blocks.16.mlp.fc1.weight": "model-00002-of-00002.safetensors",
206
+ "text.model.blocks.16.mlp.fc2.biases": "model-00002-of-00002.safetensors",
207
+ "text.model.blocks.16.mlp.fc2.scales": "model-00002-of-00002.safetensors",
208
+ "text.model.blocks.16.mlp.fc2.weight": "model-00002-of-00002.safetensors",
209
+ "text.model.blocks.16.mlp.router.bias": "model-00002-of-00002.safetensors",
210
+ "text.model.blocks.16.mlp.router.biases": "model-00002-of-00002.safetensors",
211
+ "text.model.blocks.16.mlp.router.scales": "model-00002-of-00002.safetensors",
212
+ "text.model.blocks.16.mlp.router.weight": "model-00002-of-00002.safetensors",
213
+ "text.model.blocks.17.attn.proj.bias": "model-00002-of-00002.safetensors",
214
+ "text.model.blocks.17.attn.proj.biases": "model-00002-of-00002.safetensors",
215
+ "text.model.blocks.17.attn.proj.scales": "model-00002-of-00002.safetensors",
216
+ "text.model.blocks.17.attn.proj.weight": "model-00002-of-00002.safetensors",
217
+ "text.model.blocks.17.attn.qkv.bias": "model-00002-of-00002.safetensors",
218
+ "text.model.blocks.17.attn.qkv.biases": "model-00002-of-00002.safetensors",
219
+ "text.model.blocks.17.attn.qkv.scales": "model-00002-of-00002.safetensors",
220
+ "text.model.blocks.17.attn.qkv.weight": "model-00002-of-00002.safetensors",
221
+ "text.model.blocks.17.attn.tau.alpha": "model-00002-of-00002.safetensors",
222
+ "text.model.blocks.17.attn.tau.wq": "model-00002-of-00002.safetensors",
223
+ "text.model.blocks.17.attn.tau.wv": "model-00002-of-00002.safetensors",
224
+ "text.model.blocks.17.ln.bias": "model-00002-of-00002.safetensors",
225
+ "text.model.blocks.17.ln.weight": "model-00002-of-00002.safetensors",
226
+ "text.model.blocks.17.mlp.fc1.biases": "model-00002-of-00002.safetensors",
227
+ "text.model.blocks.17.mlp.fc1.scales": "model-00002-of-00002.safetensors",
228
+ "text.model.blocks.17.mlp.fc1.weight": "model-00002-of-00002.safetensors",
229
+ "text.model.blocks.17.mlp.fc2.biases": "model-00002-of-00002.safetensors",
230
+ "text.model.blocks.17.mlp.fc2.scales": "model-00002-of-00002.safetensors",
231
+ "text.model.blocks.17.mlp.fc2.weight": "model-00002-of-00002.safetensors",
232
+ "text.model.blocks.17.mlp.router.bias": "model-00002-of-00002.safetensors",
233
+ "text.model.blocks.17.mlp.router.biases": "model-00002-of-00002.safetensors",
234
+ "text.model.blocks.17.mlp.router.scales": "model-00002-of-00002.safetensors",
235
+ "text.model.blocks.17.mlp.router.weight": "model-00002-of-00002.safetensors",
236
+ "text.model.blocks.18.attn.proj.bias": "model-00002-of-00002.safetensors",
237
+ "text.model.blocks.18.attn.proj.biases": "model-00002-of-00002.safetensors",
238
+ "text.model.blocks.18.attn.proj.scales": "model-00002-of-00002.safetensors",
239
+ "text.model.blocks.18.attn.proj.weight": "model-00002-of-00002.safetensors",
240
+ "text.model.blocks.18.attn.qkv.bias": "model-00002-of-00002.safetensors",
241
+ "text.model.blocks.18.attn.qkv.biases": "model-00002-of-00002.safetensors",
242
+ "text.model.blocks.18.attn.qkv.scales": "model-00002-of-00002.safetensors",
243
+ "text.model.blocks.18.attn.qkv.weight": "model-00002-of-00002.safetensors",
244
+ "text.model.blocks.18.attn.tau.alpha": "model-00002-of-00002.safetensors",
245
+ "text.model.blocks.18.attn.tau.wq": "model-00002-of-00002.safetensors",
246
+ "text.model.blocks.18.attn.tau.wv": "model-00002-of-00002.safetensors",
247
+ "text.model.blocks.18.ln.bias": "model-00002-of-00002.safetensors",
248
+ "text.model.blocks.18.ln.weight": "model-00002-of-00002.safetensors",
249
+ "text.model.blocks.18.mlp.fc1.biases": "model-00002-of-00002.safetensors",
250
+ "text.model.blocks.18.mlp.fc1.scales": "model-00002-of-00002.safetensors",
251
+ "text.model.blocks.18.mlp.fc1.weight": "model-00002-of-00002.safetensors",
252
+ "text.model.blocks.18.mlp.fc2.biases": "model-00002-of-00002.safetensors",
253
+ "text.model.blocks.18.mlp.fc2.scales": "model-00002-of-00002.safetensors",
254
+ "text.model.blocks.18.mlp.fc2.weight": "model-00002-of-00002.safetensors",
255
+ "text.model.blocks.18.mlp.router.bias": "model-00002-of-00002.safetensors",
256
+ "text.model.blocks.18.mlp.router.biases": "model-00002-of-00002.safetensors",
257
+ "text.model.blocks.18.mlp.router.scales": "model-00002-of-00002.safetensors",
258
+ "text.model.blocks.18.mlp.router.weight": "model-00002-of-00002.safetensors",
259
+ "text.model.blocks.19.attn.proj.bias": "model-00002-of-00002.safetensors",
260
+ "text.model.blocks.19.attn.proj.biases": "model-00002-of-00002.safetensors",
261
+ "text.model.blocks.19.attn.proj.scales": "model-00002-of-00002.safetensors",
262
+ "text.model.blocks.19.attn.proj.weight": "model-00002-of-00002.safetensors",
263
+ "text.model.blocks.19.attn.qkv.bias": "model-00002-of-00002.safetensors",
264
+ "text.model.blocks.19.attn.qkv.biases": "model-00002-of-00002.safetensors",
265
+ "text.model.blocks.19.attn.qkv.scales": "model-00002-of-00002.safetensors",
266
+ "text.model.blocks.19.attn.qkv.weight": "model-00002-of-00002.safetensors",
267
+ "text.model.blocks.19.attn.tau.alpha": "model-00002-of-00002.safetensors",
268
+ "text.model.blocks.19.attn.tau.wq": "model-00002-of-00002.safetensors",
269
+ "text.model.blocks.19.attn.tau.wv": "model-00002-of-00002.safetensors",
270
+ "text.model.blocks.19.ln.bias": "model-00002-of-00002.safetensors",
271
+ "text.model.blocks.19.ln.weight": "model-00002-of-00002.safetensors",
272
+ "text.model.blocks.19.mlp.fc1.biases": "model-00002-of-00002.safetensors",
273
+ "text.model.blocks.19.mlp.fc1.scales": "model-00002-of-00002.safetensors",
274
+ "text.model.blocks.19.mlp.fc1.weight": "model-00002-of-00002.safetensors",
275
+ "text.model.blocks.19.mlp.fc2.biases": "model-00002-of-00002.safetensors",
276
+ "text.model.blocks.19.mlp.fc2.scales": "model-00002-of-00002.safetensors",
277
+ "text.model.blocks.19.mlp.fc2.weight": "model-00002-of-00002.safetensors",
278
+ "text.model.blocks.19.mlp.router.bias": "model-00002-of-00002.safetensors",
279
+ "text.model.blocks.19.mlp.router.biases": "model-00002-of-00002.safetensors",
280
+ "text.model.blocks.19.mlp.router.scales": "model-00002-of-00002.safetensors",
281
+ "text.model.blocks.19.mlp.router.weight": "model-00002-of-00002.safetensors",
282
+ "text.model.blocks.2.attn.proj.bias": "model-00001-of-00002.safetensors",
283
+ "text.model.blocks.2.attn.proj.biases": "model-00001-of-00002.safetensors",
284
+ "text.model.blocks.2.attn.proj.scales": "model-00001-of-00002.safetensors",
285
+ "text.model.blocks.2.attn.proj.weight": "model-00001-of-00002.safetensors",
286
+ "text.model.blocks.2.attn.qkv.bias": "model-00001-of-00002.safetensors",
287
+ "text.model.blocks.2.attn.qkv.biases": "model-00001-of-00002.safetensors",
288
+ "text.model.blocks.2.attn.qkv.scales": "model-00001-of-00002.safetensors",
289
+ "text.model.blocks.2.attn.qkv.weight": "model-00001-of-00002.safetensors",
290
+ "text.model.blocks.2.attn.tau.alpha": "model-00001-of-00002.safetensors",
291
+ "text.model.blocks.2.attn.tau.wq": "model-00001-of-00002.safetensors",
292
+ "text.model.blocks.2.attn.tau.wv": "model-00001-of-00002.safetensors",
293
+ "text.model.blocks.2.ln.bias": "model-00001-of-00002.safetensors",
294
+ "text.model.blocks.2.ln.weight": "model-00001-of-00002.safetensors",
295
+ "text.model.blocks.2.mlp.fc1.bias": "model-00001-of-00002.safetensors",
296
+ "text.model.blocks.2.mlp.fc1.biases": "model-00001-of-00002.safetensors",
297
+ "text.model.blocks.2.mlp.fc1.scales": "model-00001-of-00002.safetensors",
298
+ "text.model.blocks.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
299
+ "text.model.blocks.2.mlp.fc2.bias": "model-00001-of-00002.safetensors",
300
+ "text.model.blocks.2.mlp.fc2.biases": "model-00001-of-00002.safetensors",
301
+ "text.model.blocks.2.mlp.fc2.scales": "model-00001-of-00002.safetensors",
302
+ "text.model.blocks.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
303
+ "text.model.blocks.20.attn.proj.bias": "model-00002-of-00002.safetensors",
304
+ "text.model.blocks.20.attn.proj.biases": "model-00002-of-00002.safetensors",
305
+ "text.model.blocks.20.attn.proj.scales": "model-00002-of-00002.safetensors",
306
+ "text.model.blocks.20.attn.proj.weight": "model-00002-of-00002.safetensors",
307
+ "text.model.blocks.20.attn.qkv.bias": "model-00002-of-00002.safetensors",
308
+ "text.model.blocks.20.attn.qkv.biases": "model-00002-of-00002.safetensors",
309
+ "text.model.blocks.20.attn.qkv.scales": "model-00002-of-00002.safetensors",
310
+ "text.model.blocks.20.attn.qkv.weight": "model-00002-of-00002.safetensors",
311
+ "text.model.blocks.20.attn.tau.alpha": "model-00002-of-00002.safetensors",
312
+ "text.model.blocks.20.attn.tau.wq": "model-00002-of-00002.safetensors",
313
+ "text.model.blocks.20.attn.tau.wv": "model-00002-of-00002.safetensors",
314
+ "text.model.blocks.20.ln.bias": "model-00002-of-00002.safetensors",
315
+ "text.model.blocks.20.ln.weight": "model-00002-of-00002.safetensors",
316
+ "text.model.blocks.20.mlp.fc1.biases": "model-00002-of-00002.safetensors",
317
+ "text.model.blocks.20.mlp.fc1.scales": "model-00002-of-00002.safetensors",
318
+ "text.model.blocks.20.mlp.fc1.weight": "model-00002-of-00002.safetensors",
319
+ "text.model.blocks.20.mlp.fc2.biases": "model-00002-of-00002.safetensors",
320
+ "text.model.blocks.20.mlp.fc2.scales": "model-00002-of-00002.safetensors",
321
+ "text.model.blocks.20.mlp.fc2.weight": "model-00002-of-00002.safetensors",
322
+ "text.model.blocks.20.mlp.router.bias": "model-00002-of-00002.safetensors",
323
+ "text.model.blocks.20.mlp.router.biases": "model-00002-of-00002.safetensors",
324
+ "text.model.blocks.20.mlp.router.scales": "model-00002-of-00002.safetensors",
325
+ "text.model.blocks.20.mlp.router.weight": "model-00002-of-00002.safetensors",
326
+ "text.model.blocks.21.attn.proj.bias": "model-00002-of-00002.safetensors",
327
+ "text.model.blocks.21.attn.proj.biases": "model-00002-of-00002.safetensors",
328
+ "text.model.blocks.21.attn.proj.scales": "model-00002-of-00002.safetensors",
329
+ "text.model.blocks.21.attn.proj.weight": "model-00002-of-00002.safetensors",
330
+ "text.model.blocks.21.attn.qkv.bias": "model-00002-of-00002.safetensors",
331
+ "text.model.blocks.21.attn.qkv.biases": "model-00002-of-00002.safetensors",
332
+ "text.model.blocks.21.attn.qkv.scales": "model-00002-of-00002.safetensors",
333
+ "text.model.blocks.21.attn.qkv.weight": "model-00002-of-00002.safetensors",
334
+ "text.model.blocks.21.attn.tau.alpha": "model-00002-of-00002.safetensors",
335
+ "text.model.blocks.21.attn.tau.wq": "model-00002-of-00002.safetensors",
336
+ "text.model.blocks.21.attn.tau.wv": "model-00002-of-00002.safetensors",
337
+ "text.model.blocks.21.ln.bias": "model-00002-of-00002.safetensors",
338
+ "text.model.blocks.21.ln.weight": "model-00002-of-00002.safetensors",
339
+ "text.model.blocks.21.mlp.fc1.biases": "model-00002-of-00002.safetensors",
340
+ "text.model.blocks.21.mlp.fc1.scales": "model-00002-of-00002.safetensors",
341
+ "text.model.blocks.21.mlp.fc1.weight": "model-00002-of-00002.safetensors",
342
+ "text.model.blocks.21.mlp.fc2.biases": "model-00002-of-00002.safetensors",
343
+ "text.model.blocks.21.mlp.fc2.scales": "model-00002-of-00002.safetensors",
344
+ "text.model.blocks.21.mlp.fc2.weight": "model-00002-of-00002.safetensors",
345
+ "text.model.blocks.21.mlp.router.bias": "model-00002-of-00002.safetensors",
346
+ "text.model.blocks.21.mlp.router.biases": "model-00002-of-00002.safetensors",
347
+ "text.model.blocks.21.mlp.router.scales": "model-00002-of-00002.safetensors",
348
+ "text.model.blocks.21.mlp.router.weight": "model-00002-of-00002.safetensors",
349
+ "text.model.blocks.22.attn.proj.bias": "model-00002-of-00002.safetensors",
350
+ "text.model.blocks.22.attn.proj.biases": "model-00002-of-00002.safetensors",
351
+ "text.model.blocks.22.attn.proj.scales": "model-00002-of-00002.safetensors",
352
+ "text.model.blocks.22.attn.proj.weight": "model-00002-of-00002.safetensors",
353
+ "text.model.blocks.22.attn.qkv.bias": "model-00002-of-00002.safetensors",
354
+ "text.model.blocks.22.attn.qkv.biases": "model-00002-of-00002.safetensors",
355
+ "text.model.blocks.22.attn.qkv.scales": "model-00002-of-00002.safetensors",
356
+ "text.model.blocks.22.attn.qkv.weight": "model-00002-of-00002.safetensors",
357
+ "text.model.blocks.22.attn.tau.alpha": "model-00002-of-00002.safetensors",
358
+ "text.model.blocks.22.attn.tau.wq": "model-00002-of-00002.safetensors",
359
+ "text.model.blocks.22.attn.tau.wv": "model-00002-of-00002.safetensors",
360
+ "text.model.blocks.22.ln.bias": "model-00002-of-00002.safetensors",
361
+ "text.model.blocks.22.ln.weight": "model-00002-of-00002.safetensors",
362
+ "text.model.blocks.22.mlp.fc1.biases": "model-00002-of-00002.safetensors",
363
+ "text.model.blocks.22.mlp.fc1.scales": "model-00002-of-00002.safetensors",
364
+ "text.model.blocks.22.mlp.fc1.weight": "model-00002-of-00002.safetensors",
365
+ "text.model.blocks.22.mlp.fc2.biases": "model-00002-of-00002.safetensors",
366
+ "text.model.blocks.22.mlp.fc2.scales": "model-00002-of-00002.safetensors",
367
+ "text.model.blocks.22.mlp.fc2.weight": "model-00002-of-00002.safetensors",
368
+ "text.model.blocks.22.mlp.router.bias": "model-00002-of-00002.safetensors",
369
+ "text.model.blocks.22.mlp.router.biases": "model-00002-of-00002.safetensors",
370
+ "text.model.blocks.22.mlp.router.scales": "model-00002-of-00002.safetensors",
371
+ "text.model.blocks.22.mlp.router.weight": "model-00002-of-00002.safetensors",
372
+ "text.model.blocks.23.attn.proj.bias": "model-00002-of-00002.safetensors",
373
+ "text.model.blocks.23.attn.proj.biases": "model-00002-of-00002.safetensors",
374
+ "text.model.blocks.23.attn.proj.scales": "model-00002-of-00002.safetensors",
375
+ "text.model.blocks.23.attn.proj.weight": "model-00002-of-00002.safetensors",
376
+ "text.model.blocks.23.attn.qkv.bias": "model-00002-of-00002.safetensors",
377
+ "text.model.blocks.23.attn.qkv.biases": "model-00002-of-00002.safetensors",
378
+ "text.model.blocks.23.attn.qkv.scales": "model-00002-of-00002.safetensors",
379
+ "text.model.blocks.23.attn.qkv.weight": "model-00002-of-00002.safetensors",
380
+ "text.model.blocks.23.attn.tau.alpha": "model-00002-of-00002.safetensors",
381
+ "text.model.blocks.23.attn.tau.wq": "model-00002-of-00002.safetensors",
382
+ "text.model.blocks.23.attn.tau.wv": "model-00002-of-00002.safetensors",
383
+ "text.model.blocks.23.ln.bias": "model-00002-of-00002.safetensors",
384
+ "text.model.blocks.23.ln.weight": "model-00002-of-00002.safetensors",
385
+ "text.model.blocks.23.mlp.fc1.biases": "model-00002-of-00002.safetensors",
386
+ "text.model.blocks.23.mlp.fc1.scales": "model-00002-of-00002.safetensors",
387
+ "text.model.blocks.23.mlp.fc1.weight": "model-00002-of-00002.safetensors",
388
+ "text.model.blocks.23.mlp.fc2.biases": "model-00002-of-00002.safetensors",
389
+ "text.model.blocks.23.mlp.fc2.scales": "model-00002-of-00002.safetensors",
390
+ "text.model.blocks.23.mlp.fc2.weight": "model-00002-of-00002.safetensors",
391
+ "text.model.blocks.23.mlp.router.bias": "model-00002-of-00002.safetensors",
392
+ "text.model.blocks.23.mlp.router.biases": "model-00002-of-00002.safetensors",
393
+ "text.model.blocks.23.mlp.router.scales": "model-00002-of-00002.safetensors",
394
+ "text.model.blocks.23.mlp.router.weight": "model-00002-of-00002.safetensors",
395
+ "text.model.blocks.3.attn.proj.bias": "model-00001-of-00002.safetensors",
396
+ "text.model.blocks.3.attn.proj.biases": "model-00001-of-00002.safetensors",
397
+ "text.model.blocks.3.attn.proj.scales": "model-00001-of-00002.safetensors",
398
+ "text.model.blocks.3.attn.proj.weight": "model-00001-of-00002.safetensors",
399
+ "text.model.blocks.3.attn.qkv.bias": "model-00001-of-00002.safetensors",
400
+ "text.model.blocks.3.attn.qkv.biases": "model-00001-of-00002.safetensors",
401
+ "text.model.blocks.3.attn.qkv.scales": "model-00001-of-00002.safetensors",
402
+ "text.model.blocks.3.attn.qkv.weight": "model-00001-of-00002.safetensors",
403
+ "text.model.blocks.3.attn.tau.alpha": "model-00001-of-00002.safetensors",
404
+ "text.model.blocks.3.attn.tau.wq": "model-00001-of-00002.safetensors",
405
+ "text.model.blocks.3.attn.tau.wv": "model-00001-of-00002.safetensors",
406
+ "text.model.blocks.3.ln.bias": "model-00001-of-00002.safetensors",
407
+ "text.model.blocks.3.ln.weight": "model-00001-of-00002.safetensors",
408
+ "text.model.blocks.3.mlp.fc1.bias": "model-00001-of-00002.safetensors",
409
+ "text.model.blocks.3.mlp.fc1.biases": "model-00001-of-00002.safetensors",
410
+ "text.model.blocks.3.mlp.fc1.scales": "model-00001-of-00002.safetensors",
411
+ "text.model.blocks.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
412
+ "text.model.blocks.3.mlp.fc2.bias": "model-00001-of-00002.safetensors",
413
+ "text.model.blocks.3.mlp.fc2.biases": "model-00001-of-00002.safetensors",
414
+ "text.model.blocks.3.mlp.fc2.scales": "model-00001-of-00002.safetensors",
415
+ "text.model.blocks.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
416
+ "text.model.blocks.4.attn.proj.bias": "model-00001-of-00002.safetensors",
417
+ "text.model.blocks.4.attn.proj.biases": "model-00001-of-00002.safetensors",
418
+ "text.model.blocks.4.attn.proj.scales": "model-00001-of-00002.safetensors",
419
+ "text.model.blocks.4.attn.proj.weight": "model-00001-of-00002.safetensors",
420
+ "text.model.blocks.4.attn.qkv.bias": "model-00001-of-00002.safetensors",
421
+ "text.model.blocks.4.attn.qkv.biases": "model-00001-of-00002.safetensors",
422
+ "text.model.blocks.4.attn.qkv.scales": "model-00001-of-00002.safetensors",
423
+ "text.model.blocks.4.attn.qkv.weight": "model-00001-of-00002.safetensors",
424
+ "text.model.blocks.4.attn.tau.alpha": "model-00001-of-00002.safetensors",
425
+ "text.model.blocks.4.attn.tau.wq": "model-00001-of-00002.safetensors",
426
+ "text.model.blocks.4.attn.tau.wv": "model-00001-of-00002.safetensors",
427
+ "text.model.blocks.4.ln.bias": "model-00001-of-00002.safetensors",
428
+ "text.model.blocks.4.ln.weight": "model-00001-of-00002.safetensors",
429
+ "text.model.blocks.4.mlp.fc1.biases": "model-00001-of-00002.safetensors",
430
+ "text.model.blocks.4.mlp.fc1.scales": "model-00001-of-00002.safetensors",
431
+ "text.model.blocks.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
432
+ "text.model.blocks.4.mlp.fc2.biases": "model-00001-of-00002.safetensors",
433
+ "text.model.blocks.4.mlp.fc2.scales": "model-00001-of-00002.safetensors",
434
+ "text.model.blocks.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
435
+ "text.model.blocks.4.mlp.router.bias": "model-00001-of-00002.safetensors",
436
+ "text.model.blocks.4.mlp.router.biases": "model-00001-of-00002.safetensors",
437
+ "text.model.blocks.4.mlp.router.scales": "model-00001-of-00002.safetensors",
438
+ "text.model.blocks.4.mlp.router.weight": "model-00001-of-00002.safetensors",
439
+ "text.model.blocks.5.attn.proj.bias": "model-00001-of-00002.safetensors",
440
+ "text.model.blocks.5.attn.proj.biases": "model-00001-of-00002.safetensors",
441
+ "text.model.blocks.5.attn.proj.scales": "model-00001-of-00002.safetensors",
442
+ "text.model.blocks.5.attn.proj.weight": "model-00001-of-00002.safetensors",
443
+ "text.model.blocks.5.attn.qkv.bias": "model-00001-of-00002.safetensors",
444
+ "text.model.blocks.5.attn.qkv.biases": "model-00001-of-00002.safetensors",
445
+ "text.model.blocks.5.attn.qkv.scales": "model-00001-of-00002.safetensors",
446
+ "text.model.blocks.5.attn.qkv.weight": "model-00001-of-00002.safetensors",
447
+ "text.model.blocks.5.attn.tau.alpha": "model-00001-of-00002.safetensors",
448
+ "text.model.blocks.5.attn.tau.wq": "model-00001-of-00002.safetensors",
449
+ "text.model.blocks.5.attn.tau.wv": "model-00001-of-00002.safetensors",
450
+ "text.model.blocks.5.ln.bias": "model-00001-of-00002.safetensors",
451
+ "text.model.blocks.5.ln.weight": "model-00001-of-00002.safetensors",
452
+ "text.model.blocks.5.mlp.fc1.biases": "model-00001-of-00002.safetensors",
453
+ "text.model.blocks.5.mlp.fc1.scales": "model-00001-of-00002.safetensors",
454
+ "text.model.blocks.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
455
+ "text.model.blocks.5.mlp.fc2.biases": "model-00001-of-00002.safetensors",
456
+ "text.model.blocks.5.mlp.fc2.scales": "model-00001-of-00002.safetensors",
457
+ "text.model.blocks.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
458
+ "text.model.blocks.5.mlp.router.bias": "model-00001-of-00002.safetensors",
459
+ "text.model.blocks.5.mlp.router.biases": "model-00001-of-00002.safetensors",
460
+ "text.model.blocks.5.mlp.router.scales": "model-00001-of-00002.safetensors",
461
+ "text.model.blocks.5.mlp.router.weight": "model-00001-of-00002.safetensors",
462
+ "text.model.blocks.6.attn.proj.bias": "model-00001-of-00002.safetensors",
463
+ "text.model.blocks.6.attn.proj.biases": "model-00001-of-00002.safetensors",
464
+ "text.model.blocks.6.attn.proj.scales": "model-00001-of-00002.safetensors",
465
+ "text.model.blocks.6.attn.proj.weight": "model-00001-of-00002.safetensors",
466
+ "text.model.blocks.6.attn.qkv.bias": "model-00001-of-00002.safetensors",
467
+ "text.model.blocks.6.attn.qkv.biases": "model-00001-of-00002.safetensors",
468
+ "text.model.blocks.6.attn.qkv.scales": "model-00001-of-00002.safetensors",
469
+ "text.model.blocks.6.attn.qkv.weight": "model-00001-of-00002.safetensors",
470
+ "text.model.blocks.6.attn.tau.alpha": "model-00001-of-00002.safetensors",
471
+ "text.model.blocks.6.attn.tau.wq": "model-00001-of-00002.safetensors",
472
+ "text.model.blocks.6.attn.tau.wv": "model-00001-of-00002.safetensors",
473
+ "text.model.blocks.6.ln.bias": "model-00001-of-00002.safetensors",
474
+ "text.model.blocks.6.ln.weight": "model-00001-of-00002.safetensors",
475
+ "text.model.blocks.6.mlp.fc1.biases": "model-00001-of-00002.safetensors",
476
+ "text.model.blocks.6.mlp.fc1.scales": "model-00001-of-00002.safetensors",
477
+ "text.model.blocks.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
478
+ "text.model.blocks.6.mlp.fc2.biases": "model-00001-of-00002.safetensors",
479
+ "text.model.blocks.6.mlp.fc2.scales": "model-00001-of-00002.safetensors",
480
+ "text.model.blocks.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
481
+ "text.model.blocks.6.mlp.router.bias": "model-00001-of-00002.safetensors",
482
+ "text.model.blocks.6.mlp.router.biases": "model-00001-of-00002.safetensors",
483
+ "text.model.blocks.6.mlp.router.scales": "model-00001-of-00002.safetensors",
484
+ "text.model.blocks.6.mlp.router.weight": "model-00001-of-00002.safetensors",
485
+ "text.model.blocks.7.attn.proj.bias": "model-00001-of-00002.safetensors",
486
+ "text.model.blocks.7.attn.proj.biases": "model-00001-of-00002.safetensors",
487
+ "text.model.blocks.7.attn.proj.scales": "model-00001-of-00002.safetensors",
488
+ "text.model.blocks.7.attn.proj.weight": "model-00001-of-00002.safetensors",
489
+ "text.model.blocks.7.attn.qkv.bias": "model-00001-of-00002.safetensors",
490
+ "text.model.blocks.7.attn.qkv.biases": "model-00001-of-00002.safetensors",
491
+ "text.model.blocks.7.attn.qkv.scales": "model-00001-of-00002.safetensors",
492
+ "text.model.blocks.7.attn.qkv.weight": "model-00001-of-00002.safetensors",
493
+ "text.model.blocks.7.attn.tau.alpha": "model-00001-of-00002.safetensors",
494
+ "text.model.blocks.7.attn.tau.wq": "model-00001-of-00002.safetensors",
495
+ "text.model.blocks.7.attn.tau.wv": "model-00001-of-00002.safetensors",
496
+ "text.model.blocks.7.ln.bias": "model-00001-of-00002.safetensors",
497
+ "text.model.blocks.7.ln.weight": "model-00001-of-00002.safetensors",
498
+ "text.model.blocks.7.mlp.fc1.biases": "model-00001-of-00002.safetensors",
499
+ "text.model.blocks.7.mlp.fc1.scales": "model-00001-of-00002.safetensors",
500
+ "text.model.blocks.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
501
+ "text.model.blocks.7.mlp.fc2.biases": "model-00001-of-00002.safetensors",
502
+ "text.model.blocks.7.mlp.fc2.scales": "model-00001-of-00002.safetensors",
503
+ "text.model.blocks.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
504
+ "text.model.blocks.7.mlp.router.bias": "model-00001-of-00002.safetensors",
505
+ "text.model.blocks.7.mlp.router.biases": "model-00001-of-00002.safetensors",
506
+ "text.model.blocks.7.mlp.router.scales": "model-00001-of-00002.safetensors",
507
+ "text.model.blocks.7.mlp.router.weight": "model-00001-of-00002.safetensors",
508
+ "text.model.blocks.8.attn.proj.bias": "model-00001-of-00002.safetensors",
509
+ "text.model.blocks.8.attn.proj.biases": "model-00001-of-00002.safetensors",
510
+ "text.model.blocks.8.attn.proj.scales": "model-00001-of-00002.safetensors",
511
+ "text.model.blocks.8.attn.proj.weight": "model-00001-of-00002.safetensors",
512
+ "text.model.blocks.8.attn.qkv.bias": "model-00001-of-00002.safetensors",
513
+ "text.model.blocks.8.attn.qkv.biases": "model-00001-of-00002.safetensors",
514
+ "text.model.blocks.8.attn.qkv.scales": "model-00001-of-00002.safetensors",
515
+ "text.model.blocks.8.attn.qkv.weight": "model-00001-of-00002.safetensors",
516
+ "text.model.blocks.8.attn.tau.alpha": "model-00001-of-00002.safetensors",
517
+ "text.model.blocks.8.attn.tau.wq": "model-00001-of-00002.safetensors",
518
+ "text.model.blocks.8.attn.tau.wv": "model-00001-of-00002.safetensors",
519
+ "text.model.blocks.8.ln.bias": "model-00001-of-00002.safetensors",
520
+ "text.model.blocks.8.ln.weight": "model-00001-of-00002.safetensors",
521
+ "text.model.blocks.8.mlp.fc1.biases": "model-00001-of-00002.safetensors",
522
+ "text.model.blocks.8.mlp.fc1.scales": "model-00001-of-00002.safetensors",
523
+ "text.model.blocks.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
524
+ "text.model.blocks.8.mlp.fc2.biases": "model-00001-of-00002.safetensors",
525
+ "text.model.blocks.8.mlp.fc2.scales": "model-00001-of-00002.safetensors",
526
+ "text.model.blocks.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
527
+ "text.model.blocks.8.mlp.router.bias": "model-00001-of-00002.safetensors",
528
+ "text.model.blocks.8.mlp.router.biases": "model-00001-of-00002.safetensors",
529
+ "text.model.blocks.8.mlp.router.scales": "model-00001-of-00002.safetensors",
530
+ "text.model.blocks.8.mlp.router.weight": "model-00001-of-00002.safetensors",
531
+ "text.model.blocks.9.attn.proj.bias": "model-00001-of-00002.safetensors",
532
+ "text.model.blocks.9.attn.proj.biases": "model-00001-of-00002.safetensors",
533
+ "text.model.blocks.9.attn.proj.scales": "model-00001-of-00002.safetensors",
534
+ "text.model.blocks.9.attn.proj.weight": "model-00001-of-00002.safetensors",
535
+ "text.model.blocks.9.attn.qkv.bias": "model-00001-of-00002.safetensors",
536
+ "text.model.blocks.9.attn.qkv.biases": "model-00001-of-00002.safetensors",
537
+ "text.model.blocks.9.attn.qkv.scales": "model-00001-of-00002.safetensors",
538
+ "text.model.blocks.9.attn.qkv.weight": "model-00001-of-00002.safetensors",
539
+ "text.model.blocks.9.attn.tau.alpha": "model-00001-of-00002.safetensors",
540
+ "text.model.blocks.9.attn.tau.wq": "model-00001-of-00002.safetensors",
541
+ "text.model.blocks.9.attn.tau.wv": "model-00001-of-00002.safetensors",
542
+ "text.model.blocks.9.ln.bias": "model-00001-of-00002.safetensors",
543
+ "text.model.blocks.9.ln.weight": "model-00001-of-00002.safetensors",
544
+ "text.model.blocks.9.mlp.fc1.biases": "model-00001-of-00002.safetensors",
545
+ "text.model.blocks.9.mlp.fc1.scales": "model-00001-of-00002.safetensors",
546
+ "text.model.blocks.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
547
+ "text.model.blocks.9.mlp.fc2.biases": "model-00001-of-00002.safetensors",
548
+ "text.model.blocks.9.mlp.fc2.scales": "model-00001-of-00002.safetensors",
549
+ "text.model.blocks.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
550
+ "text.model.blocks.9.mlp.router.bias": "model-00001-of-00002.safetensors",
551
+ "text.model.blocks.9.mlp.router.biases": "model-00001-of-00002.safetensors",
552
+ "text.model.blocks.9.mlp.router.scales": "model-00001-of-00002.safetensors",
553
+ "text.model.blocks.9.mlp.router.weight": "model-00001-of-00002.safetensors",
554
+ "text.model.post_ln.bias": "model-00002-of-00002.safetensors",
555
+ "text.model.post_ln.weight": "model-00002-of-00002.safetensors",
556
+ "text.model.wte.biases": "model-00001-of-00002.safetensors",
557
+ "text.model.wte.scales": "model-00001-of-00002.safetensors",
558
+ "text.model.wte.weight": "model-00001-of-00002.safetensors",
559
+ "vision.encoder.blocks.0.attn.proj.bias": "model-00001-of-00002.safetensors",
560
+ "vision.encoder.blocks.0.attn.proj.biases": "model-00001-of-00002.safetensors",
561
+ "vision.encoder.blocks.0.attn.proj.scales": "model-00001-of-00002.safetensors",
562
+ "vision.encoder.blocks.0.attn.proj.weight": "model-00001-of-00002.safetensors",
563
+ "vision.encoder.blocks.0.attn.qkv.bias": "model-00001-of-00002.safetensors",
564
+ "vision.encoder.blocks.0.attn.qkv.biases": "model-00001-of-00002.safetensors",
565
+ "vision.encoder.blocks.0.attn.qkv.scales": "model-00001-of-00002.safetensors",
566
+ "vision.encoder.blocks.0.attn.qkv.weight": "model-00001-of-00002.safetensors",
567
+ "vision.encoder.blocks.0.ln1.bias": "model-00001-of-00002.safetensors",
568
+ "vision.encoder.blocks.0.ln1.weight": "model-00001-of-00002.safetensors",
569
+ "vision.encoder.blocks.0.ln2.bias": "model-00001-of-00002.safetensors",
570
+ "vision.encoder.blocks.0.ln2.weight": "model-00001-of-00002.safetensors",
571
+ "vision.encoder.blocks.0.mlp.fc1.bias": "model-00001-of-00002.safetensors",
572
+ "vision.encoder.blocks.0.mlp.fc1.biases": "model-00001-of-00002.safetensors",
573
+ "vision.encoder.blocks.0.mlp.fc1.scales": "model-00001-of-00002.safetensors",
574
+ "vision.encoder.blocks.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
575
+ "vision.encoder.blocks.0.mlp.fc2.bias": "model-00001-of-00002.safetensors",
576
+ "vision.encoder.blocks.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
577
+ "vision.encoder.blocks.1.attn.proj.bias": "model-00001-of-00002.safetensors",
578
+ "vision.encoder.blocks.1.attn.proj.biases": "model-00001-of-00002.safetensors",
579
+ "vision.encoder.blocks.1.attn.proj.scales": "model-00001-of-00002.safetensors",
580
+ "vision.encoder.blocks.1.attn.proj.weight": "model-00001-of-00002.safetensors",
581
+ "vision.encoder.blocks.1.attn.qkv.bias": "model-00001-of-00002.safetensors",
582
+ "vision.encoder.blocks.1.attn.qkv.biases": "model-00001-of-00002.safetensors",
583
+ "vision.encoder.blocks.1.attn.qkv.scales": "model-00001-of-00002.safetensors",
584
+ "vision.encoder.blocks.1.attn.qkv.weight": "model-00001-of-00002.safetensors",
585
+ "vision.encoder.blocks.1.ln1.bias": "model-00001-of-00002.safetensors",
586
+ "vision.encoder.blocks.1.ln1.weight": "model-00001-of-00002.safetensors",
587
+ "vision.encoder.blocks.1.ln2.bias": "model-00001-of-00002.safetensors",
588
+ "vision.encoder.blocks.1.ln2.weight": "model-00001-of-00002.safetensors",
589
+ "vision.encoder.blocks.1.mlp.fc1.bias": "model-00001-of-00002.safetensors",
590
+ "vision.encoder.blocks.1.mlp.fc1.biases": "model-00001-of-00002.safetensors",
591
+ "vision.encoder.blocks.1.mlp.fc1.scales": "model-00001-of-00002.safetensors",
592
+ "vision.encoder.blocks.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
593
+ "vision.encoder.blocks.1.mlp.fc2.bias": "model-00001-of-00002.safetensors",
594
+ "vision.encoder.blocks.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
595
+ "vision.encoder.blocks.10.attn.proj.bias": "model-00001-of-00002.safetensors",
596
+ "vision.encoder.blocks.10.attn.proj.biases": "model-00001-of-00002.safetensors",
597
+ "vision.encoder.blocks.10.attn.proj.scales": "model-00001-of-00002.safetensors",
598
+ "vision.encoder.blocks.10.attn.proj.weight": "model-00001-of-00002.safetensors",
599
+ "vision.encoder.blocks.10.attn.qkv.bias": "model-00001-of-00002.safetensors",
600
+ "vision.encoder.blocks.10.attn.qkv.biases": "model-00001-of-00002.safetensors",
601
+ "vision.encoder.blocks.10.attn.qkv.scales": "model-00001-of-00002.safetensors",
602
+ "vision.encoder.blocks.10.attn.qkv.weight": "model-00001-of-00002.safetensors",
603
+ "vision.encoder.blocks.10.ln1.bias": "model-00001-of-00002.safetensors",
604
+ "vision.encoder.blocks.10.ln1.weight": "model-00001-of-00002.safetensors",
605
+ "vision.encoder.blocks.10.ln2.bias": "model-00001-of-00002.safetensors",
606
+ "vision.encoder.blocks.10.ln2.weight": "model-00001-of-00002.safetensors",
607
+ "vision.encoder.blocks.10.mlp.fc1.bias": "model-00001-of-00002.safetensors",
608
+ "vision.encoder.blocks.10.mlp.fc1.biases": "model-00001-of-00002.safetensors",
609
+ "vision.encoder.blocks.10.mlp.fc1.scales": "model-00001-of-00002.safetensors",
610
+ "vision.encoder.blocks.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
611
+ "vision.encoder.blocks.10.mlp.fc2.bias": "model-00001-of-00002.safetensors",
612
+ "vision.encoder.blocks.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
613
+ "vision.encoder.blocks.11.attn.proj.bias": "model-00001-of-00002.safetensors",
614
+ "vision.encoder.blocks.11.attn.proj.biases": "model-00001-of-00002.safetensors",
615
+ "vision.encoder.blocks.11.attn.proj.scales": "model-00001-of-00002.safetensors",
616
+ "vision.encoder.blocks.11.attn.proj.weight": "model-00001-of-00002.safetensors",
617
+ "vision.encoder.blocks.11.attn.qkv.bias": "model-00001-of-00002.safetensors",
618
+ "vision.encoder.blocks.11.attn.qkv.biases": "model-00001-of-00002.safetensors",
619
+ "vision.encoder.blocks.11.attn.qkv.scales": "model-00001-of-00002.safetensors",
620
+ "vision.encoder.blocks.11.attn.qkv.weight": "model-00001-of-00002.safetensors",
621
+ "vision.encoder.blocks.11.ln1.bias": "model-00001-of-00002.safetensors",
622
+ "vision.encoder.blocks.11.ln1.weight": "model-00001-of-00002.safetensors",
623
+ "vision.encoder.blocks.11.ln2.bias": "model-00001-of-00002.safetensors",
624
+ "vision.encoder.blocks.11.ln2.weight": "model-00001-of-00002.safetensors",
625
+ "vision.encoder.blocks.11.mlp.fc1.bias": "model-00001-of-00002.safetensors",
626
+ "vision.encoder.blocks.11.mlp.fc1.biases": "model-00001-of-00002.safetensors",
627
+ "vision.encoder.blocks.11.mlp.fc1.scales": "model-00001-of-00002.safetensors",
628
+ "vision.encoder.blocks.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
629
+ "vision.encoder.blocks.11.mlp.fc2.bias": "model-00001-of-00002.safetensors",
630
+ "vision.encoder.blocks.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
631
+ "vision.encoder.blocks.12.attn.proj.bias": "model-00001-of-00002.safetensors",
632
+ "vision.encoder.blocks.12.attn.proj.biases": "model-00001-of-00002.safetensors",
633
+ "vision.encoder.blocks.12.attn.proj.scales": "model-00001-of-00002.safetensors",
634
+ "vision.encoder.blocks.12.attn.proj.weight": "model-00001-of-00002.safetensors",
635
+ "vision.encoder.blocks.12.attn.qkv.bias": "model-00001-of-00002.safetensors",
636
+ "vision.encoder.blocks.12.attn.qkv.biases": "model-00001-of-00002.safetensors",
637
+ "vision.encoder.blocks.12.attn.qkv.scales": "model-00001-of-00002.safetensors",
638
+ "vision.encoder.blocks.12.attn.qkv.weight": "model-00001-of-00002.safetensors",
639
+ "vision.encoder.blocks.12.ln1.bias": "model-00001-of-00002.safetensors",
640
+ "vision.encoder.blocks.12.ln1.weight": "model-00001-of-00002.safetensors",
641
+ "vision.encoder.blocks.12.ln2.bias": "model-00001-of-00002.safetensors",
642
+ "vision.encoder.blocks.12.ln2.weight": "model-00001-of-00002.safetensors",
643
+ "vision.encoder.blocks.12.mlp.fc1.bias": "model-00001-of-00002.safetensors",
644
+ "vision.encoder.blocks.12.mlp.fc1.biases": "model-00001-of-00002.safetensors",
645
+ "vision.encoder.blocks.12.mlp.fc1.scales": "model-00001-of-00002.safetensors",
646
+ "vision.encoder.blocks.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
647
+ "vision.encoder.blocks.12.mlp.fc2.bias": "model-00001-of-00002.safetensors",
648
+ "vision.encoder.blocks.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
649
+ "vision.encoder.blocks.13.attn.proj.bias": "model-00001-of-00002.safetensors",
650
+ "vision.encoder.blocks.13.attn.proj.biases": "model-00001-of-00002.safetensors",
651
+ "vision.encoder.blocks.13.attn.proj.scales": "model-00001-of-00002.safetensors",
652
+ "vision.encoder.blocks.13.attn.proj.weight": "model-00001-of-00002.safetensors",
653
+ "vision.encoder.blocks.13.attn.qkv.bias": "model-00001-of-00002.safetensors",
654
+ "vision.encoder.blocks.13.attn.qkv.biases": "model-00001-of-00002.safetensors",
655
+ "vision.encoder.blocks.13.attn.qkv.scales": "model-00001-of-00002.safetensors",
656
+ "vision.encoder.blocks.13.attn.qkv.weight": "model-00001-of-00002.safetensors",
657
+ "vision.encoder.blocks.13.ln1.bias": "model-00001-of-00002.safetensors",
658
+ "vision.encoder.blocks.13.ln1.weight": "model-00001-of-00002.safetensors",
659
+ "vision.encoder.blocks.13.ln2.bias": "model-00001-of-00002.safetensors",
660
+ "vision.encoder.blocks.13.ln2.weight": "model-00001-of-00002.safetensors",
661
+ "vision.encoder.blocks.13.mlp.fc1.bias": "model-00001-of-00002.safetensors",
662
+ "vision.encoder.blocks.13.mlp.fc1.biases": "model-00001-of-00002.safetensors",
663
+ "vision.encoder.blocks.13.mlp.fc1.scales": "model-00001-of-00002.safetensors",
664
+ "vision.encoder.blocks.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
665
+ "vision.encoder.blocks.13.mlp.fc2.bias": "model-00001-of-00002.safetensors",
666
+ "vision.encoder.blocks.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
667
+ "vision.encoder.blocks.14.attn.proj.bias": "model-00001-of-00002.safetensors",
668
+ "vision.encoder.blocks.14.attn.proj.biases": "model-00001-of-00002.safetensors",
669
+ "vision.encoder.blocks.14.attn.proj.scales": "model-00001-of-00002.safetensors",
670
+ "vision.encoder.blocks.14.attn.proj.weight": "model-00001-of-00002.safetensors",
671
+ "vision.encoder.blocks.14.attn.qkv.bias": "model-00001-of-00002.safetensors",
672
+ "vision.encoder.blocks.14.attn.qkv.biases": "model-00001-of-00002.safetensors",
673
+ "vision.encoder.blocks.14.attn.qkv.scales": "model-00001-of-00002.safetensors",
674
+ "vision.encoder.blocks.14.attn.qkv.weight": "model-00001-of-00002.safetensors",
675
+ "vision.encoder.blocks.14.ln1.bias": "model-00001-of-00002.safetensors",
676
+ "vision.encoder.blocks.14.ln1.weight": "model-00001-of-00002.safetensors",
677
+ "vision.encoder.blocks.14.ln2.bias": "model-00001-of-00002.safetensors",
678
+ "vision.encoder.blocks.14.ln2.weight": "model-00001-of-00002.safetensors",
679
+ "vision.encoder.blocks.14.mlp.fc1.bias": "model-00001-of-00002.safetensors",
680
+ "vision.encoder.blocks.14.mlp.fc1.biases": "model-00001-of-00002.safetensors",
681
+ "vision.encoder.blocks.14.mlp.fc1.scales": "model-00001-of-00002.safetensors",
682
+ "vision.encoder.blocks.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
683
+ "vision.encoder.blocks.14.mlp.fc2.bias": "model-00001-of-00002.safetensors",
684
+ "vision.encoder.blocks.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
685
+ "vision.encoder.blocks.15.attn.proj.bias": "model-00001-of-00002.safetensors",
686
+ "vision.encoder.blocks.15.attn.proj.biases": "model-00001-of-00002.safetensors",
687
+ "vision.encoder.blocks.15.attn.proj.scales": "model-00001-of-00002.safetensors",
688
+ "vision.encoder.blocks.15.attn.proj.weight": "model-00001-of-00002.safetensors",
689
+ "vision.encoder.blocks.15.attn.qkv.bias": "model-00001-of-00002.safetensors",
690
+ "vision.encoder.blocks.15.attn.qkv.biases": "model-00001-of-00002.safetensors",
691
+ "vision.encoder.blocks.15.attn.qkv.scales": "model-00001-of-00002.safetensors",
692
+ "vision.encoder.blocks.15.attn.qkv.weight": "model-00001-of-00002.safetensors",
693
+ "vision.encoder.blocks.15.ln1.bias": "model-00001-of-00002.safetensors",
694
+ "vision.encoder.blocks.15.ln1.weight": "model-00001-of-00002.safetensors",
695
+ "vision.encoder.blocks.15.ln2.bias": "model-00001-of-00002.safetensors",
696
+ "vision.encoder.blocks.15.ln2.weight": "model-00001-of-00002.safetensors",
697
+ "vision.encoder.blocks.15.mlp.fc1.bias": "model-00001-of-00002.safetensors",
698
+ "vision.encoder.blocks.15.mlp.fc1.biases": "model-00001-of-00002.safetensors",
699
+ "vision.encoder.blocks.15.mlp.fc1.scales": "model-00001-of-00002.safetensors",
700
+ "vision.encoder.blocks.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
701
+ "vision.encoder.blocks.15.mlp.fc2.bias": "model-00001-of-00002.safetensors",
702
+ "vision.encoder.blocks.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
703
+ "vision.encoder.blocks.16.attn.proj.bias": "model-00001-of-00002.safetensors",
704
+ "vision.encoder.blocks.16.attn.proj.biases": "model-00001-of-00002.safetensors",
705
+ "vision.encoder.blocks.16.attn.proj.scales": "model-00001-of-00002.safetensors",
706
+ "vision.encoder.blocks.16.attn.proj.weight": "model-00001-of-00002.safetensors",
707
+ "vision.encoder.blocks.16.attn.qkv.bias": "model-00001-of-00002.safetensors",
708
+ "vision.encoder.blocks.16.attn.qkv.biases": "model-00001-of-00002.safetensors",
709
+ "vision.encoder.blocks.16.attn.qkv.scales": "model-00001-of-00002.safetensors",
710
+ "vision.encoder.blocks.16.attn.qkv.weight": "model-00001-of-00002.safetensors",
711
+ "vision.encoder.blocks.16.ln1.bias": "model-00001-of-00002.safetensors",
712
+ "vision.encoder.blocks.16.ln1.weight": "model-00001-of-00002.safetensors",
713
+ "vision.encoder.blocks.16.ln2.bias": "model-00001-of-00002.safetensors",
714
+ "vision.encoder.blocks.16.ln2.weight": "model-00001-of-00002.safetensors",
715
+ "vision.encoder.blocks.16.mlp.fc1.bias": "model-00001-of-00002.safetensors",
716
+ "vision.encoder.blocks.16.mlp.fc1.biases": "model-00001-of-00002.safetensors",
717
+ "vision.encoder.blocks.16.mlp.fc1.scales": "model-00001-of-00002.safetensors",
718
+ "vision.encoder.blocks.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
719
+ "vision.encoder.blocks.16.mlp.fc2.bias": "model-00001-of-00002.safetensors",
720
+ "vision.encoder.blocks.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
721
+ "vision.encoder.blocks.17.attn.proj.bias": "model-00001-of-00002.safetensors",
722
+ "vision.encoder.blocks.17.attn.proj.biases": "model-00001-of-00002.safetensors",
723
+ "vision.encoder.blocks.17.attn.proj.scales": "model-00001-of-00002.safetensors",
724
+ "vision.encoder.blocks.17.attn.proj.weight": "model-00001-of-00002.safetensors",
725
+ "vision.encoder.blocks.17.attn.qkv.bias": "model-00001-of-00002.safetensors",
726
+ "vision.encoder.blocks.17.attn.qkv.biases": "model-00001-of-00002.safetensors",
727
+ "vision.encoder.blocks.17.attn.qkv.scales": "model-00001-of-00002.safetensors",
728
+ "vision.encoder.blocks.17.attn.qkv.weight": "model-00001-of-00002.safetensors",
729
+ "vision.encoder.blocks.17.ln1.bias": "model-00001-of-00002.safetensors",
730
+ "vision.encoder.blocks.17.ln1.weight": "model-00001-of-00002.safetensors",
731
+ "vision.encoder.blocks.17.ln2.bias": "model-00001-of-00002.safetensors",
732
+ "vision.encoder.blocks.17.ln2.weight": "model-00001-of-00002.safetensors",
733
+ "vision.encoder.blocks.17.mlp.fc1.bias": "model-00001-of-00002.safetensors",
734
+ "vision.encoder.blocks.17.mlp.fc1.biases": "model-00001-of-00002.safetensors",
735
+ "vision.encoder.blocks.17.mlp.fc1.scales": "model-00001-of-00002.safetensors",
736
+ "vision.encoder.blocks.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
737
+ "vision.encoder.blocks.17.mlp.fc2.bias": "model-00001-of-00002.safetensors",
738
+ "vision.encoder.blocks.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
739
+ "vision.encoder.blocks.18.attn.proj.bias": "model-00001-of-00002.safetensors",
740
+ "vision.encoder.blocks.18.attn.proj.biases": "model-00001-of-00002.safetensors",
741
+ "vision.encoder.blocks.18.attn.proj.scales": "model-00001-of-00002.safetensors",
742
+ "vision.encoder.blocks.18.attn.proj.weight": "model-00001-of-00002.safetensors",
743
+ "vision.encoder.blocks.18.attn.qkv.bias": "model-00001-of-00002.safetensors",
744
+ "vision.encoder.blocks.18.attn.qkv.biases": "model-00001-of-00002.safetensors",
745
+ "vision.encoder.blocks.18.attn.qkv.scales": "model-00001-of-00002.safetensors",
746
+ "vision.encoder.blocks.18.attn.qkv.weight": "model-00001-of-00002.safetensors",
747
+ "vision.encoder.blocks.18.ln1.bias": "model-00001-of-00002.safetensors",
748
+ "vision.encoder.blocks.18.ln1.weight": "model-00001-of-00002.safetensors",
749
+ "vision.encoder.blocks.18.ln2.bias": "model-00001-of-00002.safetensors",
750
+ "vision.encoder.blocks.18.ln2.weight": "model-00001-of-00002.safetensors",
751
+ "vision.encoder.blocks.18.mlp.fc1.bias": "model-00001-of-00002.safetensors",
752
+ "vision.encoder.blocks.18.mlp.fc1.biases": "model-00001-of-00002.safetensors",
753
+ "vision.encoder.blocks.18.mlp.fc1.scales": "model-00001-of-00002.safetensors",
754
+ "vision.encoder.blocks.18.mlp.fc1.weight": "model-00001-of-00002.safetensors",
755
+ "vision.encoder.blocks.18.mlp.fc2.bias": "model-00001-of-00002.safetensors",
756
+ "vision.encoder.blocks.18.mlp.fc2.weight": "model-00001-of-00002.safetensors",
757
+ "vision.encoder.blocks.19.attn.proj.bias": "model-00001-of-00002.safetensors",
758
+ "vision.encoder.blocks.19.attn.proj.biases": "model-00001-of-00002.safetensors",
759
+ "vision.encoder.blocks.19.attn.proj.scales": "model-00001-of-00002.safetensors",
760
+ "vision.encoder.blocks.19.attn.proj.weight": "model-00001-of-00002.safetensors",
761
+ "vision.encoder.blocks.19.attn.qkv.bias": "model-00001-of-00002.safetensors",
762
+ "vision.encoder.blocks.19.attn.qkv.biases": "model-00001-of-00002.safetensors",
763
+ "vision.encoder.blocks.19.attn.qkv.scales": "model-00001-of-00002.safetensors",
764
+ "vision.encoder.blocks.19.attn.qkv.weight": "model-00001-of-00002.safetensors",
765
+ "vision.encoder.blocks.19.ln1.bias": "model-00001-of-00002.safetensors",
766
+ "vision.encoder.blocks.19.ln1.weight": "model-00001-of-00002.safetensors",
767
+ "vision.encoder.blocks.19.ln2.bias": "model-00001-of-00002.safetensors",
768
+ "vision.encoder.blocks.19.ln2.weight": "model-00001-of-00002.safetensors",
769
+ "vision.encoder.blocks.19.mlp.fc1.bias": "model-00001-of-00002.safetensors",
770
+ "vision.encoder.blocks.19.mlp.fc1.biases": "model-00001-of-00002.safetensors",
771
+ "vision.encoder.blocks.19.mlp.fc1.scales": "model-00001-of-00002.safetensors",
772
+ "vision.encoder.blocks.19.mlp.fc1.weight": "model-00001-of-00002.safetensors",
773
+ "vision.encoder.blocks.19.mlp.fc2.bias": "model-00001-of-00002.safetensors",
774
+ "vision.encoder.blocks.19.mlp.fc2.weight": "model-00001-of-00002.safetensors",
775
+ "vision.encoder.blocks.2.attn.proj.bias": "model-00001-of-00002.safetensors",
776
+ "vision.encoder.blocks.2.attn.proj.biases": "model-00001-of-00002.safetensors",
777
+ "vision.encoder.blocks.2.attn.proj.scales": "model-00001-of-00002.safetensors",
778
+ "vision.encoder.blocks.2.attn.proj.weight": "model-00001-of-00002.safetensors",
779
+ "vision.encoder.blocks.2.attn.qkv.bias": "model-00001-of-00002.safetensors",
780
+ "vision.encoder.blocks.2.attn.qkv.biases": "model-00001-of-00002.safetensors",
781
+ "vision.encoder.blocks.2.attn.qkv.scales": "model-00001-of-00002.safetensors",
782
+ "vision.encoder.blocks.2.attn.qkv.weight": "model-00001-of-00002.safetensors",
783
+ "vision.encoder.blocks.2.ln1.bias": "model-00001-of-00002.safetensors",
784
+ "vision.encoder.blocks.2.ln1.weight": "model-00001-of-00002.safetensors",
785
+ "vision.encoder.blocks.2.ln2.bias": "model-00001-of-00002.safetensors",
786
+ "vision.encoder.blocks.2.ln2.weight": "model-00001-of-00002.safetensors",
787
+ "vision.encoder.blocks.2.mlp.fc1.bias": "model-00001-of-00002.safetensors",
788
+ "vision.encoder.blocks.2.mlp.fc1.biases": "model-00001-of-00002.safetensors",
789
+ "vision.encoder.blocks.2.mlp.fc1.scales": "model-00001-of-00002.safetensors",
790
+ "vision.encoder.blocks.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
791
+ "vision.encoder.blocks.2.mlp.fc2.bias": "model-00001-of-00002.safetensors",
792
+ "vision.encoder.blocks.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
793
+ "vision.encoder.blocks.20.attn.proj.bias": "model-00001-of-00002.safetensors",
794
+ "vision.encoder.blocks.20.attn.proj.biases": "model-00001-of-00002.safetensors",
795
+ "vision.encoder.blocks.20.attn.proj.scales": "model-00001-of-00002.safetensors",
796
+ "vision.encoder.blocks.20.attn.proj.weight": "model-00001-of-00002.safetensors",
797
+ "vision.encoder.blocks.20.attn.qkv.bias": "model-00001-of-00002.safetensors",
798
+ "vision.encoder.blocks.20.attn.qkv.biases": "model-00001-of-00002.safetensors",
799
+ "vision.encoder.blocks.20.attn.qkv.scales": "model-00001-of-00002.safetensors",
800
+ "vision.encoder.blocks.20.attn.qkv.weight": "model-00001-of-00002.safetensors",
801
+ "vision.encoder.blocks.20.ln1.bias": "model-00001-of-00002.safetensors",
802
+ "vision.encoder.blocks.20.ln1.weight": "model-00001-of-00002.safetensors",
803
+ "vision.encoder.blocks.20.ln2.bias": "model-00001-of-00002.safetensors",
804
+ "vision.encoder.blocks.20.ln2.weight": "model-00001-of-00002.safetensors",
805
+ "vision.encoder.blocks.20.mlp.fc1.bias": "model-00001-of-00002.safetensors",
806
+ "vision.encoder.blocks.20.mlp.fc1.biases": "model-00001-of-00002.safetensors",
807
+ "vision.encoder.blocks.20.mlp.fc1.scales": "model-00001-of-00002.safetensors",
808
+ "vision.encoder.blocks.20.mlp.fc1.weight": "model-00001-of-00002.safetensors",
809
+ "vision.encoder.blocks.20.mlp.fc2.bias": "model-00001-of-00002.safetensors",
810
+ "vision.encoder.blocks.20.mlp.fc2.weight": "model-00001-of-00002.safetensors",
811
+ "vision.encoder.blocks.21.attn.proj.bias": "model-00001-of-00002.safetensors",
812
+ "vision.encoder.blocks.21.attn.proj.biases": "model-00001-of-00002.safetensors",
813
+ "vision.encoder.blocks.21.attn.proj.scales": "model-00001-of-00002.safetensors",
814
+ "vision.encoder.blocks.21.attn.proj.weight": "model-00001-of-00002.safetensors",
815
+ "vision.encoder.blocks.21.attn.qkv.bias": "model-00001-of-00002.safetensors",
816
+ "vision.encoder.blocks.21.attn.qkv.biases": "model-00001-of-00002.safetensors",
817
+ "vision.encoder.blocks.21.attn.qkv.scales": "model-00001-of-00002.safetensors",
818
+ "vision.encoder.blocks.21.attn.qkv.weight": "model-00001-of-00002.safetensors",
819
+ "vision.encoder.blocks.21.ln1.bias": "model-00001-of-00002.safetensors",
820
+ "vision.encoder.blocks.21.ln1.weight": "model-00001-of-00002.safetensors",
821
+ "vision.encoder.blocks.21.ln2.bias": "model-00001-of-00002.safetensors",
822
+ "vision.encoder.blocks.21.ln2.weight": "model-00001-of-00002.safetensors",
823
+ "vision.encoder.blocks.21.mlp.fc1.bias": "model-00001-of-00002.safetensors",
824
+ "vision.encoder.blocks.21.mlp.fc1.biases": "model-00001-of-00002.safetensors",
825
+ "vision.encoder.blocks.21.mlp.fc1.scales": "model-00001-of-00002.safetensors",
826
+ "vision.encoder.blocks.21.mlp.fc1.weight": "model-00001-of-00002.safetensors",
827
+ "vision.encoder.blocks.21.mlp.fc2.bias": "model-00001-of-00002.safetensors",
828
+ "vision.encoder.blocks.21.mlp.fc2.weight": "model-00001-of-00002.safetensors",
829
+ "vision.encoder.blocks.22.attn.proj.bias": "model-00001-of-00002.safetensors",
830
+ "vision.encoder.blocks.22.attn.proj.biases": "model-00001-of-00002.safetensors",
831
+ "vision.encoder.blocks.22.attn.proj.scales": "model-00001-of-00002.safetensors",
832
+ "vision.encoder.blocks.22.attn.proj.weight": "model-00001-of-00002.safetensors",
833
+ "vision.encoder.blocks.22.attn.qkv.bias": "model-00001-of-00002.safetensors",
834
+ "vision.encoder.blocks.22.attn.qkv.biases": "model-00001-of-00002.safetensors",
835
+ "vision.encoder.blocks.22.attn.qkv.scales": "model-00001-of-00002.safetensors",
836
+ "vision.encoder.blocks.22.attn.qkv.weight": "model-00001-of-00002.safetensors",
837
+ "vision.encoder.blocks.22.ln1.bias": "model-00001-of-00002.safetensors",
838
+ "vision.encoder.blocks.22.ln1.weight": "model-00001-of-00002.safetensors",
839
+ "vision.encoder.blocks.22.ln2.bias": "model-00001-of-00002.safetensors",
840
+ "vision.encoder.blocks.22.ln2.weight": "model-00001-of-00002.safetensors",
841
+ "vision.encoder.blocks.22.mlp.fc1.bias": "model-00001-of-00002.safetensors",
842
+ "vision.encoder.blocks.22.mlp.fc1.biases": "model-00001-of-00002.safetensors",
843
+ "vision.encoder.blocks.22.mlp.fc1.scales": "model-00001-of-00002.safetensors",
844
+ "vision.encoder.blocks.22.mlp.fc1.weight": "model-00001-of-00002.safetensors",
845
+ "vision.encoder.blocks.22.mlp.fc2.bias": "model-00001-of-00002.safetensors",
846
+ "vision.encoder.blocks.22.mlp.fc2.weight": "model-00001-of-00002.safetensors",
847
+ "vision.encoder.blocks.23.attn.proj.bias": "model-00001-of-00002.safetensors",
848
+ "vision.encoder.blocks.23.attn.proj.biases": "model-00001-of-00002.safetensors",
849
+ "vision.encoder.blocks.23.attn.proj.scales": "model-00001-of-00002.safetensors",
850
+ "vision.encoder.blocks.23.attn.proj.weight": "model-00001-of-00002.safetensors",
851
+ "vision.encoder.blocks.23.attn.qkv.bias": "model-00001-of-00002.safetensors",
852
+ "vision.encoder.blocks.23.attn.qkv.biases": "model-00001-of-00002.safetensors",
853
+ "vision.encoder.blocks.23.attn.qkv.scales": "model-00001-of-00002.safetensors",
854
+ "vision.encoder.blocks.23.attn.qkv.weight": "model-00001-of-00002.safetensors",
855
+ "vision.encoder.blocks.23.ln1.bias": "model-00001-of-00002.safetensors",
856
+ "vision.encoder.blocks.23.ln1.weight": "model-00001-of-00002.safetensors",
857
+ "vision.encoder.blocks.23.ln2.bias": "model-00001-of-00002.safetensors",
858
+ "vision.encoder.blocks.23.ln2.weight": "model-00001-of-00002.safetensors",
859
+ "vision.encoder.blocks.23.mlp.fc1.bias": "model-00001-of-00002.safetensors",
860
+ "vision.encoder.blocks.23.mlp.fc1.biases": "model-00001-of-00002.safetensors",
861
+ "vision.encoder.blocks.23.mlp.fc1.scales": "model-00001-of-00002.safetensors",
862
+ "vision.encoder.blocks.23.mlp.fc1.weight": "model-00001-of-00002.safetensors",
863
+ "vision.encoder.blocks.23.mlp.fc2.bias": "model-00001-of-00002.safetensors",
864
+ "vision.encoder.blocks.23.mlp.fc2.weight": "model-00001-of-00002.safetensors",
865
+ "vision.encoder.blocks.24.attn.proj.bias": "model-00001-of-00002.safetensors",
866
+ "vision.encoder.blocks.24.attn.proj.biases": "model-00001-of-00002.safetensors",
867
+ "vision.encoder.blocks.24.attn.proj.scales": "model-00001-of-00002.safetensors",
868
+ "vision.encoder.blocks.24.attn.proj.weight": "model-00001-of-00002.safetensors",
869
+ "vision.encoder.blocks.24.attn.qkv.bias": "model-00001-of-00002.safetensors",
870
+ "vision.encoder.blocks.24.attn.qkv.biases": "model-00001-of-00002.safetensors",
871
+ "vision.encoder.blocks.24.attn.qkv.scales": "model-00001-of-00002.safetensors",
872
+ "vision.encoder.blocks.24.attn.qkv.weight": "model-00001-of-00002.safetensors",
873
+ "vision.encoder.blocks.24.ln1.bias": "model-00001-of-00002.safetensors",
874
+ "vision.encoder.blocks.24.ln1.weight": "model-00001-of-00002.safetensors",
875
+ "vision.encoder.blocks.24.ln2.bias": "model-00001-of-00002.safetensors",
876
+ "vision.encoder.blocks.24.ln2.weight": "model-00001-of-00002.safetensors",
877
+ "vision.encoder.blocks.24.mlp.fc1.bias": "model-00001-of-00002.safetensors",
878
+ "vision.encoder.blocks.24.mlp.fc1.biases": "model-00001-of-00002.safetensors",
879
+ "vision.encoder.blocks.24.mlp.fc1.scales": "model-00001-of-00002.safetensors",
880
+ "vision.encoder.blocks.24.mlp.fc1.weight": "model-00001-of-00002.safetensors",
881
+ "vision.encoder.blocks.24.mlp.fc2.bias": "model-00001-of-00002.safetensors",
882
+ "vision.encoder.blocks.24.mlp.fc2.weight": "model-00001-of-00002.safetensors",
883
+ "vision.encoder.blocks.25.attn.proj.bias": "model-00001-of-00002.safetensors",
884
+ "vision.encoder.blocks.25.attn.proj.biases": "model-00001-of-00002.safetensors",
885
+ "vision.encoder.blocks.25.attn.proj.scales": "model-00001-of-00002.safetensors",
886
+ "vision.encoder.blocks.25.attn.proj.weight": "model-00001-of-00002.safetensors",
887
+ "vision.encoder.blocks.25.attn.qkv.bias": "model-00001-of-00002.safetensors",
888
+ "vision.encoder.blocks.25.attn.qkv.biases": "model-00001-of-00002.safetensors",
889
+ "vision.encoder.blocks.25.attn.qkv.scales": "model-00001-of-00002.safetensors",
890
+ "vision.encoder.blocks.25.attn.qkv.weight": "model-00001-of-00002.safetensors",
891
+ "vision.encoder.blocks.25.ln1.bias": "model-00001-of-00002.safetensors",
892
+ "vision.encoder.blocks.25.ln1.weight": "model-00001-of-00002.safetensors",
893
+ "vision.encoder.blocks.25.ln2.bias": "model-00001-of-00002.safetensors",
894
+ "vision.encoder.blocks.25.ln2.weight": "model-00001-of-00002.safetensors",
895
+ "vision.encoder.blocks.25.mlp.fc1.bias": "model-00001-of-00002.safetensors",
896
+ "vision.encoder.blocks.25.mlp.fc1.biases": "model-00001-of-00002.safetensors",
897
+ "vision.encoder.blocks.25.mlp.fc1.scales": "model-00001-of-00002.safetensors",
898
+ "vision.encoder.blocks.25.mlp.fc1.weight": "model-00001-of-00002.safetensors",
899
+ "vision.encoder.blocks.25.mlp.fc2.bias": "model-00001-of-00002.safetensors",
900
+ "vision.encoder.blocks.25.mlp.fc2.weight": "model-00001-of-00002.safetensors",
901
+ "vision.encoder.blocks.26.attn.proj.bias": "model-00001-of-00002.safetensors",
902
+ "vision.encoder.blocks.26.attn.proj.biases": "model-00001-of-00002.safetensors",
903
+ "vision.encoder.blocks.26.attn.proj.scales": "model-00001-of-00002.safetensors",
904
+ "vision.encoder.blocks.26.attn.proj.weight": "model-00001-of-00002.safetensors",
905
+ "vision.encoder.blocks.26.attn.qkv.bias": "model-00001-of-00002.safetensors",
906
+ "vision.encoder.blocks.26.attn.qkv.biases": "model-00001-of-00002.safetensors",
907
+ "vision.encoder.blocks.26.attn.qkv.scales": "model-00001-of-00002.safetensors",
908
+ "vision.encoder.blocks.26.attn.qkv.weight": "model-00001-of-00002.safetensors",
909
+ "vision.encoder.blocks.26.ln1.bias": "model-00001-of-00002.safetensors",
910
+ "vision.encoder.blocks.26.ln1.weight": "model-00001-of-00002.safetensors",
911
+ "vision.encoder.blocks.26.ln2.bias": "model-00001-of-00002.safetensors",
912
+ "vision.encoder.blocks.26.ln2.weight": "model-00001-of-00002.safetensors",
913
+ "vision.encoder.blocks.26.mlp.fc1.bias": "model-00001-of-00002.safetensors",
914
+ "vision.encoder.blocks.26.mlp.fc1.biases": "model-00001-of-00002.safetensors",
915
+ "vision.encoder.blocks.26.mlp.fc1.scales": "model-00001-of-00002.safetensors",
916
+ "vision.encoder.blocks.26.mlp.fc1.weight": "model-00001-of-00002.safetensors",
917
+ "vision.encoder.blocks.26.mlp.fc2.bias": "model-00001-of-00002.safetensors",
918
+ "vision.encoder.blocks.26.mlp.fc2.weight": "model-00001-of-00002.safetensors",
919
+ "vision.encoder.blocks.3.attn.proj.bias": "model-00001-of-00002.safetensors",
920
+ "vision.encoder.blocks.3.attn.proj.biases": "model-00001-of-00002.safetensors",
921
+ "vision.encoder.blocks.3.attn.proj.scales": "model-00001-of-00002.safetensors",
922
+ "vision.encoder.blocks.3.attn.proj.weight": "model-00001-of-00002.safetensors",
923
+ "vision.encoder.blocks.3.attn.qkv.bias": "model-00001-of-00002.safetensors",
924
+ "vision.encoder.blocks.3.attn.qkv.biases": "model-00001-of-00002.safetensors",
925
+ "vision.encoder.blocks.3.attn.qkv.scales": "model-00001-of-00002.safetensors",
926
+ "vision.encoder.blocks.3.attn.qkv.weight": "model-00001-of-00002.safetensors",
927
+ "vision.encoder.blocks.3.ln1.bias": "model-00001-of-00002.safetensors",
928
+ "vision.encoder.blocks.3.ln1.weight": "model-00001-of-00002.safetensors",
929
+ "vision.encoder.blocks.3.ln2.bias": "model-00001-of-00002.safetensors",
930
+ "vision.encoder.blocks.3.ln2.weight": "model-00001-of-00002.safetensors",
931
+ "vision.encoder.blocks.3.mlp.fc1.bias": "model-00001-of-00002.safetensors",
932
+ "vision.encoder.blocks.3.mlp.fc1.biases": "model-00001-of-00002.safetensors",
933
+ "vision.encoder.blocks.3.mlp.fc1.scales": "model-00001-of-00002.safetensors",
934
+ "vision.encoder.blocks.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
935
+ "vision.encoder.blocks.3.mlp.fc2.bias": "model-00001-of-00002.safetensors",
936
+ "vision.encoder.blocks.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
937
+ "vision.encoder.blocks.4.attn.proj.bias": "model-00001-of-00002.safetensors",
938
+ "vision.encoder.blocks.4.attn.proj.biases": "model-00001-of-00002.safetensors",
939
+ "vision.encoder.blocks.4.attn.proj.scales": "model-00001-of-00002.safetensors",
940
+ "vision.encoder.blocks.4.attn.proj.weight": "model-00001-of-00002.safetensors",
941
+ "vision.encoder.blocks.4.attn.qkv.bias": "model-00001-of-00002.safetensors",
942
+ "vision.encoder.blocks.4.attn.qkv.biases": "model-00001-of-00002.safetensors",
943
+ "vision.encoder.blocks.4.attn.qkv.scales": "model-00001-of-00002.safetensors",
944
+ "vision.encoder.blocks.4.attn.qkv.weight": "model-00001-of-00002.safetensors",
945
+ "vision.encoder.blocks.4.ln1.bias": "model-00001-of-00002.safetensors",
946
+ "vision.encoder.blocks.4.ln1.weight": "model-00001-of-00002.safetensors",
947
+ "vision.encoder.blocks.4.ln2.bias": "model-00001-of-00002.safetensors",
948
+ "vision.encoder.blocks.4.ln2.weight": "model-00001-of-00002.safetensors",
949
+ "vision.encoder.blocks.4.mlp.fc1.bias": "model-00001-of-00002.safetensors",
950
+ "vision.encoder.blocks.4.mlp.fc1.biases": "model-00001-of-00002.safetensors",
951
+ "vision.encoder.blocks.4.mlp.fc1.scales": "model-00001-of-00002.safetensors",
952
+ "vision.encoder.blocks.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
953
+ "vision.encoder.blocks.4.mlp.fc2.bias": "model-00001-of-00002.safetensors",
954
+ "vision.encoder.blocks.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
955
+ "vision.encoder.blocks.5.attn.proj.bias": "model-00001-of-00002.safetensors",
956
+ "vision.encoder.blocks.5.attn.proj.biases": "model-00001-of-00002.safetensors",
957
+ "vision.encoder.blocks.5.attn.proj.scales": "model-00001-of-00002.safetensors",
958
+ "vision.encoder.blocks.5.attn.proj.weight": "model-00001-of-00002.safetensors",
959
+ "vision.encoder.blocks.5.attn.qkv.bias": "model-00001-of-00002.safetensors",
960
+ "vision.encoder.blocks.5.attn.qkv.biases": "model-00001-of-00002.safetensors",
961
+ "vision.encoder.blocks.5.attn.qkv.scales": "model-00001-of-00002.safetensors",
962
+ "vision.encoder.blocks.5.attn.qkv.weight": "model-00001-of-00002.safetensors",
963
+ "vision.encoder.blocks.5.ln1.bias": "model-00001-of-00002.safetensors",
964
+ "vision.encoder.blocks.5.ln1.weight": "model-00001-of-00002.safetensors",
965
+ "vision.encoder.blocks.5.ln2.bias": "model-00001-of-00002.safetensors",
966
+ "vision.encoder.blocks.5.ln2.weight": "model-00001-of-00002.safetensors",
967
+ "vision.encoder.blocks.5.mlp.fc1.bias": "model-00001-of-00002.safetensors",
968
+ "vision.encoder.blocks.5.mlp.fc1.biases": "model-00001-of-00002.safetensors",
969
+ "vision.encoder.blocks.5.mlp.fc1.scales": "model-00001-of-00002.safetensors",
970
+ "vision.encoder.blocks.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
971
+ "vision.encoder.blocks.5.mlp.fc2.bias": "model-00001-of-00002.safetensors",
972
+ "vision.encoder.blocks.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
973
+ "vision.encoder.blocks.6.attn.proj.bias": "model-00001-of-00002.safetensors",
974
+ "vision.encoder.blocks.6.attn.proj.biases": "model-00001-of-00002.safetensors",
975
+ "vision.encoder.blocks.6.attn.proj.scales": "model-00001-of-00002.safetensors",
976
+ "vision.encoder.blocks.6.attn.proj.weight": "model-00001-of-00002.safetensors",
977
+ "vision.encoder.blocks.6.attn.qkv.bias": "model-00001-of-00002.safetensors",
978
+ "vision.encoder.blocks.6.attn.qkv.biases": "model-00001-of-00002.safetensors",
979
+ "vision.encoder.blocks.6.attn.qkv.scales": "model-00001-of-00002.safetensors",
980
+ "vision.encoder.blocks.6.attn.qkv.weight": "model-00001-of-00002.safetensors",
981
+ "vision.encoder.blocks.6.ln1.bias": "model-00001-of-00002.safetensors",
982
+ "vision.encoder.blocks.6.ln1.weight": "model-00001-of-00002.safetensors",
983
+ "vision.encoder.blocks.6.ln2.bias": "model-00001-of-00002.safetensors",
984
+ "vision.encoder.blocks.6.ln2.weight": "model-00001-of-00002.safetensors",
985
+ "vision.encoder.blocks.6.mlp.fc1.bias": "model-00001-of-00002.safetensors",
986
+ "vision.encoder.blocks.6.mlp.fc1.biases": "model-00001-of-00002.safetensors",
987
+ "vision.encoder.blocks.6.mlp.fc1.scales": "model-00001-of-00002.safetensors",
988
+ "vision.encoder.blocks.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
989
+ "vision.encoder.blocks.6.mlp.fc2.bias": "model-00001-of-00002.safetensors",
990
+ "vision.encoder.blocks.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
991
+ "vision.encoder.blocks.7.attn.proj.bias": "model-00001-of-00002.safetensors",
992
+ "vision.encoder.blocks.7.attn.proj.biases": "model-00001-of-00002.safetensors",
993
+ "vision.encoder.blocks.7.attn.proj.scales": "model-00001-of-00002.safetensors",
994
+ "vision.encoder.blocks.7.attn.proj.weight": "model-00001-of-00002.safetensors",
995
+ "vision.encoder.blocks.7.attn.qkv.bias": "model-00001-of-00002.safetensors",
996
+ "vision.encoder.blocks.7.attn.qkv.biases": "model-00001-of-00002.safetensors",
997
+ "vision.encoder.blocks.7.attn.qkv.scales": "model-00001-of-00002.safetensors",
998
+ "vision.encoder.blocks.7.attn.qkv.weight": "model-00001-of-00002.safetensors",
999
+ "vision.encoder.blocks.7.ln1.bias": "model-00001-of-00002.safetensors",
1000
+ "vision.encoder.blocks.7.ln1.weight": "model-00001-of-00002.safetensors",
1001
+ "vision.encoder.blocks.7.ln2.bias": "model-00001-of-00002.safetensors",
1002
+ "vision.encoder.blocks.7.ln2.weight": "model-00001-of-00002.safetensors",
1003
+ "vision.encoder.blocks.7.mlp.fc1.bias": "model-00001-of-00002.safetensors",
1004
+ "vision.encoder.blocks.7.mlp.fc1.biases": "model-00001-of-00002.safetensors",
1005
+ "vision.encoder.blocks.7.mlp.fc1.scales": "model-00001-of-00002.safetensors",
1006
+ "vision.encoder.blocks.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
1007
+ "vision.encoder.blocks.7.mlp.fc2.bias": "model-00001-of-00002.safetensors",
1008
+ "vision.encoder.blocks.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
1009
+ "vision.encoder.blocks.8.attn.proj.bias": "model-00001-of-00002.safetensors",
1010
+ "vision.encoder.blocks.8.attn.proj.biases": "model-00001-of-00002.safetensors",
1011
+ "vision.encoder.blocks.8.attn.proj.scales": "model-00001-of-00002.safetensors",
1012
+ "vision.encoder.blocks.8.attn.proj.weight": "model-00001-of-00002.safetensors",
1013
+ "vision.encoder.blocks.8.attn.qkv.bias": "model-00001-of-00002.safetensors",
1014
+ "vision.encoder.blocks.8.attn.qkv.biases": "model-00001-of-00002.safetensors",
1015
+ "vision.encoder.blocks.8.attn.qkv.scales": "model-00001-of-00002.safetensors",
1016
+ "vision.encoder.blocks.8.attn.qkv.weight": "model-00001-of-00002.safetensors",
1017
+ "vision.encoder.blocks.8.ln1.bias": "model-00001-of-00002.safetensors",
1018
+ "vision.encoder.blocks.8.ln1.weight": "model-00001-of-00002.safetensors",
1019
+ "vision.encoder.blocks.8.ln2.bias": "model-00001-of-00002.safetensors",
1020
+ "vision.encoder.blocks.8.ln2.weight": "model-00001-of-00002.safetensors",
1021
+ "vision.encoder.blocks.8.mlp.fc1.bias": "model-00001-of-00002.safetensors",
1022
+ "vision.encoder.blocks.8.mlp.fc1.biases": "model-00001-of-00002.safetensors",
1023
+ "vision.encoder.blocks.8.mlp.fc1.scales": "model-00001-of-00002.safetensors",
1024
+ "vision.encoder.blocks.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
1025
+ "vision.encoder.blocks.8.mlp.fc2.bias": "model-00001-of-00002.safetensors",
1026
+ "vision.encoder.blocks.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
1027
+ "vision.encoder.blocks.9.attn.proj.bias": "model-00001-of-00002.safetensors",
1028
+ "vision.encoder.blocks.9.attn.proj.biases": "model-00001-of-00002.safetensors",
1029
+ "vision.encoder.blocks.9.attn.proj.scales": "model-00001-of-00002.safetensors",
1030
+ "vision.encoder.blocks.9.attn.proj.weight": "model-00001-of-00002.safetensors",
1031
+ "vision.encoder.blocks.9.attn.qkv.bias": "model-00001-of-00002.safetensors",
1032
+ "vision.encoder.blocks.9.attn.qkv.biases": "model-00001-of-00002.safetensors",
1033
+ "vision.encoder.blocks.9.attn.qkv.scales": "model-00001-of-00002.safetensors",
1034
+ "vision.encoder.blocks.9.attn.qkv.weight": "model-00001-of-00002.safetensors",
1035
+ "vision.encoder.blocks.9.ln1.bias": "model-00001-of-00002.safetensors",
1036
+ "vision.encoder.blocks.9.ln1.weight": "model-00001-of-00002.safetensors",
1037
+ "vision.encoder.blocks.9.ln2.bias": "model-00001-of-00002.safetensors",
1038
+ "vision.encoder.blocks.9.ln2.weight": "model-00001-of-00002.safetensors",
1039
+ "vision.encoder.blocks.9.mlp.fc1.bias": "model-00001-of-00002.safetensors",
1040
+ "vision.encoder.blocks.9.mlp.fc1.biases": "model-00001-of-00002.safetensors",
1041
+ "vision.encoder.blocks.9.mlp.fc1.scales": "model-00001-of-00002.safetensors",
1042
+ "vision.encoder.blocks.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
1043
+ "vision.encoder.blocks.9.mlp.fc2.bias": "model-00001-of-00002.safetensors",
1044
+ "vision.encoder.blocks.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
1045
+ "vision.encoder.patch_emb.bias": "model-00001-of-00002.safetensors",
1046
+ "vision.encoder.patch_emb.weight": "model-00001-of-00002.safetensors",
1047
+ "vision.encoder.pos_emb": "model-00001-of-00002.safetensors",
1048
+ "vision.encoder.post_ln.bias": "model-00001-of-00002.safetensors",
1049
+ "vision.encoder.post_ln.weight": "model-00001-of-00002.safetensors",
1050
+ "vision.proj_mlp.fc1.bias": "model-00001-of-00002.safetensors",
1051
+ "vision.proj_mlp.fc1.biases": "model-00001-of-00002.safetensors",
1052
+ "vision.proj_mlp.fc1.scales": "model-00001-of-00002.safetensors",
1053
+ "vision.proj_mlp.fc1.weight": "model-00001-of-00002.safetensors",
1054
+ "vision.proj_mlp.fc2.bias": "model-00001-of-00002.safetensors",
1055
+ "vision.proj_mlp.fc2.biases": "model-00001-of-00002.safetensors",
1056
+ "vision.proj_mlp.fc2.scales": "model-00001-of-00002.safetensors",
1057
+ "vision.proj_mlp.fc2.weight": "model-00001-of-00002.safetensors"
1058
+ }
1059
+ }
moondream.py ADDED
@@ -0,0 +1,1097 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
+ from PIL import Image
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
+ from torch.nn.attention.flex_attention import create_block_mask
10
+
11
+ from .config import MoondreamConfig
12
+ from .image_crops import reconstruct_from_crops
13
+ from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
14
+ from .text import build_text_model, text_encoder, lm_head, text_decoder
15
+ from .region import (
16
+ decode_coordinate,
17
+ encode_coordinate,
18
+ decode_size,
19
+ encode_size,
20
+ encode_spatial_refs,
21
+ SpatialRefs,
22
+ )
23
+ from .layers import QuantizedLinear
24
+ from .lora import load_adapter, normalize_adapter_id
25
+ from .rope import precompute_freqs_cis
26
+ from .utils import remove_outlier_points
27
+
28
+ ImageEncodingSettings = TypedDict(
29
+ "ImageEncodingSettings",
30
+ {"adapter": str, "model": str},
31
+ total=False,
32
+ )
33
+
34
+ TextSamplingSettings = TypedDict(
35
+ "TextSamplingSettings",
36
+ {
37
+ "max_tokens": int,
38
+ "temperature": float,
39
+ "top_p": float,
40
+ "adapter": str,
41
+ "model": str,
42
+ },
43
+ total=False,
44
+ )
45
+
46
+ ObjectSamplingSettings = TypedDict(
47
+ "ObjectSamplingSettings",
48
+ {"max_objects": int, "adapter": str, "model": str},
49
+ total=False,
50
+ )
51
+
52
+
53
+ DEFAULT_MAX_TOKENS = 768
54
+ DEFAULT_TEMPERATURE = 0.5
55
+ DEFAULT_TOP_P = 0.9
56
+ DEFAULT_MAX_OBJECTS = 150
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class EncodedImage:
61
+ pos: int
62
+ caches: List[Tuple[torch.Tensor, torch.Tensor]]
63
+
64
+
65
+ class KVCache(nn.Module):
66
+
67
+ def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
68
+ super().__init__()
69
+ cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
70
+ self.register_buffer(
71
+ "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
72
+ )
73
+ self.register_buffer(
74
+ "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
75
+ )
76
+
77
+ def update(self, pos_ids, k, v):
78
+ kout, vout = self.k_cache, self.v_cache
79
+ kout[:, :, pos_ids, :] = k
80
+ vout[:, :, pos_ids, :] = v
81
+ return kout, vout
82
+
83
+
84
+ def causal_mask(b, h, q_idx, kv_idx):
85
+ return q_idx >= kv_idx
86
+
87
+
88
+ def get_mask_mod(mask_mod, offset):
89
+ def _mask_mod(b, h, q, kv):
90
+ return mask_mod(b, h, q + offset, kv)
91
+
92
+ return _mask_mod
93
+
94
+
95
+ class MoondreamModel(nn.Module):
96
+
97
+ def __init__(
98
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
99
+ ):
100
+ super().__init__()
101
+ self.config = config
102
+
103
+ self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
104
+ self.vision = build_vision_model(config.vision, dtype)
105
+ self.text = build_text_model(config.text, dtype)
106
+
107
+ # Region Model
108
+ linear_cls = (
109
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
110
+ )
111
+ self.region = nn.ModuleDict(
112
+ {
113
+ "coord_encoder": linear_cls(
114
+ config.region.coord_feat_dim, config.region.dim, dtype=dtype
115
+ ),
116
+ "coord_decoder": linear_cls(
117
+ config.region.dim, config.region.coord_out_dim, dtype=dtype
118
+ ),
119
+ "size_encoder": linear_cls(
120
+ config.region.size_feat_dim, config.region.dim, dtype=dtype
121
+ ),
122
+ "size_decoder": linear_cls(
123
+ config.region.dim, config.region.size_out_dim, dtype=dtype
124
+ ),
125
+ "ln": nn.LayerNorm(config.region.dim, dtype=dtype),
126
+ }
127
+ )
128
+ self.region.coord_features = nn.Parameter(
129
+ torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
130
+ )
131
+ self.region.size_features = nn.Parameter(
132
+ torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
133
+ )
134
+
135
+ attn_mask = torch.tril(
136
+ torch.ones(
137
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
138
+ )
139
+ )
140
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
141
+ prefix_attn_len = 1 + patch_w**2
142
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
143
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
144
+
145
+ self.use_flex_decoding = True
146
+ self._causal_block_mask = None
147
+ self._point_gen_indices = None
148
+
149
+ # Initialize KV caches.
150
+ if setup_caches:
151
+ self._setup_caches()
152
+
153
+ @property
154
+ def causal_block_mask(self):
155
+ # The things we do to deal with ZeroGPU...
156
+ if self._causal_block_mask is None:
157
+ self._causal_block_mask = create_block_mask(
158
+ causal_mask,
159
+ B=None,
160
+ H=None,
161
+ Q_LEN=self.config.text.max_context,
162
+ KV_LEN=self.config.text.max_context,
163
+ )
164
+ return self._causal_block_mask
165
+
166
+ @property
167
+ def point_gen_indices(self):
168
+ if self._point_gen_indices is None:
169
+ self._point_gen_indices = torch.tensor(
170
+ [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id],
171
+ device=self.device,
172
+ )
173
+ return self._point_gen_indices
174
+
175
+ def _refresh_runtime_buffers(self):
176
+ attn_mask = torch.tril(
177
+ torch.ones(
178
+ 1,
179
+ 1,
180
+ self.config.text.max_context,
181
+ self.config.text.max_context,
182
+ dtype=torch.bool,
183
+ device=self.device,
184
+ )
185
+ )
186
+ patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
187
+ prefix_attn_len = 1 + patch_w**2
188
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
189
+ self.attn_mask = attn_mask
190
+ self.text.freqs_cis = precompute_freqs_cis(
191
+ self.config.text.dim // (2 * self.config.text.n_heads),
192
+ self.config.text.max_context,
193
+ ).to(device=self.device)
194
+
195
+ def _setup_caches(self):
196
+ c = self.config.text
197
+ for b in self.text.blocks:
198
+ b.kv_cache = KVCache(
199
+ c.n_heads,
200
+ c.n_kv_heads,
201
+ c.max_context,
202
+ c.dim,
203
+ device=self.device,
204
+ dtype=self.vision.pos_emb.dtype,
205
+ )
206
+
207
+ def _adapter_id_from_settings(self, settings: Optional[dict]) -> Optional[str]:
208
+ if settings is None:
209
+ return None
210
+ adapter = settings.get("adapter")
211
+ if adapter is not None:
212
+ return normalize_adapter_id(adapter)
213
+
214
+ model_value = settings.get("model")
215
+ if isinstance(model_value, str):
216
+ return normalize_adapter_id(model_value)
217
+ return None
218
+
219
+ def _resolve_lora(self, settings: Optional[dict]) -> Optional[object]:
220
+ adapter_id = self._adapter_id_from_settings(settings)
221
+ if adapter_id is None:
222
+ return None
223
+ return load_adapter(
224
+ adapter_id,
225
+ text_config=self.config.text,
226
+ device=self.device,
227
+ dtype=self.vision.pos_emb.dtype,
228
+ )
229
+
230
+ @property
231
+ def device(self):
232
+ return self.vision.pos_emb.device
233
+
234
+ def _vis_enc(self, x: torch.Tensor):
235
+ return vision_encoder(x, self.vision, self.config.vision)
236
+
237
+ def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
238
+ return vision_projection(g, r, self.vision, self.config.vision)
239
+
240
+ def _prefill(
241
+ self,
242
+ x: torch.Tensor,
243
+ attn_mask: torch.Tensor,
244
+ pos_ids: torch.Tensor,
245
+ lora: Optional[torch.Tensor],
246
+ ):
247
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
248
+
249
+ def _decode_one_tok(
250
+ self,
251
+ x: torch.Tensor,
252
+ attn_mask: torch.Tensor,
253
+ pos_ids: torch.Tensor,
254
+ lora: Optional[torch.Tensor],
255
+ lm_head_indices: Optional[torch.Tensor] = None,
256
+ ):
257
+ if self.use_flex_decoding:
258
+ torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape")
259
+ block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0]
260
+ mask = self.causal_block_mask[:, :, block_index]
261
+ mask.seq_lengths = (1, mask.seq_lengths[1])
262
+ mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0])
263
+ else:
264
+ mask = None
265
+
266
+ hidden = text_decoder(
267
+ x,
268
+ self.text,
269
+ attn_mask,
270
+ pos_ids,
271
+ self.config.text,
272
+ lora=lora,
273
+ flex_block_mask_slice=mask,
274
+ )
275
+ logits = lm_head(hidden, self.text, indices=lm_head_indices)
276
+ return logits, hidden
277
+
278
+ def compile(self):
279
+ for module in self.modules():
280
+ if isinstance(module, QuantizedLinear):
281
+ module.unpack()
282
+
283
+ # Initialize lazy properties to avoid first-call overhead
284
+ self.causal_block_mask
285
+ self.point_gen_indices
286
+
287
+ # TODO: vision_projection and _prefill is not being compiled
288
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
289
+ self._decode_one_tok = torch.compile(
290
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
291
+ )
292
+
293
+ # Warm up compiled methods with dummy forward passes
294
+ device = self.device
295
+ dtype = self.vision.pos_emb.dtype
296
+ with torch.no_grad():
297
+ # Warmup vision encoder
298
+ dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype)
299
+ self._vis_enc(dummy_crops)
300
+
301
+ # Warmup _decode_one_tok (both normal and point generation modes)
302
+ dummy_emb = torch.randn(
303
+ 1, 1, self.config.text.dim, device=device, dtype=dtype
304
+ )
305
+ dummy_mask = torch.ones(
306
+ 1, 1, self.config.text.max_context, device=device, dtype=torch.bool
307
+ )
308
+ dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long)
309
+ self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None)
310
+ self._decode_one_tok(
311
+ dummy_emb,
312
+ dummy_mask,
313
+ dummy_pos_ids,
314
+ None,
315
+ lm_head_indices=self.point_gen_indices,
316
+ )
317
+
318
+ def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
319
+ all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
320
+
321
+ torch._dynamo.mark_dynamic(all_crops, 0)
322
+
323
+ outputs = self._vis_enc(all_crops)
324
+
325
+ global_features = outputs[0]
326
+ local_features = outputs[1:].view(
327
+ -1,
328
+ self.config.vision.enc_n_layers,
329
+ self.config.vision.enc_n_layers,
330
+ self.config.vision.enc_dim,
331
+ )
332
+
333
+ reconstructed = reconstruct_from_crops(
334
+ local_features,
335
+ tiling,
336
+ patch_size=1,
337
+ overlap_margin=self.config.vision.overlap_margin,
338
+ )
339
+
340
+ return self._vis_proj(global_features, reconstructed)
341
+
342
+ def encode_image(
343
+ self,
344
+ image: Union[Image.Image, EncodedImage],
345
+ settings: Optional[ImageEncodingSettings] = None,
346
+ ) -> EncodedImage:
347
+ if isinstance(image, EncodedImage):
348
+ return image
349
+ elif not isinstance(image, Image.Image):
350
+ raise ValueError("image must be a PIL Image or EncodedImage")
351
+
352
+ lora = self._resolve_lora(settings)
353
+
354
+ # Run through text model in addition to the vision encoder, to minimize
355
+ # re-computation if multiple queries are performed on this image.
356
+ with torch.inference_mode():
357
+ img_emb = self._run_vision_encoder(image)
358
+ bos_emb = text_encoder(
359
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
360
+ self.text,
361
+ )
362
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
363
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
364
+ pos_ids = torch.arange(
365
+ inputs_embeds.size(1), dtype=torch.long, device=self.device
366
+ )
367
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
368
+
369
+ return EncodedImage(
370
+ pos=inputs_embeds.size(1),
371
+ caches=[
372
+ (
373
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
374
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
375
+ )
376
+ for b in self.text.blocks
377
+ ],
378
+ )
379
+
380
+ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
381
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
382
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
383
+ mask = probs_sum - probs_sort > top_p
384
+ probs_sort[mask] = 0.0
385
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
386
+ next_probs = torch.zeros_like(probs)
387
+ next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
388
+ return next_probs
389
+
390
+ def _prefill_prompt(
391
+ self,
392
+ prompt_tokens: torch.Tensor,
393
+ pos: int,
394
+ temperature: float,
395
+ top_p: float,
396
+ spatial_refs: Optional[SpatialRefs] = None,
397
+ attn_mask: Optional[torch.Tensor] = None,
398
+ lora: Optional[dict] = None,
399
+ ):
400
+ with torch.inference_mode():
401
+ prompt_emb = text_encoder(prompt_tokens, self.text)
402
+
403
+ if spatial_refs:
404
+ encoded_refs = encode_spatial_refs(spatial_refs, self.region)
405
+ prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
406
+ encoded_refs["coords"]
407
+ )
408
+ if encoded_refs["sizes"] is not None:
409
+ prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
410
+ encoded_refs["sizes"]
411
+ )
412
+
413
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
414
+
415
+ if attn_mask is None:
416
+ attn_mask = self.attn_mask
417
+
418
+ mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
419
+ pos_ids = torch.arange(
420
+ pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device
421
+ )
422
+ hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
423
+ logits_BV = lm_head(hidden_BC, self.text)
424
+
425
+ if temperature == 0:
426
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
427
+ else:
428
+ probs = torch.softmax(logits_BV / temperature, dim=-1)
429
+ probs = self._apply_top_p(probs, top_p)
430
+ next_token = torch.multinomial(probs, num_samples=1)
431
+
432
+ pos = pos + prompt_emb.size(1)
433
+ return logits_BV, hidden_BC, next_token, pos
434
+
435
+ def _generate_reasoning(
436
+ self,
437
+ prompt_tokens,
438
+ pos,
439
+ settings: Optional[TextSamplingSettings] = None,
440
+ spatial_refs: Optional[SpatialRefs] = None,
441
+ attn_mask: Optional[torch.Tensor] = None,
442
+ ) -> Tuple[int, str, List[dict]]:
443
+ max_tokens = (
444
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
445
+ if settings
446
+ else DEFAULT_MAX_TOKENS
447
+ )
448
+ temperature = (
449
+ settings.get("temperature", DEFAULT_TEMPERATURE)
450
+ if settings
451
+ else DEFAULT_TEMPERATURE
452
+ )
453
+ lora = self._resolve_lora(settings)
454
+
455
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
456
+ eos_id = self.config.tokenizer.answer_id
457
+
458
+ _, last_hidden_BC, next_token, pos = self._prefill_prompt(
459
+ prompt_tokens,
460
+ pos,
461
+ temperature,
462
+ top_p,
463
+ spatial_refs,
464
+ attn_mask=attn_mask,
465
+ lora=lora,
466
+ )
467
+
468
+ text_token_chunks = [[]]
469
+ grounding_chunks = [[]]
470
+
471
+ mask = torch.zeros(
472
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
473
+ )
474
+ mask[:, :, :pos] = 1
475
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
476
+ generated_tokens = 0
477
+
478
+ while (
479
+ next_token_id := next_token.item()
480
+ ) != eos_id and generated_tokens < max_tokens:
481
+ if (
482
+ next_token_id == self.config.tokenizer.start_ground_points_id
483
+ or next_token_id == self.config.tokenizer.end_ground_id
484
+ ):
485
+ text_token_chunks.append([])
486
+ grounding_chunks.append([])
487
+
488
+ text_token_chunks[-1].append(next_token_id)
489
+
490
+ with torch.inference_mode():
491
+ if next_token_id == self.config.tokenizer.coord_id:
492
+ coord_logits = decode_coordinate(last_hidden_BC, self.region)
493
+ coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
494
+ grounding_chunks[-1].append(coord.item())
495
+
496
+ next_emb = encode_coordinate(
497
+ coord.to(dtype=coord_logits.dtype), self.region
498
+ ).unsqueeze(0)
499
+ else:
500
+ next_emb = text_encoder(next_token, self.text)
501
+
502
+ mask[:, :, pos], pos_ids[0] = 1, pos
503
+
504
+ logits_BV, last_hidden_BC = self._decode_one_tok(
505
+ next_emb, mask, pos_ids, lora
506
+ )
507
+ logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
508
+ logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
509
+
510
+ pos += 1
511
+
512
+ if temperature == 0:
513
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
514
+ else:
515
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
516
+ probs = self._apply_top_p(probs, top_p)
517
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
518
+
519
+ generated_tokens += 1
520
+
521
+ text_chunks = [
522
+ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
523
+ ]
524
+ text = "".join(text_chunks)
525
+
526
+ start_idx = 0
527
+ grounding = []
528
+ for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
529
+ if len(grounding_chunk) > 1:
530
+ points = []
531
+ for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
532
+ points.append((grounding_chunk[i], grounding_chunk[i + 1]))
533
+ grounding.append(
534
+ {
535
+ "start_idx": start_idx,
536
+ "end_idx": start_idx + len(text_chunk),
537
+ "points": points,
538
+ }
539
+ )
540
+ start_idx += len(text_chunk)
541
+
542
+ return pos, text, grounding
543
+
544
+ def _generate_answer(
545
+ self,
546
+ prompt_tokens: torch.Tensor,
547
+ pos: int,
548
+ settings: Optional[TextSamplingSettings] = None,
549
+ spatial_refs: Optional[SpatialRefs] = None,
550
+ eos_id: Optional[int] = None,
551
+ attn_mask: Optional[torch.Tensor] = None,
552
+ ):
553
+ max_tokens = (
554
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
555
+ if settings
556
+ else DEFAULT_MAX_TOKENS
557
+ )
558
+ temperature = (
559
+ settings.get("temperature", DEFAULT_TEMPERATURE)
560
+ if settings
561
+ else DEFAULT_TEMPERATURE
562
+ )
563
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
564
+ eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
565
+ lora = self._resolve_lora(settings)
566
+
567
+ _, _, next_token, pos = self._prefill_prompt(
568
+ prompt_tokens,
569
+ pos,
570
+ temperature,
571
+ top_p,
572
+ spatial_refs,
573
+ attn_mask=attn_mask,
574
+ lora=lora,
575
+ )
576
+
577
+ def generator(next_token, pos):
578
+ mask = torch.zeros(
579
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
580
+ )
581
+ mask[:, :, :pos] = 1
582
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
583
+ generated_tokens = 0
584
+
585
+ # For properly handling token streaming with Unicode
586
+ token_cache = []
587
+ print_len = 0
588
+
589
+ while (
590
+ next_token_id := next_token.item()
591
+ ) != eos_id and generated_tokens < max_tokens:
592
+ # Add token to our cache
593
+ token_cache.append(next_token_id)
594
+
595
+ # Decode all tokens collected so far
596
+ text = self.tokenizer.decode(token_cache)
597
+
598
+ # After a newline, we flush the cache completely
599
+ if text.endswith("\n"):
600
+ printable_text = text[print_len:]
601
+ token_cache = []
602
+ print_len = 0
603
+ if printable_text:
604
+ yield printable_text
605
+ # If the last token is a CJK character, we can safely print it
606
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
607
+ printable_text = text[print_len:]
608
+ print_len += len(printable_text)
609
+ if printable_text:
610
+ yield printable_text
611
+ # Otherwise, only yield up to the last space to avoid cutting words
612
+ else:
613
+ last_space_idx = text.rfind(" ", print_len)
614
+ if last_space_idx >= print_len:
615
+ printable_text = text[print_len : last_space_idx + 1]
616
+ print_len += len(printable_text)
617
+ if printable_text:
618
+ yield printable_text
619
+
620
+ with torch.inference_mode():
621
+ next_emb = text_encoder(next_token, self.text)
622
+ mask[:, :, pos], pos_ids[0] = 1, pos
623
+
624
+ logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
625
+ logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
626
+
627
+ pos += 1
628
+
629
+ if temperature == 0:
630
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
631
+ 1
632
+ ) # (1, 1)
633
+ else:
634
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
635
+ probs = self._apply_top_p(probs, top_p)
636
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
637
+
638
+ generated_tokens += 1
639
+
640
+ # Flush any remaining text in the cache
641
+ if token_cache:
642
+ text = self.tokenizer.decode(token_cache)
643
+ printable_text = text[print_len:]
644
+ if printable_text:
645
+ yield printable_text
646
+
647
+ return generator(next_token, pos)
648
+
649
+ def query(
650
+ self,
651
+ image: Optional[Union[Image.Image, EncodedImage]] = None,
652
+ question: str = None,
653
+ reasoning: bool = True,
654
+ spatial_refs: Optional[SpatialRefs] = None,
655
+ stream: bool = False,
656
+ settings: Optional[TextSamplingSettings] = None,
657
+ ):
658
+ if self.config.tokenizer.templates["query"] is None:
659
+ raise NotImplementedError("Model does not support querying.")
660
+
661
+ if question is None:
662
+ raise ValueError("question must be provided.")
663
+
664
+ if spatial_refs and image is None:
665
+ raise ValueError("spatial_refs can only be used with an image.")
666
+
667
+ attn_mask = self.attn_mask
668
+ if image is not None:
669
+ image = self.encode_image(image, settings)
670
+ self.load_encoded_image(image)
671
+ pos = image.pos
672
+ prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
673
+ else:
674
+ self._setup_caches()
675
+ pos = 0
676
+ prompt_toks = [
677
+ self.config.tokenizer.bos_id
678
+ ] + self.config.tokenizer.templates["query"]["prefix"]
679
+ max_context = self.config.text.max_context
680
+ attn_mask = torch.tril(
681
+ torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
682
+ ).to(self.device)
683
+
684
+ spatial_toks = []
685
+ if spatial_refs:
686
+ for ref in spatial_refs:
687
+ coord_id = self.config.tokenizer.coord_id
688
+ size_id = self.config.tokenizer.size_id
689
+ if len(ref) == 2:
690
+ spatial_toks.extend([coord_id, coord_id])
691
+ else:
692
+ spatial_toks.extend([coord_id, coord_id, size_id])
693
+
694
+ prompt_tokens = [
695
+ prompt_toks + spatial_toks + self.tokenizer.encode(question).ids
696
+ ]
697
+
698
+ if reasoning:
699
+ prompt_tokens[0] += [self.config.tokenizer.thinking_id]
700
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
701
+ pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
702
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
703
+ )
704
+ prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
705
+ reasoning_dict = {
706
+ "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
707
+ }
708
+ spatial_refs = None
709
+ else:
710
+ prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
711
+ reasoning_dict = {}
712
+
713
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
714
+
715
+ def generator():
716
+ for token in self._generate_answer(
717
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
718
+ ):
719
+ yield token
720
+
721
+ if stream:
722
+ return {**reasoning_dict, "answer": generator()}
723
+ else:
724
+ return {**reasoning_dict, "answer": "".join(list(generator()))}
725
+
726
+ def load_encoded_image(self, encoded_image: EncodedImage):
727
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
728
+ b.kv_cache.k_cache[:, :, : k.size(2), :] = k
729
+ b.kv_cache.v_cache[:, :, : v.size(2), :] = v
730
+
731
+ def caption(
732
+ self,
733
+ image: Union[Image.Image, EncodedImage],
734
+ length: Literal["normal", "short", "long"] = "normal",
735
+ stream: bool = False,
736
+ settings: Optional[TextSamplingSettings] = None,
737
+ ):
738
+ if self.config.tokenizer.templates["caption"] is None:
739
+ raise NotImplementedError("Model does not support captioning.")
740
+ if length not in self.config.tokenizer.templates["caption"]:
741
+ raise ValueError(f"Model does not support caption length '{length}'.")
742
+
743
+ image = self.encode_image(image, settings)
744
+ self.load_encoded_image(image)
745
+
746
+ prompt_tokens = torch.tensor(
747
+ [self.config.tokenizer.templates["caption"][length]], device=self.device
748
+ )
749
+
750
+ def generator():
751
+ for token in self._generate_answer(prompt_tokens, image.pos, settings):
752
+ yield token
753
+
754
+ if stream:
755
+ return {"caption": generator()}
756
+ else:
757
+ return {"caption": "".join(list(generator()))}
758
+
759
+ def _generate_points(
760
+ self,
761
+ hidden: torch.Tensor,
762
+ next_token: torch.Tensor,
763
+ pos: int,
764
+ include_size: bool = True,
765
+ max_objects: int = DEFAULT_MAX_OBJECTS,
766
+ lora: Optional[dict] = None,
767
+ ):
768
+ out = []
769
+ mask = torch.zeros(
770
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
771
+ )
772
+ mask[:, :, :pos] = 1
773
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
774
+
775
+ with torch.inference_mode():
776
+ while (
777
+ next_token.item() != self.config.tokenizer.eos_id
778
+ and len(out) < max_objects
779
+ ):
780
+ x_logits = decode_coordinate(hidden, self.region)
781
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
782
+ next_emb = encode_coordinate(
783
+ x_center.to(dtype=x_logits.dtype), self.region
784
+ ).unsqueeze(0)
785
+
786
+ # Decode y-coordinate
787
+ mask[:, :, pos], pos_ids[0] = 1, pos
788
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
789
+ pos += 1
790
+ y_logits = decode_coordinate(hidden, self.region)
791
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
792
+ next_emb = encode_coordinate(
793
+ y_center.to(dtype=y_logits.dtype), self.region
794
+ ).unsqueeze(0)
795
+
796
+ # Decode size
797
+ if include_size:
798
+ mask[:, :, pos], pos_ids[0] = 1, pos
799
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
800
+ pos += 1
801
+ size_logits = decode_size(hidden, self.region)
802
+
803
+ # Get bin indices from the logits
804
+ w_bin = torch.argmax(size_logits[0], dim=-1)
805
+ h_bin = torch.argmax(size_logits[1], dim=-1)
806
+
807
+ # Convert from bin indices to actual size values using the inverse of the log-scale mapping
808
+ # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
809
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
810
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
811
+
812
+ next_emb = (
813
+ encode_size(
814
+ torch.tensor(
815
+ [w, h], device=self.device, dtype=size_logits.dtype
816
+ ),
817
+ self.region,
818
+ )
819
+ .unsqueeze(0)
820
+ .unsqueeze(0)
821
+ )
822
+
823
+ # Add object
824
+ out.append(
825
+ {
826
+ "x_min": x_center.item() - w.item() / 2,
827
+ "y_min": y_center.item() - h.item() / 2,
828
+ "x_max": x_center.item() + w.item() / 2,
829
+ "y_max": y_center.item() + h.item() / 2,
830
+ }
831
+ )
832
+ else:
833
+ out.append({"x": x_center.item(), "y": y_center.item()})
834
+
835
+ # Decode next token (x-coordinate, or eos)
836
+ mask[:, :, pos], pos_ids[0] = 1, pos
837
+ logits, hidden = self._decode_one_tok(
838
+ next_emb,
839
+ mask,
840
+ pos_ids,
841
+ lora,
842
+ lm_head_indices=self.point_gen_indices,
843
+ )
844
+ pos += 1
845
+ # Map back: index 0 -> coord_id, index 1 -> eos_id
846
+ next_token_idx = torch.argmax(logits, dim=-1)
847
+ next_token = self.point_gen_indices[next_token_idx]
848
+
849
+ return out
850
+
851
+ def detect(
852
+ self,
853
+ image: Union[Image.Image, EncodedImage],
854
+ object: str,
855
+ settings: Optional[ObjectSamplingSettings] = None,
856
+ ):
857
+ if self.config.tokenizer.templates["detect"] is None:
858
+ raise NotImplementedError("Model does not support object detection.")
859
+
860
+ image = self.encode_image(image, settings)
861
+ self.load_encoded_image(image)
862
+
863
+ prompt_tokens = torch.tensor(
864
+ [
865
+ self.config.tokenizer.templates["detect"]["prefix"]
866
+ + self.tokenizer.encode(" " + object).ids
867
+ + self.config.tokenizer.templates["detect"]["suffix"]
868
+ ],
869
+ device=self.device,
870
+ )
871
+
872
+ lora = self._resolve_lora(settings)
873
+
874
+ _, hidden, next_token, pos = self._prefill_prompt(
875
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
876
+ )
877
+ hidden = hidden[:, -1:, :]
878
+
879
+ max_objects = (
880
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
881
+ if settings
882
+ else DEFAULT_MAX_OBJECTS
883
+ )
884
+ objects = self._generate_points(
885
+ hidden,
886
+ next_token,
887
+ pos,
888
+ include_size=True,
889
+ max_objects=max_objects,
890
+ lora=lora,
891
+ )
892
+
893
+ return {"objects": objects}
894
+
895
+ def point(
896
+ self,
897
+ image: Union[Image.Image, EncodedImage],
898
+ object: str,
899
+ settings: Optional[ObjectSamplingSettings] = None,
900
+ ):
901
+ if self.config.tokenizer.templates["point"] is None:
902
+ raise NotImplementedError("Model does not support pointing.")
903
+
904
+ image = self.encode_image(image, settings)
905
+ self.load_encoded_image(image)
906
+
907
+ prompt_tokens = torch.tensor(
908
+ [
909
+ self.config.tokenizer.templates["point"]["prefix"]
910
+ + self.tokenizer.encode(" " + object).ids
911
+ + self.config.tokenizer.templates["point"]["suffix"]
912
+ ],
913
+ device=self.device,
914
+ )
915
+
916
+ lora = self._resolve_lora(settings)
917
+
918
+ _, hidden, next_token, pos = self._prefill_prompt(
919
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
920
+ )
921
+ hidden = hidden[:, -1:, :]
922
+
923
+ max_objects = (
924
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
925
+ if settings
926
+ else DEFAULT_MAX_OBJECTS
927
+ )
928
+ objects = self._generate_points(
929
+ hidden,
930
+ next_token,
931
+ pos,
932
+ include_size=False,
933
+ max_objects=max_objects,
934
+ lora=lora,
935
+ )
936
+
937
+ return {"points": objects}
938
+
939
+ def _detect_gaze(
940
+ self,
941
+ image: EncodedImage,
942
+ source: Tuple[float, float],
943
+ force_detect: bool = False,
944
+ ):
945
+ with torch.inference_mode():
946
+ before_emb = text_encoder(
947
+ torch.tensor(
948
+ [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
949
+ ),
950
+ self.text,
951
+ )
952
+ after_emb = text_encoder(
953
+ torch.tensor(
954
+ [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
955
+ ),
956
+ self.text,
957
+ )
958
+ x_emb = encode_coordinate(
959
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
960
+ self.region,
961
+ )
962
+ y_emb = encode_coordinate(
963
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
964
+ self.region,
965
+ )
966
+
967
+ prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
968
+
969
+ self.load_encoded_image(image)
970
+
971
+ mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
972
+ pos_ids = torch.arange(
973
+ image.pos,
974
+ image.pos + prompt_emb.size(1),
975
+ dtype=torch.long,
976
+ device=self.device,
977
+ )
978
+ hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
979
+ logits = lm_head(hidden, self.text)
980
+ next_token = torch.argmax(logits, dim=-1)
981
+ pos = image.pos + prompt_emb.size(1)
982
+ hidden = hidden[:, -1:, :]
983
+
984
+ if force_detect:
985
+ next_token = torch.tensor([[0]], device=self.device)
986
+
987
+ if next_token.item() == self.config.tokenizer.eos_id:
988
+ return None
989
+
990
+ gaze = self._generate_points(
991
+ hidden, next_token, pos, include_size=False, max_objects=1
992
+ )
993
+ return gaze[0]
994
+
995
+ def detect_gaze(
996
+ self,
997
+ image: Union[Image.Image, EncodedImage],
998
+ eye: Optional[Tuple[float, float]] = None,
999
+ face: Optional[Dict[str, float]] = None,
1000
+ unstable_settings: Dict[str, Any] = {},
1001
+ ):
1002
+ if "force_detect" in unstable_settings:
1003
+ force_detect = unstable_settings["force_detect"]
1004
+ else:
1005
+ force_detect = False
1006
+
1007
+ if "prioritize_accuracy" in unstable_settings:
1008
+ prioritize_accuracy = unstable_settings["prioritize_accuracy"]
1009
+ else:
1010
+ prioritize_accuracy = False
1011
+
1012
+ if not prioritize_accuracy:
1013
+ if eye is None:
1014
+ raise ValueError("eye must be provided when prioritize_accuracy=False")
1015
+ image = self.encode_image(image)
1016
+ return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
1017
+ else:
1018
+ if (
1019
+ not isinstance(image, Image.Image)
1020
+ and "flip_enc_img" not in unstable_settings
1021
+ ):
1022
+ raise ValueError(
1023
+ "image must be a PIL Image when prioritize_accuracy=True, "
1024
+ "or flip_enc_img must be provided"
1025
+ )
1026
+ if face is None:
1027
+ raise ValueError("face must be provided when prioritize_accuracy=True")
1028
+
1029
+ encoded_image = self.encode_image(image)
1030
+ if (
1031
+ isinstance(image, Image.Image)
1032
+ and "flip_enc_img" not in unstable_settings
1033
+ ):
1034
+ flipped_pil = image.copy()
1035
+ flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
1036
+ encoded_flipped_image = self.encode_image(flipped_pil)
1037
+ else:
1038
+ encoded_flipped_image = unstable_settings["flip_enc_img"]
1039
+
1040
+ N = 10
1041
+
1042
+ detections = [
1043
+ self._detect_gaze(
1044
+ encoded_image,
1045
+ (
1046
+ random.uniform(face["x_min"], face["x_max"]),
1047
+ random.uniform(face["y_min"], face["y_max"]),
1048
+ ),
1049
+ force_detect=force_detect,
1050
+ )
1051
+ for _ in range(N)
1052
+ ]
1053
+ detections = [
1054
+ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
1055
+ ]
1056
+ flipped_detections = [
1057
+ self._detect_gaze(
1058
+ encoded_flipped_image,
1059
+ (
1060
+ 1 - random.uniform(face["x_min"], face["x_max"]),
1061
+ random.uniform(face["y_min"], face["y_max"]),
1062
+ ),
1063
+ force_detect=force_detect,
1064
+ )
1065
+ for _ in range(N)
1066
+ ]
1067
+ detections.extend(
1068
+ [
1069
+ (1 - gaze["x"], gaze["y"])
1070
+ for gaze in flipped_detections
1071
+ if gaze is not None
1072
+ ]
1073
+ )
1074
+
1075
+ if len(detections) < N:
1076
+ return {"gaze": None}
1077
+
1078
+ detections = remove_outlier_points(detections)
1079
+ mean_gaze = (
1080
+ sum(gaze[0] for gaze in detections) / len(detections),
1081
+ sum(gaze[1] for gaze in detections) / len(detections),
1082
+ )
1083
+
1084
+ return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
1085
+
1086
+
1087
+ def _is_cjk_char(cp):
1088
+ """Checks whether CP is the codepoint of a CJK character."""
1089
+ # This defines a "chinese character" as anything in the CJK Unicode block:
1090
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
1091
+ if (
1092
+ (cp >= 0x4E00 and cp <= 0x9FFF)
1093
+ or (cp >= 0x3400 and cp <= 0x4DBF)
1094
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
1095
+ ):
1096
+ return True
1097
+ return False
region.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from typing import List, Tuple, Union
6
+
7
+ SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
8
+
9
+
10
+ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ Applies Fourier feature mapping to input tensor x using frequency matrix w. This
13
+ projects inputs through sinusoidal functions to create higher dimensional features
14
+ that help mitigate spectral bias - the tendency of neural networks to learn
15
+ low-frequency functions more easily than high-frequency ones. By explicitly
16
+ mapping inputs to higher frequencies through sin/cos transformations, we enable
17
+ better learning of fine details and higher frequency patterns.
18
+
19
+ Args:
20
+ x: Input tensor to transform
21
+ w: Matrix of frequencies for the Fourier features transformation
22
+
23
+ Returns:
24
+ Concatenated cosine and sine transformed features as a tensor
25
+ """
26
+ f = 2 * math.pi * x @ w
27
+ return torch.cat([f.cos(), f.sin()], dim=-1)
28
+
29
+
30
+ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
31
+ """
32
+ Takes as input a tensor containing a single float coordinate value (x or y)
33
+ and encodes it into hidden states for input to the text model.
34
+
35
+ Args:
36
+ coord: Tensor with single float coordinate value
37
+
38
+ Returns:
39
+ Encoded hidden states tensor for input to text model
40
+ """
41
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
42
+
43
+
44
+ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
45
+ """
46
+ Takes as input the last hidden state from the text model and outputs a single logit
47
+ representing either an x or y coordinate prediction.
48
+
49
+ Args:
50
+ hidden_state: The final hidden state tensor from the text model.
51
+
52
+ Returns:
53
+ A single logit representing the predicted coordinate value (x or y)
54
+ """
55
+ hidden_state = w.ln(hidden_state)
56
+ return w.coord_decoder(hidden_state)
57
+
58
+
59
+ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
60
+ """
61
+ Takes a tensor containing width and height values and encodes them into
62
+ hidden states for input to the text model.
63
+
64
+ Args:
65
+ size: Tensor with two floats for width and height
66
+
67
+ Returns:
68
+ Encoded hidden states tensor for input to text model
69
+ """
70
+ return w.size_encoder(fourier_features(size, w.size_features))
71
+
72
+
73
+ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
74
+ """
75
+ Takes as input the last hidden state from the text model and outputs logits
76
+ for 1024 bins representing width and height in log-scale.
77
+
78
+ The bins are distributed according to the formula:
79
+ bin = (log2(size) + 10.0) / 10.0 * 1023.0
80
+ where size values are clamped to be at least 1/1024.
81
+
82
+ To convert from bin back to size:
83
+ size = 2^((bin / 1023.0) * 10.0 - 10.0)
84
+
85
+ Args:
86
+ hidden_state: The final hidden state tensor from the text model.
87
+
88
+ Returns:
89
+ A tensor containing logits for 1024 bins for width and height.
90
+ Shape is (2, 1024) where the first dimension corresponds to width and height.
91
+ """
92
+ hidden_state = w.ln(hidden_state)
93
+ return w.size_decoder(hidden_state).view(2, -1)
94
+
95
+
96
+ def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
97
+ """
98
+ Takes a list of spatial references (points or regions) and encodes them into
99
+ hidden states for input to the text model.
100
+
101
+ Args:
102
+ spatial_refs: List of spatial references (points or boxes)
103
+ - Points are represented as normalized (x, y) tuples
104
+ - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
105
+
106
+ Returns:
107
+ {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
108
+ """
109
+ coords, sizes = [], []
110
+ for ref in spatial_refs:
111
+ if len(ref) == 2:
112
+ coords.append(ref[0])
113
+ coords.append(ref[1])
114
+ else:
115
+ x_c = (ref[0] + ref[2]) / 2
116
+ y_c = (ref[1] + ref[3]) / 2
117
+ width = ref[2] - ref[0]
118
+ height = ref[3] - ref[1]
119
+ coords.append(x_c)
120
+ coords.append(y_c)
121
+ sizes.append([width, height])
122
+
123
+ coords = torch.tensor(
124
+ coords, device=w.coord_features.device, dtype=w.coord_features.dtype
125
+ ).view(-1, 1)
126
+ coords = encode_coordinate(coords, w)
127
+
128
+ if sizes:
129
+ sizes = torch.tensor(
130
+ sizes, device=w.size_features.device, dtype=w.size_features.dtype
131
+ )
132
+ sizes = encode_size(sizes, w)
133
+ else:
134
+ sizes = None
135
+
136
+ return {"coords": coords, "sizes": sizes}
rope.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ import torch
4
+
5
+
6
+ def precompute_freqs_cis(
7
+ dim: int,
8
+ end: int,
9
+ theta: float = 1500000.0,
10
+ dtype: torch.dtype = torch.float32,
11
+ ) -> torch.Tensor:
12
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
13
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
14
+ freqs = t * freqs.unsqueeze(0)
15
+ freqs = torch.exp(1j * freqs)
16
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
17
+
18
+
19
+ def apply_rotary_emb(
20
+ x: torch.Tensor,
21
+ freqs_cis: torch.Tensor,
22
+ position_ids: torch.Tensor,
23
+ num_heads: int,
24
+ rot_dim: int = 32,
25
+ interleave: bool = False,
26
+ ) -> torch.Tensor:
27
+ assert rot_dim == freqs_cis.shape[-2] * 2
28
+ assert num_heads == x.shape[1]
29
+
30
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
31
+
32
+ if interleave:
33
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
34
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
35
+ else:
36
+ d_q = x_rot.shape[-1] // 2
37
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
38
+
39
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
40
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
41
+
42
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
43
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
44
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
45
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
46
+
47
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
text.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn import functional as F
5
+ from torch.nn.attention.flex_attention import flex_attention
6
+ from typing import Optional
7
+
8
+ from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
+ from .rope import apply_rotary_emb, precompute_freqs_cis
10
+ from .config import TextConfig
11
+ from .lora import select_layer_lora
12
+
13
+
14
+ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
15
+ return F.embedding(input_ids, w.wte)
16
+
17
+
18
+ def attn(
19
+ x: torch.Tensor,
20
+ w: nn.Module,
21
+ freqs_cis: torch.Tensor,
22
+ kv_cache: nn.Module,
23
+ attn_mask: torch.Tensor,
24
+ n_heads: int,
25
+ n_kv_heads: int,
26
+ position_ids: torch.Tensor,
27
+ flex_block_mask_slice=None,
28
+ ):
29
+ bsz, q_len, d_model = x.shape
30
+ head_dim = d_model // n_heads
31
+
32
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
33
+ q_dim = n_heads * head_dim
34
+ kv_dim = n_kv_heads * head_dim
35
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
36
+
37
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
38
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
39
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
40
+
41
+ if hasattr(w, "tau") and w.tau is not None:
42
+ tok_feat = F.gelu(qkv_out)
43
+ tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1)
44
+ tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1)
45
+ pos = position_ids.to(q.dtype) + 1
46
+ tau_pos = 1 + (
47
+ torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5
48
+ ) # (H,S)
49
+ tau_q = (tok_q + tau_pos[None]).unsqueeze(-1) # (B,H,S,1)
50
+ tau_v = (tok_v + tau_pos[None]).unsqueeze(-1)
51
+ q = q * tau_q
52
+ v = v * tau_v
53
+
54
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
55
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
56
+
57
+ if kv_cache is not None:
58
+ k, v = kv_cache.update(position_ids, k, v)
59
+
60
+ if flex_block_mask_slice is not None:
61
+ torch._assert(n_heads == n_kv_heads, "gqa not supported yet")
62
+ out = flex_attention(q, k, v, block_mask=flex_block_mask_slice)
63
+ else:
64
+ out = F.scaled_dot_product_attention(
65
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
66
+ )
67
+
68
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
69
+
70
+ return w.proj(out)
71
+
72
+
73
+ def text_decoder(
74
+ x: torch.Tensor,
75
+ w: nn.Module,
76
+ attn_mask: torch.Tensor,
77
+ position_ids: torch.Tensor,
78
+ config: TextConfig,
79
+ lora: Optional[object] = None,
80
+ flex_block_mask_slice=None,
81
+ ):
82
+ for i, block in enumerate(w.blocks):
83
+ layer_lora = select_layer_lora(
84
+ lora, i, is_moe=config.moe is not None and i >= config.moe.start_layer
85
+ )
86
+
87
+ l_in = layer_norm(x, block.ln)
88
+ l_attn = attn(
89
+ l_in,
90
+ block.attn,
91
+ freqs_cis=w.freqs_cis,
92
+ kv_cache=block.kv_cache,
93
+ attn_mask=attn_mask,
94
+ n_heads=config.n_heads,
95
+ n_kv_heads=config.n_kv_heads,
96
+ position_ids=position_ids,
97
+ flex_block_mask_slice=flex_block_mask_slice,
98
+ )
99
+
100
+ if config.moe is not None and i >= config.moe.start_layer:
101
+ l_mlp = moe_mlp(
102
+ l_in, block.mlp, config.moe.experts_per_token, lora=layer_lora
103
+ )
104
+ else:
105
+ l_mlp = mlp(l_in, block.mlp, lora=layer_lora)
106
+
107
+ x = x + l_attn + l_mlp
108
+
109
+ return x
110
+
111
+
112
+ def lm_head(
113
+ hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None
114
+ ):
115
+ hidden_BC = hidden_BTC[:, -1, :]
116
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
117
+ if indices is not None:
118
+ # Only compute logits for specified token indices
119
+ logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices]
120
+ else:
121
+ logits = w.lm_head(hidden_BC)
122
+ return logits
123
+
124
+
125
+ def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
126
+ return nn.ModuleDict(
127
+ {
128
+ "fc1": linear_cls(d_model, d_ffn, dtype=dtype),
129
+ "fc2": linear_cls(d_ffn, d_model, dtype=dtype),
130
+ }
131
+ )
132
+
133
+
134
+ def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
135
+ # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
136
+ mlp = nn.ModuleDict(
137
+ {
138
+ "router": nn.Linear(d_model, n_experts, dtype=dtype),
139
+ "fc1": nn.ParameterDict(
140
+ {
141
+ "weight": nn.Parameter(
142
+ torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype)
143
+ )
144
+ }
145
+ ),
146
+ "fc2": nn.ParameterDict(
147
+ {
148
+ "weight": nn.Parameter(
149
+ torch.empty(n_experts, d_model, d_ffn, dtype=dtype)
150
+ )
151
+ }
152
+ ),
153
+ }
154
+ )
155
+ return mlp
156
+
157
+
158
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
159
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
160
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
161
+
162
+ text = nn.ModuleDict(
163
+ {
164
+ "blocks": nn.ModuleList(
165
+ [
166
+ nn.ModuleDict(
167
+ {
168
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
169
+ "attn": nn.ModuleDict(
170
+ {
171
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
172
+ "proj": linear_cls(
173
+ config.dim, config.dim, dtype=dtype
174
+ ),
175
+ "tau": nn.ParameterDict(
176
+ {
177
+ "wq": nn.Parameter(
178
+ torch.empty(
179
+ config.n_heads, qkv_dim, dtype=dtype
180
+ )
181
+ ),
182
+ "wv": nn.Parameter(
183
+ torch.empty(
184
+ config.n_heads, qkv_dim, dtype=dtype
185
+ )
186
+ ),
187
+ "alpha": nn.Parameter(
188
+ torch.empty(config.n_heads, dtype=dtype)
189
+ ),
190
+ }
191
+ ),
192
+ }
193
+ ),
194
+ "mlp": (
195
+ build_moe_mlp(
196
+ config.dim,
197
+ config.moe.expert_inner_dim,
198
+ config.moe.num_experts,
199
+ dtype,
200
+ )
201
+ if config.moe is not None
202
+ and layer_idx >= config.moe.start_layer
203
+ else build_dense_mlp(
204
+ config.dim, config.ff_dim, dtype, linear_cls
205
+ )
206
+ ),
207
+ }
208
+ )
209
+ for layer_idx in range(config.n_layers)
210
+ ]
211
+ ),
212
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
213
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
214
+ }
215
+ )
216
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
217
+ text.register_buffer(
218
+ "freqs_cis",
219
+ precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
220
+ persistent=False,
221
+ )
222
+
223
+ return text
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5
+ """
6
+ Robust outlier detection for list of (x,y) tuples.
7
+ Only requires numpy.
8
+
9
+ Args:
10
+ points_tuples: list of (x,y) tuples
11
+ k_nearest: number of neighbors to consider
12
+ threshold: multiplier for median distance
13
+
14
+ Returns:
15
+ list: filtered list of (x,y) tuples with outliers removed
16
+ list: list of booleans indicating which points were kept (True = kept)
17
+ """
18
+ points = np.array(points_tuples)
19
+ n_points = len(points)
20
+
21
+ # Calculate pairwise distances manually
22
+ dist_matrix = np.zeros((n_points, n_points))
23
+ for i in range(n_points):
24
+ for j in range(i + 1, n_points):
25
+ # Euclidean distance between points i and j
26
+ dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27
+ dist_matrix[i, j] = dist
28
+ dist_matrix[j, i] = dist
29
+
30
+ # Get k nearest neighbors' distances
31
+ k = min(k_nearest, n_points - 1)
32
+ neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33
+ avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34
+
35
+ # Calculate mask using median distance
36
+ median_dist = np.median(avg_neighbor_dist)
37
+ mask = avg_neighbor_dist <= threshold * median_dist
38
+
39
+ # Return filtered tuples and mask
40
+ filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41
+ return filtered_tuples
vision.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Union, Tuple
7
+ from PIL import Image
8
+
9
+ from .layers import attn, layer_norm, mlp
10
+ from .image_crops import overlap_crop_image
11
+ from .config import VisionConfig
12
+
13
+ if torch.backends.mps.is_available():
14
+ # Non-divisible input sizes are not implemented on MPS device yet.
15
+ # https://github.com/pytorch/pytorch/issues/96056
16
+ def adaptive_avg_pool2d(input, output_size):
17
+ return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18
+
19
+ else:
20
+ adaptive_avg_pool2d = F.adaptive_avg_pool2d
21
+
22
+ DeviceLike = Union[str, torch.device, int]
23
+
24
+
25
+ def prepare_crops(
26
+ image: Image.Image, config: VisionConfig, device: DeviceLike
27
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
28
+ np_image = np.array(image.convert("RGB"))
29
+ overlap_crops = overlap_crop_image(
30
+ np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
31
+ )
32
+ all_crops = overlap_crops["crops"]
33
+ all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
+ all_crops = (
35
+ torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
+ .div_(255.0)
38
+ .sub_(0.5)
39
+ .div_(0.5)
40
+ )
41
+ return all_crops, overlap_crops["tiling"]
42
+
43
+
44
+ def create_patches(x, patch_size):
45
+ # Original shape: [B, C, H, W]
46
+ B, C, H, W = x.shape
47
+ P1 = P2 = patch_size
48
+
49
+ # Step 1: Split H and W dimensions into patches
50
+ # [B, C, H/P1, P1, W/P2, P2]
51
+ x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52
+
53
+ # Step 2: Rearrange dimensions to match target shape
54
+ # [B, H/P1, W/P2, C, P1, P2]
55
+ x = x.permute(0, 2, 4, 1, 3, 5)
56
+
57
+ # Step 3: Combine dimensions to get final shape
58
+ # [B, (H/P1)*(W/P2), C*P1*P2]
59
+ x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60
+
61
+ return x
62
+
63
+
64
+ def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
+ x = create_patches(input_BCHW, config.enc_patch_size)
66
+
67
+ x = w.patch_emb(x)
68
+ x = x + w.pos_emb
69
+ for block in w.blocks:
70
+ x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
71
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
72
+ x = layer_norm(x, w.post_ln)
73
+
74
+ return x
75
+
76
+
77
+ def vision_projection(
78
+ global_features: torch.Tensor,
79
+ reconstructed: torch.Tensor,
80
+ w: nn.Module,
81
+ config: VisionConfig,
82
+ ):
83
+ reconstructed = reconstructed.permute(2, 0, 1)
84
+ reconstructed = adaptive_avg_pool2d(
85
+ reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
86
+ )
87
+ reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
88
+ final_features = torch.cat([global_features, reconstructed], dim=-1)
89
+ return mlp(final_features, w.proj_mlp)
90
+
91
+
92
+ def build_vision_model(config: VisionConfig, dtype: torch.dtype):
93
+ patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
94
+ grid_size = config.crop_size // config.enc_patch_size
95
+ num_patches = grid_size * grid_size
96
+
97
+ vision = nn.ModuleDict(
98
+ {
99
+ "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
100
+ "blocks": nn.ModuleList(
101
+ [
102
+ nn.ModuleDict(
103
+ {
104
+ "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
105
+ "attn": nn.ModuleDict(
106
+ {
107
+ "qkv": nn.Linear(
108
+ config.enc_dim, 3 * config.enc_dim, dtype=dtype
109
+ ),
110
+ "proj": nn.Linear(
111
+ config.enc_dim, config.enc_dim, dtype=dtype
112
+ ),
113
+ }
114
+ ),
115
+ "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
116
+ "mlp": nn.ModuleDict(
117
+ {
118
+ "fc1": nn.Linear(
119
+ config.enc_dim, config.enc_ff_dim, dtype=dtype
120
+ ),
121
+ "fc2": nn.Linear(
122
+ config.enc_ff_dim, config.enc_dim, dtype=dtype
123
+ ),
124
+ }
125
+ ),
126
+ }
127
+ )
128
+ for _ in range(config.enc_n_layers)
129
+ ]
130
+ ),
131
+ "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
132
+ "proj_mlp": nn.ModuleDict(
133
+ {
134
+ "fc1": nn.Linear(
135
+ config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
136
+ ),
137
+ "fc2": nn.Linear(
138
+ config.proj_inner_dim, config.proj_out_dim, dtype=dtype
139
+ ),
140
+ }
141
+ ),
142
+ }
143
+ )
144
+ vision.pos_emb = nn.Parameter(
145
+ torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
146
+ )
147
+ return vision