EQX55 commited on
Commit
8c21234
·
verified ·
1 Parent(s): 0047d0a

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ structured_outputs.png filter=lfs diff=lfs merge=lfs -text
37
+ open_vocab_detect.png filter=lfs diff=lfs merge=lfs -text
38
+ point_count.png filter=lfs diff=lfs merge=lfs -text
39
+ visual_reasoning.png filter=lfs diff=lfs merge=lfs -text
LICENSE.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | License | Business Source License (BSL 1.1) |
2
+ | --- | --- |
3
+ | Licensor | M87 Labs, Inc. |
4
+ | Licensed Work | “Moondream 3 (Preview)” including Model Weights and any Derivatives (“Derivatives” include fine-tunes, merges, quantizations, weight deltas, and other weight-level modifications or conversions.) |
5
+ | Additional Use Grant | You may use the Licensed Work and Derivatives for any purpose, including commercial use, and you may self-host them for your or your organization’s internal use. You may not provide the Licensed Work, Derivatives, or any service that exposes their functionality to third parties (including via API, website, application, model hub, or dataset/model redistribution) without a separate commercial agreement with the Licensor. |
6
+ | Change Date | Two years after the first public release of this version of the Licensed Work |
7
+ | Change License | Apache License, Version 2.0 |
8
+
9
+ The text of the Business Source License 1.1 follows. License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved. “Business Source License” is a trademark of MariaDB Corporation Ab.
10
+
11
+ ## Terms
12
+
13
+ The Licensor hereby grants you the right to copy, modify, create derivative works, redistribute, and make non-production use of the Licensed Work. The Licensor may make an Additional Use Grant, above, permitting limited production use.
14
+
15
+ Effective on the Change Date, or the fourth anniversary of the first publicly available distribution of a specific version of the Licensed Work under this License, whichever comes first, the Licensor hereby grants you rights under the terms of the Change License, and the rights granted in the paragraph above terminate.
16
+
17
+ If your use of the Licensed Work does not comply with the requirements currently in effect as described in this License, you must purchase a commercial license from the Licensor, its affiliated entities, or authorized resellers, or you must refrain from using the Licensed Work.
18
+
19
+ All copies of the original and modified Licensed Work, and derivative works of the Licensed Work, are subject to this License. This License applies separately for each version of the Licensed Work and the Change Date may vary for each version of the Licensed Work released by Licensor.
20
+
21
+ You must conspicuously display this License on each original or modified copy of the Licensed Work. If you receive the Licensed Work in original or modified form from a third party, the terms and conditions set forth in this License apply to your use of that work.
22
+
23
+ Any use of the Licensed Work in violation of this License will automatically terminate your rights under this License for the current and all other versions of the Licensed Work.
24
+
25
+ This License does not grant you any right in any trademark or logo of Licensor or its affiliates (provided that you may use a trademark or logo of Licensor as expressly required by this License).TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND TITLE.
README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ pipeline_tag: image-text-to-text
4
+ license: other
5
+ ---
6
+
7
+ **Moondream 3 (Preview)** is an vision language model with a mixture-of-experts architecture (9B total parameters, 2B active). This model makes no compromises, delivering state-of-the-art visual reasoning while still retaining our efficient and deployment-friendly ethos.
8
+
9
+ [✨ Demo](https://moondream.ai/c/playground)   ·   [☁️ Cloud API](https://moondream.ai/c/docs/quickstart)   ·   [📝 Release notes](https://moondream.ai/blog/moondream-3-preview)
10
+
11
+ ![](https://huggingface.co/moondream/moondream3-preview/resolve/main/open_vocab_detect.png)
12
+ ![](https://huggingface.co/moondream/moondream3-preview/resolve/main/visual_reasoning.png)
13
+ ![](https://huggingface.co/moondream/moondream3-preview/resolve/main/point_count.png)
14
+ ![](https://huggingface.co/moondream/moondream3-preview/resolve/main/structured_outputs.png)
15
+
16
+ ## Architecture
17
+
18
+ 1. 24 layers; the first four are dense, the rest have MoE FFNs with 64 experts, 8 activated per token
19
+ 2. MoE FFNs have GeGLU architecture, with inner/gate dim of 1024. The model's hidden dim is 2048.
20
+ 3. Usable context length increased to 32K, with [a custom efficient SuperBPE tokenizer](https://huggingface.co/moondream/starmie-v1)
21
+ 4. Multi-headed attention with learned position- and data-dependent temperature scaling
22
+ 5. SigLIP-based vision encoder, with multi-crop channel concatenation for token-efficient high resolution image processing
23
+
24
+ For more details, please refer to the [release notes]((https://moondream.ai/blog/moondream-3-preview). Or try the model out in our [playground demo](https://moondream.ai/c/playground).
25
+
26
+ The following instructions demonstrate how to run the model locally using Transformers. We also offer a [cloud API](https://moondream.ai/c/docs/quickstart) with a generous free tier that can help you get started quicker!
27
+
28
+ ## Usage
29
+
30
+ Load the model and prepare it for inference. We use [FlexAttention for inference](https://pytorch.org/blog/flexattention-for-inference/), so calling `.compile()` is critical for fast decoding. Our `compile` implementation also handles warmup, so you can start making requests directly once it returns.
31
+
32
+ ```python
33
+ import torch
34
+ from transformers import AutoModelForCausalLM
35
+
36
+ moondream = AutoModelForCausalLM.from_pretrained(
37
+ "moondream/moondream3-preview",
38
+ trust_remote_code=True,
39
+ dtype=torch.bfloat16,
40
+ device_map={"": "cuda"},
41
+ )
42
+ moondream.compile()
43
+ ```
44
+
45
+ The model comes with four skills, tailored towards different visual understanding tasks.
46
+
47
+ ### Query
48
+
49
+ The `query` skill can be used to ask open-ended questions about images.
50
+
51
+ ```python
52
+ from PIL import Image
53
+
54
+ # Simple VQA
55
+ image = Image.open("photo.jpg")
56
+ result = moondream.query(image=image, question="What's in this image?")
57
+ print(result["answer"])
58
+ ```
59
+
60
+ By default, `query` runs in reasoning mode, allowing the model to "think" about the question before generating an answer. This is helpful for more complicated tasks, but sometimes the task you're running is simple and doesn't benefit from reasoning. To save on inference cost when this is the case, you can disable reasoning:
61
+
62
+ ```python
63
+ # Without reasoning for simple questions
64
+ result = moondream.query(
65
+ image=image,
66
+ question="What color is the sky?",
67
+ reasoning=False
68
+ )
69
+ print(result["answer"])
70
+ ```
71
+
72
+ If you want to stream outputs, pass in `stream=True`. You can control the temperature, top-p, and maximum number of tokens generated by passing in optional settings.
73
+
74
+ ```python
75
+ # Streaming with custom settings
76
+ settings = {
77
+ "temperature": 0.7,
78
+ "top_p": 0.95,
79
+ "max_tokens": 512
80
+ }
81
+
82
+ result = moondream.query(
83
+ image=image,
84
+ question="Describe what's happening in detail",
85
+ stream=True,
86
+ settings=settings
87
+ )
88
+
89
+ # Stream the answer
90
+ for chunk in result["answer"]:
91
+ print(chunk, end="", flush=True)
92
+ ```
93
+
94
+ Note that this isn't just for images; Moondream is also a strong general-purpose text model.
95
+
96
+ ```python
97
+ # Text-only example (no image)
98
+ result = moondream.query(
99
+ question="Explain the concept of machine learning in simple terms"
100
+ )
101
+ print(result["answer"])
102
+ ```
103
+
104
+ ### Caption
105
+
106
+ Whether you want short, normal-sized or long descriptions of images, the `caption` skill has you covered.
107
+
108
+ ```python
109
+ # Different caption lengths
110
+ image = Image.open("landscape.jpg")
111
+
112
+ # Short caption
113
+ short = moondream.caption(image, length="short")
114
+ print(f"Short: {short['caption']}")
115
+
116
+ # Normal caption (default)
117
+ normal = moondream.caption(image, length="normal")
118
+ print(f"Normal: {normal['caption']}")
119
+
120
+ # Long caption
121
+ long = moondream.caption(image, length="long")
122
+ print(f"Long: {long['caption']}")
123
+ ```
124
+
125
+ It accepts the same streaming and temperature etc. settings as the `query` skill.
126
+
127
+ ```python
128
+ # Streaming caption with custom settings
129
+ result = moondream.caption(
130
+ image,
131
+ length="long",
132
+ stream=True,
133
+ settings={"temperature": 0.3}
134
+ )
135
+
136
+ for chunk in result["caption"]:
137
+ print(chunk, end="", flush=True)
138
+ ```
139
+
140
+ ### Point
141
+
142
+ The `point` skill identifies specific points (x, y coordinates) for objects in an image.
143
+
144
+ ```python
145
+ # Find points for specific objects
146
+ image = Image.open("crowd.jpg")
147
+ result = moondream.point(image, "person wearing a red shirt")
148
+
149
+ # Points are normalized coordinates (0-1)
150
+ for i, point in enumerate(result["points"]):
151
+ print(f"Point {i+1}: x={point['x']:.3f}, y={point['y']:.3f}")
152
+ ```
153
+
154
+ ### Detect
155
+
156
+ The `detect` skill provides bounding boxes for objects in an image.
157
+
158
+ ```python
159
+ # Detect objects with bounding boxes
160
+ image = Image.open("street_scene.jpg")
161
+ result = moondream.detect(image, "car")
162
+
163
+ # Bounding boxes are normalized coordinates (0-1)
164
+ for i, obj in enumerate(result["objects"]):
165
+ print(f"Object {i+1}: "
166
+ f"x_min={obj['x_min']:.3f}, y_min={obj['y_min']:.3f}, "
167
+ f"x_max={obj['x_max']:.3f}, y_max={obj['y_max']:.3f}")
168
+
169
+ # Control maximum number of objects
170
+ settings = {"max_objects": 10}
171
+ result = moondream.detect(image, "person", settings=settings)
172
+ ```
173
+
174
+ ### Caching image encodings (advanced)
175
+
176
+ If you're planning to run multiple inferences on the same image, you can pre-encode it once and reuse the encoding for better performance.
177
+
178
+ ```python
179
+ # Encode image once
180
+ image = Image.open("complex_scene.jpg")
181
+ encoded = moondream.encode_image(image)
182
+
183
+ # Reuse the encoding for multiple queries
184
+ questions = [
185
+ "How many people are in this image?",
186
+ "What time of day was this taken?",
187
+ "What's the weather like?"
188
+ ]
189
+
190
+ for q in questions:
191
+ result = moondream.query(image=encoded, question=q, reasoning=False)
192
+ print(f"Q: {q}")
193
+ print(f"A: {result['answer']}\n")
194
+
195
+ # Also works with other skills
196
+ caption = moondream.caption(encoded, length="normal")
197
+ objects = moondream.detect(encoded, "vehicle")
198
+ ```
199
+
200
+ ---
201
+
202
+ Copyright (c) 2025 M87 Labs, Inc.
203
+
204
+ This distribution includes Model Weights licensed under the [Business Source License 1.1 with an Additional Use Grant (No Third-Party Service)](https://huggingface.co/moondream/moondream3-preview/blob/main/LICENSE.md). Commercial hosting or rehosting requires an agreement with <contact@m87.ai>.
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HfMoondream"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_moondream.HfConfig",
7
+ "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
+ },
9
+ "config": {},
10
+ "model_type": "moondream1",
11
+ "torch_dtype": "bfloat16",
12
+ "transformers_version": "4.51.1"
13
+ }
config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class TextMoeConfig:
7
+ num_experts: int = 64
8
+ start_layer: int = 4
9
+ experts_per_token: int = 8
10
+ expert_inner_dim: int = 1024
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TextConfig:
15
+ dim: int = 2048
16
+ ff_dim: int = 8192
17
+ n_layers: int = 24
18
+ vocab_size: int = 51200
19
+ max_context: int = 4096
20
+ n_heads: int = 32
21
+ n_kv_heads: int = 32
22
+ prefix_attn: int = 730
23
+ group_size: Optional[int] = None
24
+ moe: Optional[TextMoeConfig] = TextMoeConfig()
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class VisionConfig:
29
+ enc_dim: int = 1152
30
+ enc_patch_size: int = 14
31
+ enc_n_layers: int = 27
32
+ enc_ff_dim: int = 4304
33
+ enc_n_heads: int = 16
34
+ proj_out_dim: int = 2048
35
+ crop_size: int = 378
36
+ in_channels: int = 3
37
+ max_crops: int = 12
38
+ overlap_margin: int = 4
39
+ proj_inner_dim: int = 8192
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class RegionConfig:
44
+ dim: int = 2048
45
+ coord_feat_dim: int = 256
46
+ coord_out_dim: int = 1024
47
+ size_feat_dim: int = 512
48
+ size_out_dim: int = 2048
49
+ group_size: Optional[int] = None
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class TokenizerConfig:
54
+ bos_id: int = 0
55
+ eos_id: int = 0
56
+ answer_id: int = 3
57
+ thinking_id: int = 4
58
+ coord_id: int = 5
59
+ size_id: int = 6
60
+ start_ground_points_id: int = 7
61
+ end_ground_id: int = 9
62
+ templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
63
+ default_factory=lambda: {
64
+ "caption": {
65
+ "short": [1, 32708, 2, 12492, 3],
66
+ "normal": [1, 32708, 2, 6382, 3],
67
+ "long": [1, 32708, 2, 4059, 3],
68
+ },
69
+ "query": {"prefix": [1, 15381, 2], "suffix": [3]},
70
+ "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
71
+ "point": {"prefix": [1, 2581, 2], "suffix": [3]},
72
+ }
73
+ )
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class MoondreamConfig:
78
+ text: TextConfig = TextConfig()
79
+ vision: VisionConfig = VisionConfig()
80
+ region: RegionConfig = RegionConfig()
81
+ tokenizer: TokenizerConfig = TokenizerConfig()
82
+
83
+ @classmethod
84
+ def from_dict(cls, config_dict: dict):
85
+ text_config = TextConfig(**config_dict.get("text", {}))
86
+ vision_config = VisionConfig(**config_dict.get("vision", {}))
87
+ region_config = RegionConfig(**config_dict.get("region", {}))
88
+ tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
89
+ return cls(
90
+ text=text_config,
91
+ vision=vision_config,
92
+ region=region_config,
93
+ tokenizer=tokenizer_config,
94
+ )
95
+
96
+ def to_dict(self):
97
+ return {
98
+ "text": self.text.__dict__,
99
+ "vision": self.vision.__dict__,
100
+ "region": self.region.__dict__,
101
+ "tokenizer": self.tokenizer.__dict__,
102
+ }
hf_moondream.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from typing import Union
6
+
7
+ from .config import MoondreamConfig
8
+ from .moondream import MoondreamModel
9
+
10
+ # Files sometimes don't get loaded without these...
11
+ from .image_crops import *
12
+ from .vision import *
13
+ from .text import *
14
+ from .region import *
15
+ from .utils import *
16
+
17
+
18
+ def extract_question(text):
19
+ prefix = "<image>\n\nQuestion: "
20
+ suffix = "\n\nAnswer:"
21
+
22
+ if text.startswith(prefix) and text.endswith(suffix):
23
+ return text[len(prefix) : -len(suffix)]
24
+ else:
25
+ return None
26
+
27
+
28
+ class HfConfig(PretrainedConfig):
29
+ _auto_class = "AutoConfig"
30
+ model_type = "moondream1"
31
+
32
+ def __init__(self, **kwargs):
33
+ super().__init__(**kwargs)
34
+ self.config = {}
35
+
36
+
37
+ class HfMoondream(PreTrainedModel):
38
+ _auto_class = "AutoModelForCausalLM"
39
+ config_class = HfConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.model = MoondreamModel(
44
+ MoondreamConfig.from_dict(config.config), setup_caches=False
45
+ )
46
+ self._is_kv_cache_setup = False
47
+
48
+ def _setup_caches(self):
49
+ if not self._is_kv_cache_setup:
50
+ self.model._setup_caches()
51
+ self._is_kv_cache_setup = True
52
+
53
+ @property
54
+ def encode_image(self):
55
+ self._setup_caches()
56
+ return self.model.encode_image
57
+
58
+ @property
59
+ def query(self):
60
+ self._setup_caches()
61
+ return self.model.query
62
+
63
+ @property
64
+ def caption(self):
65
+ self._setup_caches()
66
+ return self.model.caption
67
+
68
+ @property
69
+ def detect(self):
70
+ self._setup_caches()
71
+ return self.model.detect
72
+
73
+ @property
74
+ def point(self):
75
+ self._setup_caches()
76
+ return self.model.point
77
+
78
+ @property
79
+ def detect_gaze(self):
80
+ self._setup_caches()
81
+ return self.model.detect_gaze
82
+
83
+ def answer_question(
84
+ self,
85
+ image_embeds,
86
+ question,
87
+ tokenizer=None,
88
+ chat_history="",
89
+ result_queue=None,
90
+ max_new_tokens=256,
91
+ **kwargs
92
+ ):
93
+ answer = self.query(image_embeds, question)["answer"].strip()
94
+
95
+ if result_queue is not None:
96
+ result_queue.put(answer)
97
+ return answer
98
+
99
+ def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
100
+ answers = []
101
+ for image, prompt in zip(images, prompts):
102
+ answers.append(self.query(image, prompt)["answer"].strip())
103
+ return answers
104
+
105
+ def _unsupported_exception(self):
106
+ raise NotImplementedError(
107
+ "This method is not supported in the latest version of moondream. "
108
+ "Consider upgrading to the updated API spec, or alternately pin "
109
+ "to 'revision=2024-08-26'."
110
+ )
111
+
112
+ def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
113
+ """
114
+ Function definition remains unchanged for backwards compatibility.
115
+ Be aware that tokenizer, max_new_takens, and kwargs are ignored.
116
+ """
117
+ prompt_extracted = extract_question(prompt)
118
+ if prompt_extracted is not None:
119
+ answer = self.model.query(
120
+ image=image_embeds, question=prompt_extracted, stream=False
121
+ )["answer"]
122
+ else:
123
+ image_embeds = self.encode_image(image_embeds)
124
+ prompt_tokens = torch.tensor(
125
+ [self.model.tokenizer.encode(prompt).ids],
126
+ device=self.device,
127
+ )
128
+
129
+ def generator():
130
+ for token in self.model._generate_answer(
131
+ prompt_tokens,
132
+ image_embeds.kv_cache,
133
+ image_embeds.pos,
134
+ max_new_tokens,
135
+ ):
136
+ yield token
137
+
138
+ answer = "".join(list(generator()))
139
+
140
+ return [answer]
141
+
142
+ def get_input_embeddings(self) -> nn.Embedding:
143
+ """
144
+ Lazily wrap the raw parameter `self.model.text.wte` in a real
145
+ `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
146
+ **shares** the weight tensor—no copy is made.
147
+ """
148
+ if not hasattr(self, "_input_embeddings"):
149
+ self._input_embeddings = nn.Embedding.from_pretrained(
150
+ self.model.text.wte, # tensor created in text.py
151
+ freeze=True, # set to False if you need it trainable
152
+ )
153
+ return self._input_embeddings
154
+
155
+ def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
156
+ """
157
+ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
158
+ embeddings and keeps everything tied to `self.model.text.wte`.
159
+ """
160
+ # 1. point the low-level parameter to the new weight matrix
161
+ self.model.text.wte = value.weight
162
+ # 2. keep a reference for get_input_embeddings()
163
+ self._input_embeddings = value
164
+
165
+ def input_embeds(
166
+ self,
167
+ input_ids: Union[torch.LongTensor, list, tuple],
168
+ *,
169
+ device: torch.device | None = None
170
+ ) -> torch.FloatTensor:
171
+ """
172
+ Back-compat wrapper that turns token IDs into embeddings.
173
+
174
+ Example:
175
+ ids = torch.tensor([[1, 2, 3]])
176
+ embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
177
+ """
178
+ if not torch.is_tensor(input_ids):
179
+ input_ids = torch.as_tensor(input_ids)
180
+ if device is not None:
181
+ input_ids = input_ids.to(device)
182
+
183
+ return self.get_input_embeddings()(input_ids)
image_crops.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ from typing import TypedDict
6
+
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
+
17
+ def select_tiling(
18
+ height: int, width: int, crop_size: int, max_crops: int
19
+ ) -> tuple[int, int]:
20
+ """
21
+ Determine the optimal number of tiles to cover an image with overlapping crops.
22
+ """
23
+ if height <= crop_size or width <= crop_size:
24
+ return (1, 1)
25
+
26
+ # Minimum required tiles in each dimension
27
+ min_h = math.ceil(height / crop_size)
28
+ min_w = math.ceil(width / crop_size)
29
+
30
+ # If minimum required tiles exceed max_crops, return proportional distribution
31
+ if min_h * min_w > max_crops:
32
+ ratio = math.sqrt(max_crops / (min_h * min_w))
33
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
34
+
35
+ # Perfect aspect-ratio tiles that satisfy max_crops
36
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
37
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
38
+
39
+ # Ensure we meet minimum tile requirements
40
+ h_tiles = max(h_tiles, min_h)
41
+ w_tiles = max(w_tiles, min_w)
42
+
43
+ # If we exceeded max_crops, scale down the larger dimension
44
+ if h_tiles * w_tiles > max_crops:
45
+ if w_tiles > h_tiles:
46
+ w_tiles = math.floor(max_crops / h_tiles)
47
+ else:
48
+ h_tiles = math.floor(max_crops / w_tiles)
49
+
50
+ return (max(1, h_tiles), max(1, w_tiles))
51
+
52
+
53
+ class OverlapCropOutput(TypedDict):
54
+ crops: np.ndarray
55
+ tiling: tuple[int, int]
56
+
57
+
58
+ def overlap_crop_image(
59
+ image: np.ndarray,
60
+ overlap_margin: int,
61
+ max_crops: int,
62
+ base_size: tuple[int, int] = (378, 378),
63
+ patch_size: int = 14,
64
+ ) -> OverlapCropOutput:
65
+ """
66
+ Process an image using an overlap-and-resize cropping strategy with margin handling.
67
+
68
+ This function takes an input image and creates multiple overlapping crops with
69
+ consistent margins. It produces:
70
+ 1. A single global crop resized to base_size
71
+ 2. Multiple overlapping local crops that maintain high resolution details
72
+ 3. A patch ordering matrix that tracks correspondence between crops
73
+
74
+ The overlap strategy ensures:
75
+ - Smooth transitions between adjacent crops
76
+ - No loss of information at crop boundaries
77
+ - Proper handling of features that cross crop boundaries
78
+ - Consistent patch indexing across the full image
79
+
80
+ Args:
81
+ image (np.ndarray): Input image as numpy array with shape (H,W,C)
82
+ base_size (tuple[int,int]): Target size for crops, default (378,378)
83
+ patch_size (int): Size of patches in pixels, default 14
84
+ overlap_margin (int): Margin size in patch units, default 4
85
+ max_crops (int): Maximum number of crops allowed, default 12
86
+
87
+ Returns:
88
+ OverlapCropOutput: Dictionary containing:
89
+ - crops: A numpy array containing the global crop of the full image (index 0)
90
+ followed by the overlapping cropped regions (indices 1+)
91
+ - tiling: Tuple of (height,width) tile counts
92
+ """
93
+ original_h, original_w = image.shape[:2]
94
+
95
+ # Convert margin from patch units to pixels
96
+ margin_pixels = patch_size * overlap_margin
97
+ total_margin_pixels = margin_pixels * 2 # Both sides
98
+
99
+ # Calculate crop parameters
100
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
101
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
102
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
103
+
104
+ # Determine tiling
105
+ tiling = select_tiling(
106
+ original_h - total_margin_pixels,
107
+ original_w - total_margin_pixels,
108
+ crop_window_size,
109
+ max_crops,
110
+ )
111
+
112
+ # Pre-allocate crops.
113
+ n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
114
+ crops = np.zeros(
115
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
116
+ )
117
+
118
+ # Resize image to fit tiling
119
+ target_size = (
120
+ tiling[0] * crop_window_size + total_margin_pixels,
121
+ tiling[1] * crop_window_size + total_margin_pixels,
122
+ )
123
+
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
+
152
+ for i in range(tiling[0]):
153
+ for j in range(tiling[1]):
154
+ # Calculate crop coordinates
155
+ y0 = i * crop_window_size
156
+ x0 = j * crop_window_size
157
+
158
+ # Extract crop with padding if needed
159
+ y_end = min(y0 + base_size[0], image.shape[0])
160
+ x_end = min(x0 + base_size[1], image.shape[1])
161
+
162
+ crop_region = image[y0:y_end, x0:x_end]
163
+ crops[
164
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
165
+ ] = crop_region
166
+
167
+ return {"crops": crops, "tiling": tiling}
168
+
169
+
170
+ def reconstruct_from_crops(
171
+ crops: torch.Tensor,
172
+ tiling: tuple[int, int],
173
+ overlap_margin: int,
174
+ patch_size: int = 14,
175
+ ) -> torch.Tensor:
176
+ """
177
+ Reconstruct the original image from overlapping crops into a single seamless image.
178
+
179
+ Takes a list of overlapping image crops along with their positional metadata and
180
+ reconstructs them into a single coherent image by carefully stitching together
181
+ non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
182
+
183
+ Args:
184
+ crops: List of image crops as numpy arrays or PyTorch tensors with shape
185
+ (H,W,C)
186
+ tiling: Tuple of (height,width) indicating crop grid layout
187
+ patch_size: Size in pixels of each patch, default 14
188
+ overlap_margin: Number of overlapping patches on each edge, default 4
189
+
190
+ Returns:
191
+ Reconstructed image as numpy array or PyTorch tensor matching input type,
192
+ with shape (H,W,C) where H,W are the original image dimensions
193
+ """
194
+ tiling_h, tiling_w = tiling
195
+ crop_height, crop_width = crops[0].shape[:2]
196
+ margin_pixels = overlap_margin * patch_size
197
+
198
+ # Calculate output size (only adding margins once)
199
+ output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
200
+ output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
201
+
202
+ reconstructed = torch.zeros(
203
+ (output_h, output_w, crops[0].shape[2]),
204
+ device=crops[0].device,
205
+ dtype=crops[0].dtype,
206
+ )
207
+
208
+ for i, crop in enumerate(crops):
209
+ tile_y = i // tiling_w
210
+ tile_x = i % tiling_w
211
+
212
+ # For each tile, determine which part to keep
213
+ # Keep left margin only for first column
214
+ x_start = 0 if tile_x == 0 else margin_pixels
215
+ # Keep right margin only for last column
216
+ x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
217
+ # Keep top margin only for first row
218
+ y_start = 0 if tile_y == 0 else margin_pixels
219
+ # Keep bottom margin only for last row
220
+ y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
221
+
222
+ # Calculate where this piece belongs in the output
223
+ out_x = tile_x * (crop_width - 2 * margin_pixels)
224
+ out_y = tile_y * (crop_height - 2 * margin_pixels)
225
+
226
+ # Place the piece
227
+ reconstructed[
228
+ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
229
+ ] = crop[y_start:y_end, x_start:x_end]
230
+
231
+ return reconstructed
layers.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Literal, Optional
7
+
8
+ try:
9
+ from torchao import quantize_
10
+ from torchao.quantization import int4_weight_only
11
+ except ImportError:
12
+
13
+ def quantize_(model, quant_mode):
14
+ raise ImportError(
15
+ "torchao is not installed. Please install it with `pip install torchao`."
16
+ )
17
+
18
+ def int4_weight_only(group_size):
19
+ raise ImportError(
20
+ "torchao is not installed. Please install it with `pip install torchao`."
21
+ )
22
+
23
+
24
+ def gelu_approx(x):
25
+ return F.gelu(x, approximate="tanh")
26
+
27
+
28
+ @dataclass
29
+ class LinearWeights:
30
+ weight: torch.Tensor
31
+ bias: torch.Tensor
32
+
33
+
34
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
35
+ return F.linear(x, w.weight, w.bias)
36
+
37
+
38
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39
+ _step = W_q.shape[0]
40
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
41
+ W_r[:_step] = (W_q & 0b11110000) >> 4
42
+ W_r[_step:] = W_q & 0b00001111
43
+ W_r.sub_(zero).mul_(scale)
44
+ return W_r.reshape(orig_shape)
45
+
46
+
47
+ class QuantizedLinear(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ out_features: int,
52
+ dtype: torch.dtype,
53
+ ):
54
+ # TODO: Take group_size as an input instead of hardcoding it here.
55
+ super().__init__()
56
+ self.in_features = in_features
57
+ self.out_features = out_features
58
+ self.weight = nn.ParameterDict(
59
+ {
60
+ "packed": nn.Parameter(
61
+ torch.empty(
62
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
63
+ ),
64
+ requires_grad=False,
65
+ ),
66
+ "scale": nn.Parameter(
67
+ torch.empty(out_features * in_features // 128, 1),
68
+ requires_grad=False,
69
+ ),
70
+ "zero_point": nn.Parameter(
71
+ torch.empty(out_features * in_features // 128, 1),
72
+ requires_grad=False,
73
+ ),
74
+ }
75
+ )
76
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
77
+ self.unpacked = False
78
+
79
+ def unpack(self):
80
+ if self.unpacked:
81
+ return
82
+
83
+ self.weight = nn.Parameter(
84
+ dequantize_tensor(
85
+ self.weight["packed"],
86
+ self.weight["scale"],
87
+ self.weight["zero_point"],
88
+ (self.out_features, self.in_features),
89
+ torch.bfloat16,
90
+ )
91
+ )
92
+ with torch.device("meta"):
93
+ self.linear = nn.Linear(
94
+ self.in_features, self.out_features, dtype=torch.bfloat16
95
+ )
96
+ self.linear.weight = self.weight
97
+ self.linear.bias = nn.Parameter(
98
+ self.bias.to(torch.bfloat16), requires_grad=False
99
+ )
100
+
101
+ del self.weight, self.bias
102
+ quantize_(self, int4_weight_only(group_size=128))
103
+ self.unpacked = True
104
+ torch.cuda.empty_cache()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.unpacked:
108
+ self.unpack()
109
+ return self.linear(x)
110
+
111
+
112
+ @dataclass
113
+ class LayerNormWeights:
114
+ weight: torch.Tensor
115
+ bias: torch.Tensor
116
+
117
+
118
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
119
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
120
+
121
+
122
+ @dataclass
123
+ class MLPWeights:
124
+ fc1: LinearWeights
125
+ fc2: LinearWeights
126
+ act: Literal["gelu_approx"] = "gelu_approx"
127
+
128
+
129
+ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
130
+ x0 = w.fc1(x)
131
+ if lora is not None:
132
+ x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
+ x = x0 + x1
134
+ else:
135
+ x = x0
136
+
137
+ x = gelu_approx(x)
138
+
139
+ x0 = w.fc2(x)
140
+ if lora is not None:
141
+ x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
+ x = x0 + x1
143
+ else:
144
+ x = x0
145
+
146
+ return x
147
+
148
+
149
+ def moe_mlp(
150
+ x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
151
+ ) -> torch.Tensor:
152
+ B, T, C = x.shape
153
+ x = x.reshape(-1, C)
154
+
155
+ # Router computation
156
+ router_logits = mlp_module.router(x)
157
+ topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)
158
+ topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)
159
+ num_tokens, top_k = topk_idxs.shape
160
+
161
+ if T == 1:
162
+ w1_weight = mlp_module.fc1.weight
163
+ w2_weight = mlp_module.fc2.weight
164
+
165
+ # Flatten to process all token-expert pairs at once
166
+ flat_idxs = topk_idxs.view(-1) # [T*A]
167
+ flat_weights = topk_weights.view(-1) # [T*A]
168
+
169
+ # Select expert weights
170
+ w1_selected = w1_weight[flat_idxs] # [T*A, H, D]
171
+ w2_selected = w2_weight[flat_idxs] # [T*A, D, H]
172
+
173
+ # Expand input for all token-expert pairs
174
+ x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
175
+
176
+ # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
177
+ x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
178
+ -1
179
+ ) # [T*A, H]
180
+ x1, g = x1_full.chunk(2, dim=-1)
181
+ x1 = F.gelu(x1) * (g + 1)
182
+
183
+ # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
184
+ expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
185
+
186
+ # Apply weights and reshape
187
+ weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
188
+ weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D]
189
+
190
+ # Sum over experts
191
+ mlp_out = weighted_outs.sum(dim=1) # [T, D]
192
+ mlp_out = mlp_out.view(B, T, C)
193
+
194
+ return mlp_out
195
+ else:
196
+ out = x.new_zeros(x.size())
197
+
198
+ for expert_id in range(mlp_module.fc1.weight.shape[0]):
199
+ token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)
200
+ if token_pos.numel() == 0:
201
+ continue
202
+
203
+ x_tok = x.index_select(0, token_pos)
204
+ gate_tok = topk_weights[token_pos, which_k]
205
+
206
+ h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
207
+ h, g = h_full.chunk(2, dim=-1)
208
+ h = F.gelu(h) * (g + 1)
209
+ y = F.linear(h, mlp_module.fc2.weight[expert_id])
210
+
211
+ y.mul_(gate_tok.unsqueeze(-1))
212
+ out.index_add_(0, token_pos, y)
213
+
214
+ return out.view(B, T, C)
215
+
216
+
217
+ @dataclass
218
+ class AttentionWeights:
219
+ qkv: LinearWeights
220
+ proj: LinearWeights
221
+
222
+
223
+ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
224
+ bsz, q_len, d_model = x.shape
225
+ head_dim = d_model // n_heads
226
+
227
+ q, k, v = [
228
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
229
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
230
+ ]
231
+ out = F.scaled_dot_product_attention(q, k, v)
232
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
233
+ out = linear(out, w.proj)
234
+ return out
lora.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import torch
5
+
6
+ from pathlib import Path
7
+ from urllib.request import Request, urlopen
8
+ from typing import Optional
9
+
10
+
11
+ def variant_cache_dir():
12
+ hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
+ if hf_hub_cache is not None:
14
+ return Path(hf_hub_cache) / "md_variants"
15
+
16
+ hf_home = os.environ.get("HF_HOME")
17
+ if hf_home is not None:
18
+ return Path(hf_home) / "hub" / "md_variants"
19
+
20
+ return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
+
22
+
23
+ def cached_variant_path(variant_id: str):
24
+ variant, *rest = variant_id.split("/", 1)
25
+ step = rest[0] if rest else "final"
26
+
27
+ cache_dir = variant_cache_dir() / variant
28
+ os.makedirs(cache_dir, exist_ok=True)
29
+ dest = cache_dir / f"{step}.pt"
30
+ if dest.exists():
31
+ return dest
32
+
33
+ md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
34
+
35
+ headers = {"User-Agent": "moondream-torch"}
36
+ api_key = os.getenv("MOONDREAM_API_KEY")
37
+ if api_key is not None:
38
+ headers["X-Moondream-Auth"] = api_key
39
+
40
+ req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
+ with urlopen(req) as r, open(dest, "wb") as f:
42
+ shutil.copyfileobj(r, f)
43
+ return dest
44
+
45
+
46
+ def nest(flat):
47
+ tree = {}
48
+ for k, v in flat.items():
49
+ parts = k.split(".")
50
+ d = tree
51
+ for p in parts[:-1]:
52
+ d = d.setdefault(p, {})
53
+ d[parts[-1]] = v
54
+ return tree
55
+
56
+
57
+ @functools.lru_cache(maxsize=5)
58
+ def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
+ if variant_id is None:
60
+ return None
61
+
62
+ state_dict = torch.load(
63
+ cached_variant_path(variant_id), map_location=device, weights_only=True
64
+ )
65
+
66
+ # TODO: Move these into the training code that saves checkpoints...
67
+ rename_rules = [
68
+ ("text_model.transformer.h", "text.blocks"),
69
+ (".mixer", ".attn"),
70
+ (".out_proj", ".proj"),
71
+ (".Wqkv", ".qkv"),
72
+ (".parametrizations.weight.0", ""),
73
+ ]
74
+ new_state_dict = {}
75
+ for key, tensor in state_dict.items():
76
+ new_key = key
77
+ for old, new in rename_rules:
78
+ if old in new_key:
79
+ new_key = new_key.replace(old, new)
80
+ new_state_dict[new_key] = tensor
81
+
82
+ return nest(new_state_dict)
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd4b3d0d6daae9c4212056cd64f02f408ff083bbb0244114eecd05fcba30037e
3
+ size 4907406296
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cf6d17391db58801b61173510ba629875679dbcbe4bfd3cb38ac0958b3c70a0
3
+ size 4736548872
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4391c6d6b46ed49aa00afddf1f7df9dd0845cbc681fdaf424e727b01ea2d3e4
3
+ size 4502742464
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6af14858bdd7cdea5d19d786726e48434b02a3c0c52a771a0f25b6a8ca640187
3
+ size 4390620392
model.safetensors.index.json ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18537245664
4
+ },
5
+ "weight_map": {
6
+ "model.region.coord_decoder.bias": "model-00004-of-00004.safetensors",
7
+ "model.region.coord_decoder.weight": "model-00004-of-00004.safetensors",
8
+ "model.region.coord_encoder.bias": "model-00004-of-00004.safetensors",
9
+ "model.region.coord_encoder.weight": "model-00004-of-00004.safetensors",
10
+ "model.region.coord_features": "model-00004-of-00004.safetensors",
11
+ "model.region.size_decoder.bias": "model-00004-of-00004.safetensors",
12
+ "model.region.size_decoder.weight": "model-00004-of-00004.safetensors",
13
+ "model.region.size_encoder.bias": "model-00004-of-00004.safetensors",
14
+ "model.region.size_encoder.weight": "model-00004-of-00004.safetensors",
15
+ "model.region.size_features": "model-00004-of-00004.safetensors",
16
+ "model.text.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
17
+ "model.text.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.text.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
19
+ "model.text.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
20
+ "model.text.blocks.0.attn.tau.alpha": "model-00001-of-00004.safetensors",
21
+ "model.text.blocks.0.attn.tau.wq": "model-00001-of-00004.safetensors",
22
+ "model.text.blocks.0.attn.tau.wv": "model-00001-of-00004.safetensors",
23
+ "model.text.blocks.0.ln.bias": "model-00001-of-00004.safetensors",
24
+ "model.text.blocks.0.ln.weight": "model-00001-of-00004.safetensors",
25
+ "model.text.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
26
+ "model.text.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
27
+ "model.text.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
28
+ "model.text.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
29
+ "model.text.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
30
+ "model.text.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.text.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
32
+ "model.text.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
33
+ "model.text.blocks.1.attn.tau.alpha": "model-00001-of-00004.safetensors",
34
+ "model.text.blocks.1.attn.tau.wq": "model-00001-of-00004.safetensors",
35
+ "model.text.blocks.1.attn.tau.wv": "model-00001-of-00004.safetensors",
36
+ "model.text.blocks.1.ln.bias": "model-00001-of-00004.safetensors",
37
+ "model.text.blocks.1.ln.weight": "model-00001-of-00004.safetensors",
38
+ "model.text.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
39
+ "model.text.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
40
+ "model.text.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
41
+ "model.text.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
42
+ "model.text.blocks.10.attn.proj.bias": "model-00002-of-00004.safetensors",
43
+ "model.text.blocks.10.attn.proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.text.blocks.10.attn.qkv.bias": "model-00002-of-00004.safetensors",
45
+ "model.text.blocks.10.attn.qkv.weight": "model-00002-of-00004.safetensors",
46
+ "model.text.blocks.10.attn.tau.alpha": "model-00002-of-00004.safetensors",
47
+ "model.text.blocks.10.attn.tau.wq": "model-00002-of-00004.safetensors",
48
+ "model.text.blocks.10.attn.tau.wv": "model-00002-of-00004.safetensors",
49
+ "model.text.blocks.10.ln.bias": "model-00002-of-00004.safetensors",
50
+ "model.text.blocks.10.ln.weight": "model-00002-of-00004.safetensors",
51
+ "model.text.blocks.10.mlp.fc1.weight": "model-00002-of-00004.safetensors",
52
+ "model.text.blocks.10.mlp.fc2.weight": "model-00002-of-00004.safetensors",
53
+ "model.text.blocks.10.mlp.router.bias": "model-00002-of-00004.safetensors",
54
+ "model.text.blocks.10.mlp.router.weight": "model-00002-of-00004.safetensors",
55
+ "model.text.blocks.11.attn.proj.bias": "model-00002-of-00004.safetensors",
56
+ "model.text.blocks.11.attn.proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.text.blocks.11.attn.qkv.bias": "model-00002-of-00004.safetensors",
58
+ "model.text.blocks.11.attn.qkv.weight": "model-00002-of-00004.safetensors",
59
+ "model.text.blocks.11.attn.tau.alpha": "model-00002-of-00004.safetensors",
60
+ "model.text.blocks.11.attn.tau.wq": "model-00002-of-00004.safetensors",
61
+ "model.text.blocks.11.attn.tau.wv": "model-00002-of-00004.safetensors",
62
+ "model.text.blocks.11.ln.bias": "model-00002-of-00004.safetensors",
63
+ "model.text.blocks.11.ln.weight": "model-00002-of-00004.safetensors",
64
+ "model.text.blocks.11.mlp.fc1.weight": "model-00002-of-00004.safetensors",
65
+ "model.text.blocks.11.mlp.fc2.weight": "model-00002-of-00004.safetensors",
66
+ "model.text.blocks.11.mlp.router.bias": "model-00002-of-00004.safetensors",
67
+ "model.text.blocks.11.mlp.router.weight": "model-00002-of-00004.safetensors",
68
+ "model.text.blocks.12.attn.proj.bias": "model-00002-of-00004.safetensors",
69
+ "model.text.blocks.12.attn.proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.text.blocks.12.attn.qkv.bias": "model-00002-of-00004.safetensors",
71
+ "model.text.blocks.12.attn.qkv.weight": "model-00002-of-00004.safetensors",
72
+ "model.text.blocks.12.attn.tau.alpha": "model-00002-of-00004.safetensors",
73
+ "model.text.blocks.12.attn.tau.wq": "model-00002-of-00004.safetensors",
74
+ "model.text.blocks.12.attn.tau.wv": "model-00002-of-00004.safetensors",
75
+ "model.text.blocks.12.ln.bias": "model-00002-of-00004.safetensors",
76
+ "model.text.blocks.12.ln.weight": "model-00002-of-00004.safetensors",
77
+ "model.text.blocks.12.mlp.fc1.weight": "model-00002-of-00004.safetensors",
78
+ "model.text.blocks.12.mlp.fc2.weight": "model-00002-of-00004.safetensors",
79
+ "model.text.blocks.12.mlp.router.bias": "model-00002-of-00004.safetensors",
80
+ "model.text.blocks.12.mlp.router.weight": "model-00002-of-00004.safetensors",
81
+ "model.text.blocks.13.attn.proj.bias": "model-00002-of-00004.safetensors",
82
+ "model.text.blocks.13.attn.proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.text.blocks.13.attn.qkv.bias": "model-00002-of-00004.safetensors",
84
+ "model.text.blocks.13.attn.qkv.weight": "model-00002-of-00004.safetensors",
85
+ "model.text.blocks.13.attn.tau.alpha": "model-00002-of-00004.safetensors",
86
+ "model.text.blocks.13.attn.tau.wq": "model-00002-of-00004.safetensors",
87
+ "model.text.blocks.13.attn.tau.wv": "model-00002-of-00004.safetensors",
88
+ "model.text.blocks.13.ln.bias": "model-00002-of-00004.safetensors",
89
+ "model.text.blocks.13.ln.weight": "model-00002-of-00004.safetensors",
90
+ "model.text.blocks.13.mlp.fc1.weight": "model-00002-of-00004.safetensors",
91
+ "model.text.blocks.13.mlp.fc2.weight": "model-00003-of-00004.safetensors",
92
+ "model.text.blocks.13.mlp.router.bias": "model-00002-of-00004.safetensors",
93
+ "model.text.blocks.13.mlp.router.weight": "model-00002-of-00004.safetensors",
94
+ "model.text.blocks.14.attn.proj.bias": "model-00003-of-00004.safetensors",
95
+ "model.text.blocks.14.attn.proj.weight": "model-00003-of-00004.safetensors",
96
+ "model.text.blocks.14.attn.qkv.bias": "model-00003-of-00004.safetensors",
97
+ "model.text.blocks.14.attn.qkv.weight": "model-00003-of-00004.safetensors",
98
+ "model.text.blocks.14.attn.tau.alpha": "model-00003-of-00004.safetensors",
99
+ "model.text.blocks.14.attn.tau.wq": "model-00003-of-00004.safetensors",
100
+ "model.text.blocks.14.attn.tau.wv": "model-00003-of-00004.safetensors",
101
+ "model.text.blocks.14.ln.bias": "model-00003-of-00004.safetensors",
102
+ "model.text.blocks.14.ln.weight": "model-00003-of-00004.safetensors",
103
+ "model.text.blocks.14.mlp.fc1.weight": "model-00003-of-00004.safetensors",
104
+ "model.text.blocks.14.mlp.fc2.weight": "model-00003-of-00004.safetensors",
105
+ "model.text.blocks.14.mlp.router.bias": "model-00003-of-00004.safetensors",
106
+ "model.text.blocks.14.mlp.router.weight": "model-00003-of-00004.safetensors",
107
+ "model.text.blocks.15.attn.proj.bias": "model-00003-of-00004.safetensors",
108
+ "model.text.blocks.15.attn.proj.weight": "model-00003-of-00004.safetensors",
109
+ "model.text.blocks.15.attn.qkv.bias": "model-00003-of-00004.safetensors",
110
+ "model.text.blocks.15.attn.qkv.weight": "model-00003-of-00004.safetensors",
111
+ "model.text.blocks.15.attn.tau.alpha": "model-00003-of-00004.safetensors",
112
+ "model.text.blocks.15.attn.tau.wq": "model-00003-of-00004.safetensors",
113
+ "model.text.blocks.15.attn.tau.wv": "model-00003-of-00004.safetensors",
114
+ "model.text.blocks.15.ln.bias": "model-00003-of-00004.safetensors",
115
+ "model.text.blocks.15.ln.weight": "model-00003-of-00004.safetensors",
116
+ "model.text.blocks.15.mlp.fc1.weight": "model-00003-of-00004.safetensors",
117
+ "model.text.blocks.15.mlp.fc2.weight": "model-00003-of-00004.safetensors",
118
+ "model.text.blocks.15.mlp.router.bias": "model-00003-of-00004.safetensors",
119
+ "model.text.blocks.15.mlp.router.weight": "model-00003-of-00004.safetensors",
120
+ "model.text.blocks.16.attn.proj.bias": "model-00003-of-00004.safetensors",
121
+ "model.text.blocks.16.attn.proj.weight": "model-00003-of-00004.safetensors",
122
+ "model.text.blocks.16.attn.qkv.bias": "model-00003-of-00004.safetensors",
123
+ "model.text.blocks.16.attn.qkv.weight": "model-00003-of-00004.safetensors",
124
+ "model.text.blocks.16.attn.tau.alpha": "model-00003-of-00004.safetensors",
125
+ "model.text.blocks.16.attn.tau.wq": "model-00003-of-00004.safetensors",
126
+ "model.text.blocks.16.attn.tau.wv": "model-00003-of-00004.safetensors",
127
+ "model.text.blocks.16.ln.bias": "model-00003-of-00004.safetensors",
128
+ "model.text.blocks.16.ln.weight": "model-00003-of-00004.safetensors",
129
+ "model.text.blocks.16.mlp.fc1.weight": "model-00003-of-00004.safetensors",
130
+ "model.text.blocks.16.mlp.fc2.weight": "model-00003-of-00004.safetensors",
131
+ "model.text.blocks.16.mlp.router.bias": "model-00003-of-00004.safetensors",
132
+ "model.text.blocks.16.mlp.router.weight": "model-00003-of-00004.safetensors",
133
+ "model.text.blocks.17.attn.proj.bias": "model-00003-of-00004.safetensors",
134
+ "model.text.blocks.17.attn.proj.weight": "model-00003-of-00004.safetensors",
135
+ "model.text.blocks.17.attn.qkv.bias": "model-00003-of-00004.safetensors",
136
+ "model.text.blocks.17.attn.qkv.weight": "model-00003-of-00004.safetensors",
137
+ "model.text.blocks.17.attn.tau.alpha": "model-00003-of-00004.safetensors",
138
+ "model.text.blocks.17.attn.tau.wq": "model-00003-of-00004.safetensors",
139
+ "model.text.blocks.17.attn.tau.wv": "model-00003-of-00004.safetensors",
140
+ "model.text.blocks.17.ln.bias": "model-00003-of-00004.safetensors",
141
+ "model.text.blocks.17.ln.weight": "model-00003-of-00004.safetensors",
142
+ "model.text.blocks.17.mlp.fc1.weight": "model-00003-of-00004.safetensors",
143
+ "model.text.blocks.17.mlp.fc2.weight": "model-00003-of-00004.safetensors",
144
+ "model.text.blocks.17.mlp.router.bias": "model-00003-of-00004.safetensors",
145
+ "model.text.blocks.17.mlp.router.weight": "model-00003-of-00004.safetensors",
146
+ "model.text.blocks.18.attn.proj.bias": "model-00003-of-00004.safetensors",
147
+ "model.text.blocks.18.attn.proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.text.blocks.18.attn.qkv.bias": "model-00003-of-00004.safetensors",
149
+ "model.text.blocks.18.attn.qkv.weight": "model-00003-of-00004.safetensors",
150
+ "model.text.blocks.18.attn.tau.alpha": "model-00003-of-00004.safetensors",
151
+ "model.text.blocks.18.attn.tau.wq": "model-00003-of-00004.safetensors",
152
+ "model.text.blocks.18.attn.tau.wv": "model-00003-of-00004.safetensors",
153
+ "model.text.blocks.18.ln.bias": "model-00003-of-00004.safetensors",
154
+ "model.text.blocks.18.ln.weight": "model-00003-of-00004.safetensors",
155
+ "model.text.blocks.18.mlp.fc1.weight": "model-00003-of-00004.safetensors",
156
+ "model.text.blocks.18.mlp.fc2.weight": "model-00003-of-00004.safetensors",
157
+ "model.text.blocks.18.mlp.router.bias": "model-00003-of-00004.safetensors",
158
+ "model.text.blocks.18.mlp.router.weight": "model-00003-of-00004.safetensors",
159
+ "model.text.blocks.19.attn.proj.bias": "model-00003-of-00004.safetensors",
160
+ "model.text.blocks.19.attn.proj.weight": "model-00003-of-00004.safetensors",
161
+ "model.text.blocks.19.attn.qkv.bias": "model-00003-of-00004.safetensors",
162
+ "model.text.blocks.19.attn.qkv.weight": "model-00003-of-00004.safetensors",
163
+ "model.text.blocks.19.attn.tau.alpha": "model-00003-of-00004.safetensors",
164
+ "model.text.blocks.19.attn.tau.wq": "model-00003-of-00004.safetensors",
165
+ "model.text.blocks.19.attn.tau.wv": "model-00003-of-00004.safetensors",
166
+ "model.text.blocks.19.ln.bias": "model-00003-of-00004.safetensors",
167
+ "model.text.blocks.19.ln.weight": "model-00003-of-00004.safetensors",
168
+ "model.text.blocks.19.mlp.fc1.weight": "model-00004-of-00004.safetensors",
169
+ "model.text.blocks.19.mlp.fc2.weight": "model-00004-of-00004.safetensors",
170
+ "model.text.blocks.19.mlp.router.bias": "model-00003-of-00004.safetensors",
171
+ "model.text.blocks.19.mlp.router.weight": "model-00003-of-00004.safetensors",
172
+ "model.text.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
173
+ "model.text.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
174
+ "model.text.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
175
+ "model.text.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
176
+ "model.text.blocks.2.attn.tau.alpha": "model-00001-of-00004.safetensors",
177
+ "model.text.blocks.2.attn.tau.wq": "model-00001-of-00004.safetensors",
178
+ "model.text.blocks.2.attn.tau.wv": "model-00001-of-00004.safetensors",
179
+ "model.text.blocks.2.ln.bias": "model-00001-of-00004.safetensors",
180
+ "model.text.blocks.2.ln.weight": "model-00001-of-00004.safetensors",
181
+ "model.text.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
182
+ "model.text.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
183
+ "model.text.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
184
+ "model.text.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
185
+ "model.text.blocks.20.attn.proj.bias": "model-00004-of-00004.safetensors",
186
+ "model.text.blocks.20.attn.proj.weight": "model-00004-of-00004.safetensors",
187
+ "model.text.blocks.20.attn.qkv.bias": "model-00004-of-00004.safetensors",
188
+ "model.text.blocks.20.attn.qkv.weight": "model-00004-of-00004.safetensors",
189
+ "model.text.blocks.20.attn.tau.alpha": "model-00004-of-00004.safetensors",
190
+ "model.text.blocks.20.attn.tau.wq": "model-00004-of-00004.safetensors",
191
+ "model.text.blocks.20.attn.tau.wv": "model-00004-of-00004.safetensors",
192
+ "model.text.blocks.20.ln.bias": "model-00004-of-00004.safetensors",
193
+ "model.text.blocks.20.ln.weight": "model-00004-of-00004.safetensors",
194
+ "model.text.blocks.20.mlp.fc1.weight": "model-00004-of-00004.safetensors",
195
+ "model.text.blocks.20.mlp.fc2.weight": "model-00004-of-00004.safetensors",
196
+ "model.text.blocks.20.mlp.router.bias": "model-00004-of-00004.safetensors",
197
+ "model.text.blocks.20.mlp.router.weight": "model-00004-of-00004.safetensors",
198
+ "model.text.blocks.21.attn.proj.bias": "model-00004-of-00004.safetensors",
199
+ "model.text.blocks.21.attn.proj.weight": "model-00004-of-00004.safetensors",
200
+ "model.text.blocks.21.attn.qkv.bias": "model-00004-of-00004.safetensors",
201
+ "model.text.blocks.21.attn.qkv.weight": "model-00004-of-00004.safetensors",
202
+ "model.text.blocks.21.attn.tau.alpha": "model-00004-of-00004.safetensors",
203
+ "model.text.blocks.21.attn.tau.wq": "model-00004-of-00004.safetensors",
204
+ "model.text.blocks.21.attn.tau.wv": "model-00004-of-00004.safetensors",
205
+ "model.text.blocks.21.ln.bias": "model-00004-of-00004.safetensors",
206
+ "model.text.blocks.21.ln.weight": "model-00004-of-00004.safetensors",
207
+ "model.text.blocks.21.mlp.fc1.weight": "model-00004-of-00004.safetensors",
208
+ "model.text.blocks.21.mlp.fc2.weight": "model-00004-of-00004.safetensors",
209
+ "model.text.blocks.21.mlp.router.bias": "model-00004-of-00004.safetensors",
210
+ "model.text.blocks.21.mlp.router.weight": "model-00004-of-00004.safetensors",
211
+ "model.text.blocks.22.attn.proj.bias": "model-00004-of-00004.safetensors",
212
+ "model.text.blocks.22.attn.proj.weight": "model-00004-of-00004.safetensors",
213
+ "model.text.blocks.22.attn.qkv.bias": "model-00004-of-00004.safetensors",
214
+ "model.text.blocks.22.attn.qkv.weight": "model-00004-of-00004.safetensors",
215
+ "model.text.blocks.22.attn.tau.alpha": "model-00004-of-00004.safetensors",
216
+ "model.text.blocks.22.attn.tau.wq": "model-00004-of-00004.safetensors",
217
+ "model.text.blocks.22.attn.tau.wv": "model-00004-of-00004.safetensors",
218
+ "model.text.blocks.22.ln.bias": "model-00004-of-00004.safetensors",
219
+ "model.text.blocks.22.ln.weight": "model-00004-of-00004.safetensors",
220
+ "model.text.blocks.22.mlp.fc1.weight": "model-00004-of-00004.safetensors",
221
+ "model.text.blocks.22.mlp.fc2.weight": "model-00004-of-00004.safetensors",
222
+ "model.text.blocks.22.mlp.router.bias": "model-00004-of-00004.safetensors",
223
+ "model.text.blocks.22.mlp.router.weight": "model-00004-of-00004.safetensors",
224
+ "model.text.blocks.23.attn.proj.bias": "model-00004-of-00004.safetensors",
225
+ "model.text.blocks.23.attn.proj.weight": "model-00004-of-00004.safetensors",
226
+ "model.text.blocks.23.attn.qkv.bias": "model-00004-of-00004.safetensors",
227
+ "model.text.blocks.23.attn.qkv.weight": "model-00004-of-00004.safetensors",
228
+ "model.text.blocks.23.attn.tau.alpha": "model-00004-of-00004.safetensors",
229
+ "model.text.blocks.23.attn.tau.wq": "model-00004-of-00004.safetensors",
230
+ "model.text.blocks.23.attn.tau.wv": "model-00004-of-00004.safetensors",
231
+ "model.text.blocks.23.ln.bias": "model-00004-of-00004.safetensors",
232
+ "model.text.blocks.23.ln.weight": "model-00004-of-00004.safetensors",
233
+ "model.text.blocks.23.mlp.fc1.weight": "model-00004-of-00004.safetensors",
234
+ "model.text.blocks.23.mlp.fc2.weight": "model-00004-of-00004.safetensors",
235
+ "model.text.blocks.23.mlp.router.bias": "model-00004-of-00004.safetensors",
236
+ "model.text.blocks.23.mlp.router.weight": "model-00004-of-00004.safetensors",
237
+ "model.text.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
238
+ "model.text.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
239
+ "model.text.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
240
+ "model.text.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
241
+ "model.text.blocks.3.attn.tau.alpha": "model-00001-of-00004.safetensors",
242
+ "model.text.blocks.3.attn.tau.wq": "model-00001-of-00004.safetensors",
243
+ "model.text.blocks.3.attn.tau.wv": "model-00001-of-00004.safetensors",
244
+ "model.text.blocks.3.ln.bias": "model-00001-of-00004.safetensors",
245
+ "model.text.blocks.3.ln.weight": "model-00001-of-00004.safetensors",
246
+ "model.text.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
247
+ "model.text.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
248
+ "model.text.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
249
+ "model.text.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
250
+ "model.text.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
251
+ "model.text.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
252
+ "model.text.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
253
+ "model.text.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
254
+ "model.text.blocks.4.attn.tau.alpha": "model-00001-of-00004.safetensors",
255
+ "model.text.blocks.4.attn.tau.wq": "model-00001-of-00004.safetensors",
256
+ "model.text.blocks.4.attn.tau.wv": "model-00001-of-00004.safetensors",
257
+ "model.text.blocks.4.ln.bias": "model-00001-of-00004.safetensors",
258
+ "model.text.blocks.4.ln.weight": "model-00001-of-00004.safetensors",
259
+ "model.text.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
260
+ "model.text.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
261
+ "model.text.blocks.4.mlp.router.bias": "model-00001-of-00004.safetensors",
262
+ "model.text.blocks.4.mlp.router.weight": "model-00001-of-00004.safetensors",
263
+ "model.text.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
264
+ "model.text.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
265
+ "model.text.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
266
+ "model.text.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
267
+ "model.text.blocks.5.attn.tau.alpha": "model-00001-of-00004.safetensors",
268
+ "model.text.blocks.5.attn.tau.wq": "model-00001-of-00004.safetensors",
269
+ "model.text.blocks.5.attn.tau.wv": "model-00001-of-00004.safetensors",
270
+ "model.text.blocks.5.ln.bias": "model-00001-of-00004.safetensors",
271
+ "model.text.blocks.5.ln.weight": "model-00001-of-00004.safetensors",
272
+ "model.text.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
273
+ "model.text.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
274
+ "model.text.blocks.5.mlp.router.bias": "model-00001-of-00004.safetensors",
275
+ "model.text.blocks.5.mlp.router.weight": "model-00001-of-00004.safetensors",
276
+ "model.text.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
277
+ "model.text.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
278
+ "model.text.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
279
+ "model.text.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
280
+ "model.text.blocks.6.attn.tau.alpha": "model-00001-of-00004.safetensors",
281
+ "model.text.blocks.6.attn.tau.wq": "model-00001-of-00004.safetensors",
282
+ "model.text.blocks.6.attn.tau.wv": "model-00001-of-00004.safetensors",
283
+ "model.text.blocks.6.ln.bias": "model-00001-of-00004.safetensors",
284
+ "model.text.blocks.6.ln.weight": "model-00001-of-00004.safetensors",
285
+ "model.text.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
286
+ "model.text.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
287
+ "model.text.blocks.6.mlp.router.bias": "model-00001-of-00004.safetensors",
288
+ "model.text.blocks.6.mlp.router.weight": "model-00001-of-00004.safetensors",
289
+ "model.text.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
290
+ "model.text.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
291
+ "model.text.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
292
+ "model.text.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
293
+ "model.text.blocks.7.attn.tau.alpha": "model-00001-of-00004.safetensors",
294
+ "model.text.blocks.7.attn.tau.wq": "model-00001-of-00004.safetensors",
295
+ "model.text.blocks.7.attn.tau.wv": "model-00001-of-00004.safetensors",
296
+ "model.text.blocks.7.ln.bias": "model-00001-of-00004.safetensors",
297
+ "model.text.blocks.7.ln.weight": "model-00001-of-00004.safetensors",
298
+ "model.text.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
299
+ "model.text.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
300
+ "model.text.blocks.7.mlp.router.bias": "model-00001-of-00004.safetensors",
301
+ "model.text.blocks.7.mlp.router.weight": "model-00001-of-00004.safetensors",
302
+ "model.text.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
303
+ "model.text.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
304
+ "model.text.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
305
+ "model.text.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
306
+ "model.text.blocks.8.attn.tau.alpha": "model-00001-of-00004.safetensors",
307
+ "model.text.blocks.8.attn.tau.wq": "model-00001-of-00004.safetensors",
308
+ "model.text.blocks.8.attn.tau.wv": "model-00001-of-00004.safetensors",
309
+ "model.text.blocks.8.ln.bias": "model-00001-of-00004.safetensors",
310
+ "model.text.blocks.8.ln.weight": "model-00001-of-00004.safetensors",
311
+ "model.text.blocks.8.mlp.fc1.weight": "model-00002-of-00004.safetensors",
312
+ "model.text.blocks.8.mlp.fc2.weight": "model-00002-of-00004.safetensors",
313
+ "model.text.blocks.8.mlp.router.bias": "model-00001-of-00004.safetensors",
314
+ "model.text.blocks.8.mlp.router.weight": "model-00001-of-00004.safetensors",
315
+ "model.text.blocks.9.attn.proj.bias": "model-00002-of-00004.safetensors",
316
+ "model.text.blocks.9.attn.proj.weight": "model-00002-of-00004.safetensors",
317
+ "model.text.blocks.9.attn.qkv.bias": "model-00002-of-00004.safetensors",
318
+ "model.text.blocks.9.attn.qkv.weight": "model-00002-of-00004.safetensors",
319
+ "model.text.blocks.9.attn.tau.alpha": "model-00002-of-00004.safetensors",
320
+ "model.text.blocks.9.attn.tau.wq": "model-00002-of-00004.safetensors",
321
+ "model.text.blocks.9.attn.tau.wv": "model-00002-of-00004.safetensors",
322
+ "model.text.blocks.9.ln.bias": "model-00002-of-00004.safetensors",
323
+ "model.text.blocks.9.ln.weight": "model-00002-of-00004.safetensors",
324
+ "model.text.blocks.9.mlp.fc1.weight": "model-00002-of-00004.safetensors",
325
+ "model.text.blocks.9.mlp.fc2.weight": "model-00002-of-00004.safetensors",
326
+ "model.text.blocks.9.mlp.router.bias": "model-00002-of-00004.safetensors",
327
+ "model.text.blocks.9.mlp.router.weight": "model-00002-of-00004.safetensors",
328
+ "model.text.lm_head.bias": "model-00004-of-00004.safetensors",
329
+ "model.text.lm_head.weight": "model-00004-of-00004.safetensors",
330
+ "model.text.post_ln.bias": "model-00004-of-00004.safetensors",
331
+ "model.text.post_ln.weight": "model-00004-of-00004.safetensors",
332
+ "model.text.wte": "model-00001-of-00004.safetensors",
333
+ "model.vision.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
334
+ "model.vision.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
335
+ "model.vision.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
336
+ "model.vision.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
337
+ "model.vision.blocks.0.ln1.bias": "model-00001-of-00004.safetensors",
338
+ "model.vision.blocks.0.ln1.weight": "model-00001-of-00004.safetensors",
339
+ "model.vision.blocks.0.ln2.bias": "model-00001-of-00004.safetensors",
340
+ "model.vision.blocks.0.ln2.weight": "model-00001-of-00004.safetensors",
341
+ "model.vision.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
342
+ "model.vision.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
343
+ "model.vision.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
344
+ "model.vision.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
345
+ "model.vision.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
346
+ "model.vision.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
347
+ "model.vision.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
348
+ "model.vision.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
349
+ "model.vision.blocks.1.ln1.bias": "model-00001-of-00004.safetensors",
350
+ "model.vision.blocks.1.ln1.weight": "model-00001-of-00004.safetensors",
351
+ "model.vision.blocks.1.ln2.bias": "model-00001-of-00004.safetensors",
352
+ "model.vision.blocks.1.ln2.weight": "model-00001-of-00004.safetensors",
353
+ "model.vision.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
354
+ "model.vision.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
355
+ "model.vision.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
356
+ "model.vision.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
357
+ "model.vision.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
358
+ "model.vision.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
359
+ "model.vision.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
360
+ "model.vision.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
361
+ "model.vision.blocks.10.ln1.bias": "model-00001-of-00004.safetensors",
362
+ "model.vision.blocks.10.ln1.weight": "model-00001-of-00004.safetensors",
363
+ "model.vision.blocks.10.ln2.bias": "model-00001-of-00004.safetensors",
364
+ "model.vision.blocks.10.ln2.weight": "model-00001-of-00004.safetensors",
365
+ "model.vision.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
366
+ "model.vision.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
367
+ "model.vision.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
368
+ "model.vision.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
369
+ "model.vision.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
370
+ "model.vision.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.vision.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
372
+ "model.vision.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
373
+ "model.vision.blocks.11.ln1.bias": "model-00001-of-00004.safetensors",
374
+ "model.vision.blocks.11.ln1.weight": "model-00001-of-00004.safetensors",
375
+ "model.vision.blocks.11.ln2.bias": "model-00001-of-00004.safetensors",
376
+ "model.vision.blocks.11.ln2.weight": "model-00001-of-00004.safetensors",
377
+ "model.vision.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
378
+ "model.vision.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
379
+ "model.vision.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
380
+ "model.vision.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
381
+ "model.vision.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
382
+ "model.vision.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
383
+ "model.vision.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
384
+ "model.vision.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
385
+ "model.vision.blocks.12.ln1.bias": "model-00001-of-00004.safetensors",
386
+ "model.vision.blocks.12.ln1.weight": "model-00001-of-00004.safetensors",
387
+ "model.vision.blocks.12.ln2.bias": "model-00001-of-00004.safetensors",
388
+ "model.vision.blocks.12.ln2.weight": "model-00001-of-00004.safetensors",
389
+ "model.vision.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
390
+ "model.vision.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
391
+ "model.vision.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
392
+ "model.vision.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
393
+ "model.vision.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
394
+ "model.vision.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
395
+ "model.vision.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
396
+ "model.vision.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
397
+ "model.vision.blocks.13.ln1.bias": "model-00001-of-00004.safetensors",
398
+ "model.vision.blocks.13.ln1.weight": "model-00001-of-00004.safetensors",
399
+ "model.vision.blocks.13.ln2.bias": "model-00001-of-00004.safetensors",
400
+ "model.vision.blocks.13.ln2.weight": "model-00001-of-00004.safetensors",
401
+ "model.vision.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
402
+ "model.vision.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
403
+ "model.vision.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
404
+ "model.vision.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
405
+ "model.vision.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
406
+ "model.vision.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
407
+ "model.vision.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
408
+ "model.vision.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
409
+ "model.vision.blocks.14.ln1.bias": "model-00001-of-00004.safetensors",
410
+ "model.vision.blocks.14.ln1.weight": "model-00001-of-00004.safetensors",
411
+ "model.vision.blocks.14.ln2.bias": "model-00001-of-00004.safetensors",
412
+ "model.vision.blocks.14.ln2.weight": "model-00001-of-00004.safetensors",
413
+ "model.vision.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
414
+ "model.vision.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
415
+ "model.vision.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
416
+ "model.vision.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
417
+ "model.vision.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
418
+ "model.vision.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
419
+ "model.vision.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
420
+ "model.vision.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
421
+ "model.vision.blocks.15.ln1.bias": "model-00001-of-00004.safetensors",
422
+ "model.vision.blocks.15.ln1.weight": "model-00001-of-00004.safetensors",
423
+ "model.vision.blocks.15.ln2.bias": "model-00001-of-00004.safetensors",
424
+ "model.vision.blocks.15.ln2.weight": "model-00001-of-00004.safetensors",
425
+ "model.vision.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
426
+ "model.vision.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
427
+ "model.vision.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
428
+ "model.vision.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
429
+ "model.vision.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
430
+ "model.vision.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
431
+ "model.vision.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
432
+ "model.vision.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
433
+ "model.vision.blocks.16.ln1.bias": "model-00001-of-00004.safetensors",
434
+ "model.vision.blocks.16.ln1.weight": "model-00001-of-00004.safetensors",
435
+ "model.vision.blocks.16.ln2.bias": "model-00001-of-00004.safetensors",
436
+ "model.vision.blocks.16.ln2.weight": "model-00001-of-00004.safetensors",
437
+ "model.vision.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
438
+ "model.vision.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
439
+ "model.vision.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
440
+ "model.vision.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
441
+ "model.vision.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
442
+ "model.vision.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
443
+ "model.vision.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
444
+ "model.vision.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
445
+ "model.vision.blocks.17.ln1.bias": "model-00001-of-00004.safetensors",
446
+ "model.vision.blocks.17.ln1.weight": "model-00001-of-00004.safetensors",
447
+ "model.vision.blocks.17.ln2.bias": "model-00001-of-00004.safetensors",
448
+ "model.vision.blocks.17.ln2.weight": "model-00001-of-00004.safetensors",
449
+ "model.vision.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
450
+ "model.vision.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
451
+ "model.vision.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
452
+ "model.vision.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
453
+ "model.vision.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
454
+ "model.vision.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
455
+ "model.vision.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
456
+ "model.vision.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
457
+ "model.vision.blocks.18.ln1.bias": "model-00001-of-00004.safetensors",
458
+ "model.vision.blocks.18.ln1.weight": "model-00001-of-00004.safetensors",
459
+ "model.vision.blocks.18.ln2.bias": "model-00001-of-00004.safetensors",
460
+ "model.vision.blocks.18.ln2.weight": "model-00001-of-00004.safetensors",
461
+ "model.vision.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
462
+ "model.vision.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
463
+ "model.vision.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
464
+ "model.vision.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
465
+ "model.vision.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
466
+ "model.vision.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
467
+ "model.vision.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
468
+ "model.vision.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
469
+ "model.vision.blocks.19.ln1.bias": "model-00001-of-00004.safetensors",
470
+ "model.vision.blocks.19.ln1.weight": "model-00001-of-00004.safetensors",
471
+ "model.vision.blocks.19.ln2.bias": "model-00001-of-00004.safetensors",
472
+ "model.vision.blocks.19.ln2.weight": "model-00001-of-00004.safetensors",
473
+ "model.vision.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
474
+ "model.vision.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
475
+ "model.vision.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
476
+ "model.vision.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
477
+ "model.vision.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
478
+ "model.vision.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
479
+ "model.vision.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
480
+ "model.vision.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
481
+ "model.vision.blocks.2.ln1.bias": "model-00001-of-00004.safetensors",
482
+ "model.vision.blocks.2.ln1.weight": "model-00001-of-00004.safetensors",
483
+ "model.vision.blocks.2.ln2.bias": "model-00001-of-00004.safetensors",
484
+ "model.vision.blocks.2.ln2.weight": "model-00001-of-00004.safetensors",
485
+ "model.vision.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
486
+ "model.vision.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
487
+ "model.vision.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
488
+ "model.vision.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
489
+ "model.vision.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
490
+ "model.vision.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
491
+ "model.vision.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
492
+ "model.vision.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
493
+ "model.vision.blocks.20.ln1.bias": "model-00001-of-00004.safetensors",
494
+ "model.vision.blocks.20.ln1.weight": "model-00001-of-00004.safetensors",
495
+ "model.vision.blocks.20.ln2.bias": "model-00001-of-00004.safetensors",
496
+ "model.vision.blocks.20.ln2.weight": "model-00001-of-00004.safetensors",
497
+ "model.vision.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
498
+ "model.vision.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
499
+ "model.vision.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
500
+ "model.vision.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
501
+ "model.vision.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
502
+ "model.vision.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
503
+ "model.vision.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
504
+ "model.vision.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
505
+ "model.vision.blocks.21.ln1.bias": "model-00001-of-00004.safetensors",
506
+ "model.vision.blocks.21.ln1.weight": "model-00001-of-00004.safetensors",
507
+ "model.vision.blocks.21.ln2.bias": "model-00001-of-00004.safetensors",
508
+ "model.vision.blocks.21.ln2.weight": "model-00001-of-00004.safetensors",
509
+ "model.vision.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
510
+ "model.vision.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
511
+ "model.vision.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
512
+ "model.vision.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
513
+ "model.vision.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
514
+ "model.vision.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
515
+ "model.vision.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
516
+ "model.vision.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
517
+ "model.vision.blocks.22.ln1.bias": "model-00001-of-00004.safetensors",
518
+ "model.vision.blocks.22.ln1.weight": "model-00001-of-00004.safetensors",
519
+ "model.vision.blocks.22.ln2.bias": "model-00001-of-00004.safetensors",
520
+ "model.vision.blocks.22.ln2.weight": "model-00001-of-00004.safetensors",
521
+ "model.vision.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
522
+ "model.vision.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
523
+ "model.vision.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
524
+ "model.vision.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
525
+ "model.vision.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
526
+ "model.vision.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
527
+ "model.vision.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
528
+ "model.vision.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
529
+ "model.vision.blocks.23.ln1.bias": "model-00001-of-00004.safetensors",
530
+ "model.vision.blocks.23.ln1.weight": "model-00001-of-00004.safetensors",
531
+ "model.vision.blocks.23.ln2.bias": "model-00001-of-00004.safetensors",
532
+ "model.vision.blocks.23.ln2.weight": "model-00001-of-00004.safetensors",
533
+ "model.vision.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
534
+ "model.vision.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
535
+ "model.vision.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
536
+ "model.vision.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
537
+ "model.vision.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
538
+ "model.vision.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
539
+ "model.vision.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
540
+ "model.vision.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
541
+ "model.vision.blocks.24.ln1.bias": "model-00001-of-00004.safetensors",
542
+ "model.vision.blocks.24.ln1.weight": "model-00001-of-00004.safetensors",
543
+ "model.vision.blocks.24.ln2.bias": "model-00001-of-00004.safetensors",
544
+ "model.vision.blocks.24.ln2.weight": "model-00001-of-00004.safetensors",
545
+ "model.vision.blocks.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
546
+ "model.vision.blocks.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
547
+ "model.vision.blocks.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
548
+ "model.vision.blocks.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
549
+ "model.vision.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
550
+ "model.vision.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
551
+ "model.vision.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
552
+ "model.vision.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
553
+ "model.vision.blocks.25.ln1.bias": "model-00001-of-00004.safetensors",
554
+ "model.vision.blocks.25.ln1.weight": "model-00001-of-00004.safetensors",
555
+ "model.vision.blocks.25.ln2.bias": "model-00001-of-00004.safetensors",
556
+ "model.vision.blocks.25.ln2.weight": "model-00001-of-00004.safetensors",
557
+ "model.vision.blocks.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
558
+ "model.vision.blocks.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
559
+ "model.vision.blocks.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
560
+ "model.vision.blocks.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
561
+ "model.vision.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
562
+ "model.vision.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
563
+ "model.vision.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
564
+ "model.vision.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
565
+ "model.vision.blocks.26.ln1.bias": "model-00001-of-00004.safetensors",
566
+ "model.vision.blocks.26.ln1.weight": "model-00001-of-00004.safetensors",
567
+ "model.vision.blocks.26.ln2.bias": "model-00001-of-00004.safetensors",
568
+ "model.vision.blocks.26.ln2.weight": "model-00001-of-00004.safetensors",
569
+ "model.vision.blocks.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
570
+ "model.vision.blocks.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
571
+ "model.vision.blocks.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
572
+ "model.vision.blocks.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
573
+ "model.vision.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
574
+ "model.vision.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
575
+ "model.vision.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
576
+ "model.vision.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
577
+ "model.vision.blocks.3.ln1.bias": "model-00001-of-00004.safetensors",
578
+ "model.vision.blocks.3.ln1.weight": "model-00001-of-00004.safetensors",
579
+ "model.vision.blocks.3.ln2.bias": "model-00001-of-00004.safetensors",
580
+ "model.vision.blocks.3.ln2.weight": "model-00001-of-00004.safetensors",
581
+ "model.vision.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
582
+ "model.vision.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
583
+ "model.vision.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
584
+ "model.vision.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
585
+ "model.vision.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
586
+ "model.vision.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
587
+ "model.vision.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
588
+ "model.vision.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
589
+ "model.vision.blocks.4.ln1.bias": "model-00001-of-00004.safetensors",
590
+ "model.vision.blocks.4.ln1.weight": "model-00001-of-00004.safetensors",
591
+ "model.vision.blocks.4.ln2.bias": "model-00001-of-00004.safetensors",
592
+ "model.vision.blocks.4.ln2.weight": "model-00001-of-00004.safetensors",
593
+ "model.vision.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
594
+ "model.vision.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
595
+ "model.vision.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
596
+ "model.vision.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
597
+ "model.vision.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
598
+ "model.vision.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
599
+ "model.vision.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
600
+ "model.vision.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
601
+ "model.vision.blocks.5.ln1.bias": "model-00001-of-00004.safetensors",
602
+ "model.vision.blocks.5.ln1.weight": "model-00001-of-00004.safetensors",
603
+ "model.vision.blocks.5.ln2.bias": "model-00001-of-00004.safetensors",
604
+ "model.vision.blocks.5.ln2.weight": "model-00001-of-00004.safetensors",
605
+ "model.vision.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
606
+ "model.vision.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
607
+ "model.vision.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
608
+ "model.vision.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
609
+ "model.vision.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
610
+ "model.vision.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
611
+ "model.vision.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
612
+ "model.vision.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
613
+ "model.vision.blocks.6.ln1.bias": "model-00001-of-00004.safetensors",
614
+ "model.vision.blocks.6.ln1.weight": "model-00001-of-00004.safetensors",
615
+ "model.vision.blocks.6.ln2.bias": "model-00001-of-00004.safetensors",
616
+ "model.vision.blocks.6.ln2.weight": "model-00001-of-00004.safetensors",
617
+ "model.vision.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
618
+ "model.vision.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
619
+ "model.vision.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
620
+ "model.vision.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
621
+ "model.vision.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
622
+ "model.vision.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
623
+ "model.vision.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
624
+ "model.vision.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
625
+ "model.vision.blocks.7.ln1.bias": "model-00001-of-00004.safetensors",
626
+ "model.vision.blocks.7.ln1.weight": "model-00001-of-00004.safetensors",
627
+ "model.vision.blocks.7.ln2.bias": "model-00001-of-00004.safetensors",
628
+ "model.vision.blocks.7.ln2.weight": "model-00001-of-00004.safetensors",
629
+ "model.vision.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
630
+ "model.vision.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
631
+ "model.vision.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
632
+ "model.vision.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
633
+ "model.vision.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
634
+ "model.vision.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
635
+ "model.vision.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
636
+ "model.vision.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
637
+ "model.vision.blocks.8.ln1.bias": "model-00001-of-00004.safetensors",
638
+ "model.vision.blocks.8.ln1.weight": "model-00001-of-00004.safetensors",
639
+ "model.vision.blocks.8.ln2.bias": "model-00001-of-00004.safetensors",
640
+ "model.vision.blocks.8.ln2.weight": "model-00001-of-00004.safetensors",
641
+ "model.vision.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
642
+ "model.vision.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
643
+ "model.vision.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
644
+ "model.vision.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
645
+ "model.vision.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
646
+ "model.vision.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
647
+ "model.vision.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
648
+ "model.vision.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
649
+ "model.vision.blocks.9.ln1.bias": "model-00001-of-00004.safetensors",
650
+ "model.vision.blocks.9.ln1.weight": "model-00001-of-00004.safetensors",
651
+ "model.vision.blocks.9.ln2.bias": "model-00001-of-00004.safetensors",
652
+ "model.vision.blocks.9.ln2.weight": "model-00001-of-00004.safetensors",
653
+ "model.vision.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
654
+ "model.vision.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
655
+ "model.vision.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
656
+ "model.vision.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
657
+ "model.vision.patch_emb.bias": "model-00001-of-00004.safetensors",
658
+ "model.vision.patch_emb.weight": "model-00001-of-00004.safetensors",
659
+ "model.vision.pos_emb": "model-00001-of-00004.safetensors",
660
+ "model.vision.post_ln.bias": "model-00001-of-00004.safetensors",
661
+ "model.vision.post_ln.weight": "model-00001-of-00004.safetensors",
662
+ "model.vision.proj_mlp.fc1.bias": "model-00001-of-00004.safetensors",
663
+ "model.vision.proj_mlp.fc1.weight": "model-00001-of-00004.safetensors",
664
+ "model.vision.proj_mlp.fc2.bias": "model-00001-of-00004.safetensors",
665
+ "model.vision.proj_mlp.fc2.weight": "model-00001-of-00004.safetensors"
666
+ }
667
+ }
moondream.py ADDED
@@ -0,0 +1,1070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
+ from PIL import Image
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
+ from torch.nn.attention.flex_attention import create_block_mask
10
+
11
+ from .config import MoondreamConfig
12
+ from .image_crops import reconstruct_from_crops
13
+ from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
14
+ from .text import build_text_model, text_encoder, lm_head, text_decoder
15
+ from .region import (
16
+ decode_coordinate,
17
+ encode_coordinate,
18
+ decode_size,
19
+ encode_size,
20
+ encode_spatial_refs,
21
+ SpatialRefs,
22
+ )
23
+ from .layers import QuantizedLinear
24
+ from .lora import variant_state_dict
25
+ from .utils import remove_outlier_points
26
+
27
+ ImageEncodingSettings = TypedDict(
28
+ "ImageEncodingSettings",
29
+ {"variant": str},
30
+ total=False,
31
+ )
32
+
33
+ TextSamplingSettings = TypedDict(
34
+ "TextSamplingSettings",
35
+ {
36
+ "max_tokens": int,
37
+ "temperature": float,
38
+ "top_p": float,
39
+ "variant": str,
40
+ },
41
+ total=False,
42
+ )
43
+
44
+ ObjectSamplingSettings = TypedDict(
45
+ "ObjectSamplingSettings",
46
+ {"max_objects": int, "variant": str},
47
+ total=False,
48
+ )
49
+
50
+
51
+ DEFAULT_MAX_TOKENS = 768
52
+ DEFAULT_TEMPERATURE = 0.5
53
+ DEFAULT_TOP_P = 0.9
54
+ DEFAULT_MAX_OBJECTS = 150
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class EncodedImage:
59
+ pos: int
60
+ caches: List[Tuple[torch.Tensor, torch.Tensor]]
61
+
62
+
63
+ class KVCache(nn.Module):
64
+
65
+ def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
66
+ super().__init__()
67
+ cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
68
+ self.register_buffer(
69
+ "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
70
+ )
71
+ self.register_buffer(
72
+ "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
73
+ )
74
+
75
+ def update(self, pos_ids, k, v):
76
+ kout, vout = self.k_cache, self.v_cache
77
+ kout[:, :, pos_ids, :] = k
78
+ vout[:, :, pos_ids, :] = v
79
+ return kout, vout
80
+
81
+
82
+ def causal_mask(b, h, q_idx, kv_idx):
83
+ return q_idx >= kv_idx
84
+
85
+
86
+ def get_mask_mod(mask_mod, offset):
87
+ def _mask_mod(b, h, q, kv):
88
+ return mask_mod(b, h, q + offset, kv)
89
+
90
+ return _mask_mod
91
+
92
+
93
+ class MoondreamModel(nn.Module):
94
+
95
+ def __init__(
96
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
97
+ ):
98
+ super().__init__()
99
+ self.config = config
100
+
101
+ self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
102
+ self.vision = build_vision_model(config.vision, dtype)
103
+ self.text = build_text_model(config.text, dtype)
104
+
105
+ # Region Model
106
+ linear_cls = (
107
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
108
+ )
109
+ self.region = nn.ModuleDict(
110
+ {
111
+ "coord_encoder": linear_cls(
112
+ config.region.coord_feat_dim, config.region.dim, dtype=dtype
113
+ ),
114
+ "coord_decoder": linear_cls(
115
+ config.region.dim, config.region.coord_out_dim, dtype=dtype
116
+ ),
117
+ "size_encoder": linear_cls(
118
+ config.region.size_feat_dim, config.region.dim, dtype=dtype
119
+ ),
120
+ "size_decoder": linear_cls(
121
+ config.region.dim, config.region.size_out_dim, dtype=dtype
122
+ ),
123
+ }
124
+ )
125
+ self.region.coord_features = nn.Parameter(
126
+ torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
127
+ )
128
+ self.region.size_features = nn.Parameter(
129
+ torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
130
+ )
131
+
132
+ attn_mask = torch.tril(
133
+ torch.ones(
134
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
135
+ )
136
+ )
137
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
138
+ prefix_attn_len = 1 + patch_w**2
139
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
140
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
141
+
142
+ self.use_flex_decoding = True
143
+ self._causal_block_mask = None
144
+ self._point_gen_indices = None
145
+
146
+ # Initialize KV caches.
147
+ if setup_caches:
148
+ self._setup_caches()
149
+
150
+ @property
151
+ def causal_block_mask(self):
152
+ # The things we do to deal with ZeroGPU...
153
+ if self._causal_block_mask is None:
154
+ self._causal_block_mask = create_block_mask(
155
+ causal_mask,
156
+ B=None,
157
+ H=None,
158
+ Q_LEN=self.config.text.max_context,
159
+ KV_LEN=self.config.text.max_context,
160
+ )
161
+ return self._causal_block_mask
162
+
163
+ @property
164
+ def point_gen_indices(self):
165
+ if self._point_gen_indices is None:
166
+ self._point_gen_indices = torch.tensor(
167
+ [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id],
168
+ device=self.device,
169
+ )
170
+ return self._point_gen_indices
171
+
172
+ def _setup_caches(self):
173
+ c = self.config.text
174
+ for b in self.text.blocks:
175
+ b.kv_cache = KVCache(
176
+ c.n_heads,
177
+ c.n_kv_heads,
178
+ c.max_context,
179
+ c.dim,
180
+ device=self.device,
181
+ dtype=self.vision.pos_emb.dtype,
182
+ )
183
+
184
+ @property
185
+ def device(self):
186
+ return self.vision.pos_emb.device
187
+
188
+ def _vis_enc(self, x: torch.Tensor):
189
+ return vision_encoder(x, self.vision, self.config.vision)
190
+
191
+ def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
192
+ return vision_projection(g, r, self.vision, self.config.vision)
193
+
194
+ def _prefill(
195
+ self,
196
+ x: torch.Tensor,
197
+ attn_mask: torch.Tensor,
198
+ pos_ids: torch.Tensor,
199
+ lora: Optional[torch.Tensor],
200
+ ):
201
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
202
+
203
+ def _decode_one_tok(
204
+ self,
205
+ x: torch.Tensor,
206
+ attn_mask: torch.Tensor,
207
+ pos_ids: torch.Tensor,
208
+ lora: Optional[torch.Tensor],
209
+ lm_head_indices: Optional[torch.Tensor] = None,
210
+ ):
211
+ if self.use_flex_decoding:
212
+ torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape")
213
+ block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0]
214
+ mask = self.causal_block_mask[:, :, block_index]
215
+ mask.seq_lengths = (1, mask.seq_lengths[1])
216
+ mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0])
217
+ else:
218
+ mask = None
219
+
220
+ hidden = text_decoder(
221
+ x,
222
+ self.text,
223
+ attn_mask,
224
+ pos_ids,
225
+ self.config.text,
226
+ lora=lora,
227
+ flex_block_mask_slice=mask,
228
+ )
229
+ logits = lm_head(hidden, self.text, indices=lm_head_indices)
230
+ return logits, hidden
231
+
232
+ def compile(self):
233
+ for module in self.modules():
234
+ if isinstance(module, QuantizedLinear):
235
+ module.unpack()
236
+
237
+ # Initialize lazy properties to avoid first-call overhead
238
+ self.causal_block_mask
239
+ self.point_gen_indices
240
+
241
+ # TODO: vision_projection and _prefill is not being compiled
242
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
243
+ self._decode_one_tok = torch.compile(
244
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
245
+ )
246
+
247
+ # Warm up compiled methods with dummy forward passes
248
+ device = self.device
249
+ dtype = self.vision.pos_emb.dtype
250
+ with torch.no_grad():
251
+ # Warmup vision encoder
252
+ dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype)
253
+ self._vis_enc(dummy_crops)
254
+
255
+ # Warmup _decode_one_tok (both normal and point generation modes)
256
+ dummy_emb = torch.randn(
257
+ 1, 1, self.config.text.dim, device=device, dtype=dtype
258
+ )
259
+ dummy_mask = torch.ones(
260
+ 1, 1, self.config.text.max_context, device=device, dtype=torch.bool
261
+ )
262
+ dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long)
263
+ self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None)
264
+ self._decode_one_tok(
265
+ dummy_emb,
266
+ dummy_mask,
267
+ dummy_pos_ids,
268
+ None,
269
+ lm_head_indices=self.point_gen_indices,
270
+ )
271
+
272
+ def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
273
+ all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
274
+
275
+ torch._dynamo.mark_dynamic(all_crops, 0)
276
+
277
+ outputs = self._vis_enc(all_crops)
278
+
279
+ global_features = outputs[0]
280
+ local_features = outputs[1:].view(
281
+ -1,
282
+ self.config.vision.enc_n_layers,
283
+ self.config.vision.enc_n_layers,
284
+ self.config.vision.enc_dim,
285
+ )
286
+
287
+ reconstructed = reconstruct_from_crops(
288
+ local_features,
289
+ tiling,
290
+ patch_size=1,
291
+ overlap_margin=self.config.vision.overlap_margin,
292
+ )
293
+
294
+ return self._vis_proj(global_features, reconstructed)
295
+
296
+ def encode_image(
297
+ self,
298
+ image: Union[Image.Image, EncodedImage],
299
+ settings: Optional[ImageEncodingSettings] = None,
300
+ ) -> EncodedImage:
301
+ if isinstance(image, EncodedImage):
302
+ return image
303
+ elif not isinstance(image, Image.Image):
304
+ raise ValueError("image must be a PIL Image or EncodedImage")
305
+
306
+ lora = (
307
+ variant_state_dict(settings["variant"], device=self.device)
308
+ if settings is not None and "variant" in settings
309
+ else None
310
+ )
311
+
312
+ # Run through text model in addition to the vision encoder, to minimize
313
+ # re-computation if multiple queries are performed on this image.
314
+ with torch.inference_mode():
315
+ img_emb = self._run_vision_encoder(image)
316
+ bos_emb = text_encoder(
317
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
318
+ self.text,
319
+ )
320
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
321
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
322
+ pos_ids = torch.arange(
323
+ inputs_embeds.size(1), dtype=torch.long, device=self.device
324
+ )
325
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
326
+
327
+ return EncodedImage(
328
+ pos=inputs_embeds.size(1),
329
+ caches=[
330
+ (
331
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
332
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
333
+ )
334
+ for b in self.text.blocks
335
+ ],
336
+ )
337
+
338
+ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
339
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
340
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
341
+ mask = probs_sum - probs_sort > top_p
342
+ probs_sort[mask] = 0.0
343
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
344
+ next_probs = torch.zeros_like(probs)
345
+ next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
346
+ return next_probs
347
+
348
+ def _prefill_prompt(
349
+ self,
350
+ prompt_tokens: torch.Tensor,
351
+ pos: int,
352
+ temperature: float,
353
+ top_p: float,
354
+ spatial_refs: Optional[SpatialRefs] = None,
355
+ attn_mask: Optional[torch.Tensor] = None,
356
+ lora: Optional[dict] = None,
357
+ ):
358
+ with torch.inference_mode():
359
+ prompt_emb = text_encoder(prompt_tokens, self.text)
360
+
361
+ if spatial_refs:
362
+ encoded_refs = encode_spatial_refs(spatial_refs, self.region)
363
+ prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
364
+ encoded_refs["coords"]
365
+ )
366
+ if encoded_refs["sizes"] is not None:
367
+ prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
368
+ encoded_refs["sizes"]
369
+ )
370
+
371
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
372
+
373
+ if attn_mask is None:
374
+ attn_mask = self.attn_mask
375
+
376
+ mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
377
+ pos_ids = torch.arange(
378
+ pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device
379
+ )
380
+ hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
381
+ logits_BV = lm_head(hidden_BC, self.text)
382
+
383
+ if temperature == 0:
384
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
385
+ else:
386
+ probs = torch.softmax(logits_BV / temperature, dim=-1)
387
+ probs = self._apply_top_p(probs, top_p)
388
+ next_token = torch.multinomial(probs, num_samples=1)
389
+
390
+ pos = pos + prompt_emb.size(1)
391
+ return logits_BV, hidden_BC, next_token, pos
392
+
393
+ def _generate_reasoning(
394
+ self,
395
+ prompt_tokens,
396
+ pos,
397
+ settings: Optional[TextSamplingSettings] = None,
398
+ spatial_refs: Optional[SpatialRefs] = None,
399
+ attn_mask: Optional[torch.Tensor] = None,
400
+ ) -> Tuple[int, str, List[dict]]:
401
+ max_tokens = (
402
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
403
+ if settings
404
+ else DEFAULT_MAX_TOKENS
405
+ )
406
+ temperature = (
407
+ settings.get("temperature", DEFAULT_TEMPERATURE)
408
+ if settings
409
+ else DEFAULT_TEMPERATURE
410
+ )
411
+ lora = (
412
+ variant_state_dict(settings["variant"], device=self.device)
413
+ if settings is not None and "variant" in settings
414
+ else None
415
+ )
416
+
417
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
418
+ eos_id = self.config.tokenizer.answer_id
419
+
420
+ _, last_hidden_BC, next_token, pos = self._prefill_prompt(
421
+ prompt_tokens,
422
+ pos,
423
+ temperature,
424
+ top_p,
425
+ spatial_refs,
426
+ attn_mask=attn_mask,
427
+ lora=lora,
428
+ )
429
+
430
+ text_token_chunks = [[]]
431
+ grounding_chunks = [[]]
432
+
433
+ mask = torch.zeros(
434
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
435
+ )
436
+ mask[:, :, :pos] = 1
437
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
438
+ generated_tokens = 0
439
+
440
+ while (
441
+ next_token_id := next_token.item()
442
+ ) != eos_id and generated_tokens < max_tokens:
443
+ if (
444
+ next_token_id == self.config.tokenizer.start_ground_points_id
445
+ or next_token_id == self.config.tokenizer.end_ground_id
446
+ ):
447
+ text_token_chunks.append([])
448
+ grounding_chunks.append([])
449
+
450
+ text_token_chunks[-1].append(next_token_id)
451
+
452
+ with torch.inference_mode():
453
+ if next_token_id == self.config.tokenizer.coord_id:
454
+ coord_logits = decode_coordinate(last_hidden_BC, self.region)
455
+ coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
456
+ grounding_chunks[-1].append(coord.item())
457
+
458
+ next_emb = encode_coordinate(
459
+ coord.to(dtype=coord_logits.dtype), self.region
460
+ ).unsqueeze(0)
461
+ else:
462
+ next_emb = text_encoder(next_token, self.text)
463
+
464
+ mask[:, :, pos], pos_ids[0] = 1, pos
465
+
466
+ logits_BV, last_hidden_BC = self._decode_one_tok(
467
+ next_emb, mask, pos_ids, lora
468
+ )
469
+ logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
470
+ logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
471
+
472
+ pos += 1
473
+
474
+ if temperature == 0:
475
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
476
+ else:
477
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
478
+ probs = self._apply_top_p(probs, top_p)
479
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
480
+
481
+ generated_tokens += 1
482
+
483
+ text_chunks = [
484
+ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
485
+ ]
486
+ text = "".join(text_chunks)
487
+
488
+ start_idx = 0
489
+ grounding = []
490
+ for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
491
+ if len(grounding_chunk) > 1:
492
+ points = []
493
+ for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
494
+ points.append((grounding_chunk[i], grounding_chunk[i + 1]))
495
+ grounding.append(
496
+ {
497
+ "start_idx": start_idx,
498
+ "end_idx": start_idx + len(text_chunk),
499
+ "points": points,
500
+ }
501
+ )
502
+ start_idx += len(text_chunk)
503
+
504
+ return pos, text, grounding
505
+
506
+ def _generate_answer(
507
+ self,
508
+ prompt_tokens: torch.Tensor,
509
+ pos: int,
510
+ settings: Optional[TextSamplingSettings] = None,
511
+ spatial_refs: Optional[SpatialRefs] = None,
512
+ eos_id: Optional[int] = None,
513
+ attn_mask: Optional[torch.Tensor] = None,
514
+ ):
515
+ max_tokens = (
516
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
517
+ if settings
518
+ else DEFAULT_MAX_TOKENS
519
+ )
520
+ temperature = (
521
+ settings.get("temperature", DEFAULT_TEMPERATURE)
522
+ if settings
523
+ else DEFAULT_TEMPERATURE
524
+ )
525
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
526
+ eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
527
+ lora = (
528
+ variant_state_dict(settings["variant"], device=self.device)
529
+ if settings is not None and "variant" in settings
530
+ else None
531
+ )
532
+
533
+ _, _, next_token, pos = self._prefill_prompt(
534
+ prompt_tokens,
535
+ pos,
536
+ temperature,
537
+ top_p,
538
+ spatial_refs,
539
+ attn_mask=attn_mask,
540
+ lora=lora,
541
+ )
542
+
543
+ def generator(next_token, pos):
544
+ mask = torch.zeros(
545
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
546
+ )
547
+ mask[:, :, :pos] = 1
548
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
549
+ generated_tokens = 0
550
+
551
+ # For properly handling token streaming with Unicode
552
+ token_cache = []
553
+ print_len = 0
554
+
555
+ while (
556
+ next_token_id := next_token.item()
557
+ ) != eos_id and generated_tokens < max_tokens:
558
+ # Add token to our cache
559
+ token_cache.append(next_token_id)
560
+
561
+ # Decode all tokens collected so far
562
+ text = self.tokenizer.decode(token_cache)
563
+
564
+ # After a newline, we flush the cache completely
565
+ if text.endswith("\n"):
566
+ printable_text = text[print_len:]
567
+ token_cache = []
568
+ print_len = 0
569
+ if printable_text:
570
+ yield printable_text
571
+ # If the last token is a CJK character, we can safely print it
572
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
573
+ printable_text = text[print_len:]
574
+ print_len += len(printable_text)
575
+ if printable_text:
576
+ yield printable_text
577
+ # Otherwise, only yield up to the last space to avoid cutting words
578
+ else:
579
+ last_space_idx = text.rfind(" ", print_len)
580
+ if last_space_idx >= print_len:
581
+ printable_text = text[print_len : last_space_idx + 1]
582
+ print_len += len(printable_text)
583
+ if printable_text:
584
+ yield printable_text
585
+
586
+ with torch.inference_mode():
587
+ next_emb = text_encoder(next_token, self.text)
588
+ mask[:, :, pos], pos_ids[0] = 1, pos
589
+
590
+ logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
591
+ logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
592
+
593
+ pos += 1
594
+
595
+ if temperature == 0:
596
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
597
+ 1
598
+ ) # (1, 1)
599
+ else:
600
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
601
+ probs = self._apply_top_p(probs, top_p)
602
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
603
+
604
+ generated_tokens += 1
605
+
606
+ # Flush any remaining text in the cache
607
+ if token_cache:
608
+ text = self.tokenizer.decode(token_cache)
609
+ printable_text = text[print_len:]
610
+ if printable_text:
611
+ yield printable_text
612
+
613
+ return generator(next_token, pos)
614
+
615
+ def query(
616
+ self,
617
+ image: Optional[Union[Image.Image, EncodedImage]] = None,
618
+ question: str = None,
619
+ reasoning: bool = True,
620
+ spatial_refs: Optional[SpatialRefs] = None,
621
+ stream: bool = False,
622
+ settings: Optional[TextSamplingSettings] = None,
623
+ ):
624
+ if self.config.tokenizer.templates["query"] is None:
625
+ raise NotImplementedError("Model does not support querying.")
626
+
627
+ if question is None:
628
+ raise ValueError("question must be provided.")
629
+
630
+ if spatial_refs and image is None:
631
+ raise ValueError("spatial_refs can only be used with an image.")
632
+
633
+ attn_mask = self.attn_mask
634
+ if image is not None:
635
+ image = self.encode_image(image, settings)
636
+ self.load_encoded_image(image)
637
+ pos = image.pos
638
+ prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
639
+ else:
640
+ self._setup_caches()
641
+ pos = 0
642
+ prompt_toks = [
643
+ self.config.tokenizer.bos_id
644
+ ] + self.config.tokenizer.templates["query"]["prefix"]
645
+ max_context = self.config.text.max_context
646
+ attn_mask = torch.tril(
647
+ torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
648
+ ).to(self.device)
649
+
650
+ spatial_toks = []
651
+ if spatial_refs:
652
+ for ref in spatial_refs:
653
+ coord_id = self.config.tokenizer.coord_id
654
+ size_id = self.config.tokenizer.size_id
655
+ if len(ref) == 2:
656
+ spatial_toks.extend([coord_id, coord_id])
657
+ else:
658
+ spatial_toks.extend([coord_id, coord_id, size_id])
659
+
660
+ prompt_tokens = [
661
+ prompt_toks + spatial_toks + self.tokenizer.encode(question).ids
662
+ ]
663
+
664
+ if reasoning:
665
+ prompt_tokens[0] += [self.config.tokenizer.thinking_id]
666
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
667
+ pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
668
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
669
+ )
670
+ prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
671
+ reasoning_dict = {
672
+ "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
673
+ }
674
+ else:
675
+ prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
676
+ reasoning_dict = {}
677
+
678
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
679
+
680
+ def generator():
681
+ for token in self._generate_answer(
682
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
683
+ ):
684
+ yield token
685
+
686
+ if stream:
687
+ return {**reasoning_dict, "answer": generator()}
688
+ else:
689
+ return {**reasoning_dict, "answer": "".join(list(generator()))}
690
+
691
+ def load_encoded_image(self, encoded_image: EncodedImage):
692
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
693
+ b.kv_cache.k_cache[:, :, : k.size(2), :] = k
694
+ b.kv_cache.v_cache[:, :, : v.size(2), :] = v
695
+
696
+ def caption(
697
+ self,
698
+ image: Union[Image.Image, EncodedImage],
699
+ length: Literal["normal", "short", "long"] = "normal",
700
+ stream: bool = False,
701
+ settings: Optional[TextSamplingSettings] = None,
702
+ ):
703
+ if self.config.tokenizer.templates["caption"] is None:
704
+ raise NotImplementedError("Model does not support captioning.")
705
+ if length not in self.config.tokenizer.templates["caption"]:
706
+ raise ValueError(f"Model does not support caption length '{length}'.")
707
+
708
+ image = self.encode_image(image, settings)
709
+ self.load_encoded_image(image)
710
+
711
+ prompt_tokens = torch.tensor(
712
+ [self.config.tokenizer.templates["caption"][length]], device=self.device
713
+ )
714
+
715
+ def generator():
716
+ for token in self._generate_answer(prompt_tokens, image.pos, settings):
717
+ yield token
718
+
719
+ if stream:
720
+ return {"caption": generator()}
721
+ else:
722
+ return {"caption": "".join(list(generator()))}
723
+
724
+ def _generate_points(
725
+ self,
726
+ hidden: torch.Tensor,
727
+ next_token: torch.Tensor,
728
+ pos: int,
729
+ include_size: bool = True,
730
+ max_objects: int = DEFAULT_MAX_OBJECTS,
731
+ lora: Optional[dict] = None,
732
+ ):
733
+ out = []
734
+ mask = torch.zeros(
735
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
736
+ )
737
+ mask[:, :, :pos] = 1
738
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
739
+
740
+ with torch.inference_mode():
741
+ while (
742
+ next_token.item() != self.config.tokenizer.eos_id
743
+ and len(out) < max_objects
744
+ ):
745
+ x_logits = decode_coordinate(hidden, self.region)
746
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
747
+ next_emb = encode_coordinate(
748
+ x_center.to(dtype=x_logits.dtype), self.region
749
+ ).unsqueeze(0)
750
+
751
+ # Decode y-coordinate
752
+ mask[:, :, pos], pos_ids[0] = 1, pos
753
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
754
+ pos += 1
755
+ y_logits = decode_coordinate(hidden, self.region)
756
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
757
+ next_emb = encode_coordinate(
758
+ y_center.to(dtype=y_logits.dtype), self.region
759
+ ).unsqueeze(0)
760
+
761
+ # Decode size
762
+ if include_size:
763
+ mask[:, :, pos], pos_ids[0] = 1, pos
764
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
765
+ pos += 1
766
+ size_logits = decode_size(hidden, self.region)
767
+
768
+ # Get bin indices from the logits
769
+ w_bin = torch.argmax(size_logits[0], dim=-1)
770
+ h_bin = torch.argmax(size_logits[1], dim=-1)
771
+
772
+ # Convert from bin indices to actual size values using the inverse of the log-scale mapping
773
+ # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
774
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
775
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
776
+
777
+ next_emb = (
778
+ encode_size(
779
+ torch.tensor(
780
+ [w, h], device=self.device, dtype=size_logits.dtype
781
+ ),
782
+ self.region,
783
+ )
784
+ .unsqueeze(0)
785
+ .unsqueeze(0)
786
+ )
787
+
788
+ # Add object
789
+ out.append(
790
+ {
791
+ "x_min": x_center.item() - w.item() / 2,
792
+ "y_min": y_center.item() - h.item() / 2,
793
+ "x_max": x_center.item() + w.item() / 2,
794
+ "y_max": y_center.item() + h.item() / 2,
795
+ }
796
+ )
797
+ else:
798
+ out.append({"x": x_center.item(), "y": y_center.item()})
799
+
800
+ # Decode next token (x-coordinate, or eos)
801
+ mask[:, :, pos], pos_ids[0] = 1, pos
802
+ logits, hidden = self._decode_one_tok(
803
+ next_emb,
804
+ mask,
805
+ pos_ids,
806
+ lora,
807
+ lm_head_indices=self.point_gen_indices,
808
+ )
809
+ pos += 1
810
+ # Map back: index 0 -> coord_id, index 1 -> eos_id
811
+ next_token_idx = torch.argmax(logits, dim=-1)
812
+ next_token = self.point_gen_indices[next_token_idx]
813
+
814
+ return out
815
+
816
+ def detect(
817
+ self,
818
+ image: Union[Image.Image, EncodedImage],
819
+ object: str,
820
+ settings: Optional[ObjectSamplingSettings] = None,
821
+ ):
822
+ if self.config.tokenizer.templates["detect"] is None:
823
+ raise NotImplementedError("Model does not support object detection.")
824
+
825
+ image = self.encode_image(image, settings)
826
+ self.load_encoded_image(image)
827
+
828
+ prompt_tokens = torch.tensor(
829
+ [
830
+ self.config.tokenizer.templates["detect"]["prefix"]
831
+ + self.tokenizer.encode(" " + object).ids
832
+ + self.config.tokenizer.templates["detect"]["suffix"]
833
+ ],
834
+ device=self.device,
835
+ )
836
+
837
+ lora = (
838
+ variant_state_dict(settings["variant"], device=self.device)
839
+ if settings is not None and "variant" in settings
840
+ else None
841
+ )
842
+
843
+ _, hidden, next_token, pos = self._prefill_prompt(
844
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
845
+ )
846
+ hidden = hidden[:, -1:, :]
847
+
848
+ max_objects = (
849
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
850
+ if settings
851
+ else DEFAULT_MAX_OBJECTS
852
+ )
853
+ objects = self._generate_points(
854
+ hidden,
855
+ next_token,
856
+ pos,
857
+ include_size=True,
858
+ max_objects=max_objects,
859
+ lora=lora,
860
+ )
861
+
862
+ return {"objects": objects}
863
+
864
+ def point(
865
+ self,
866
+ image: Union[Image.Image, EncodedImage],
867
+ object: str,
868
+ settings: Optional[ObjectSamplingSettings] = None,
869
+ ):
870
+ if self.config.tokenizer.templates["point"] is None:
871
+ raise NotImplementedError("Model does not support pointing.")
872
+
873
+ image = self.encode_image(image, settings)
874
+ self.load_encoded_image(image)
875
+
876
+ prompt_tokens = torch.tensor(
877
+ [
878
+ self.config.tokenizer.templates["point"]["prefix"]
879
+ + self.tokenizer.encode(" " + object).ids
880
+ + self.config.tokenizer.templates["point"]["suffix"]
881
+ ],
882
+ device=self.device,
883
+ )
884
+
885
+ lora = (
886
+ variant_state_dict(settings["variant"], device=self.device)
887
+ if settings is not None and "variant" in settings
888
+ else None
889
+ )
890
+
891
+ _, hidden, next_token, pos = self._prefill_prompt(
892
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
893
+ )
894
+ hidden = hidden[:, -1:, :]
895
+
896
+ max_objects = (
897
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
898
+ if settings
899
+ else DEFAULT_MAX_OBJECTS
900
+ )
901
+ objects = self._generate_points(
902
+ hidden,
903
+ next_token,
904
+ pos,
905
+ include_size=False,
906
+ max_objects=max_objects,
907
+ lora=lora,
908
+ )
909
+
910
+ return {"points": objects}
911
+
912
+ def _detect_gaze(
913
+ self,
914
+ image: EncodedImage,
915
+ source: Tuple[float, float],
916
+ force_detect: bool = False,
917
+ ):
918
+ with torch.inference_mode():
919
+ before_emb = text_encoder(
920
+ torch.tensor(
921
+ [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
922
+ ),
923
+ self.text,
924
+ )
925
+ after_emb = text_encoder(
926
+ torch.tensor(
927
+ [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
928
+ ),
929
+ self.text,
930
+ )
931
+ x_emb = encode_coordinate(
932
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
933
+ self.region,
934
+ )
935
+ y_emb = encode_coordinate(
936
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
937
+ self.region,
938
+ )
939
+
940
+ prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
941
+
942
+ self.load_encoded_image(image)
943
+
944
+ mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
945
+ pos_ids = torch.arange(
946
+ image.pos,
947
+ image.pos + prompt_emb.size(1),
948
+ dtype=torch.long,
949
+ device=self.device,
950
+ )
951
+ hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
952
+ logits = lm_head(hidden, self.text)
953
+ next_token = torch.argmax(logits, dim=-1)
954
+ pos = image.pos + prompt_emb.size(1)
955
+ hidden = hidden[:, -1:, :]
956
+
957
+ if force_detect:
958
+ next_token = torch.tensor([[0]], device=self.device)
959
+
960
+ if next_token.item() == self.config.tokenizer.eos_id:
961
+ return None
962
+
963
+ gaze = self._generate_points(
964
+ hidden, next_token, pos, include_size=False, max_objects=1
965
+ )
966
+ return gaze[0]
967
+
968
+ def detect_gaze(
969
+ self,
970
+ image: Union[Image.Image, EncodedImage],
971
+ eye: Optional[Tuple[float, float]] = None,
972
+ face: Optional[Dict[str, float]] = None,
973
+ unstable_settings: Dict[str, Any] = {},
974
+ ):
975
+ if "force_detect" in unstable_settings:
976
+ force_detect = unstable_settings["force_detect"]
977
+ else:
978
+ force_detect = False
979
+
980
+ if "prioritize_accuracy" in unstable_settings:
981
+ prioritize_accuracy = unstable_settings["prioritize_accuracy"]
982
+ else:
983
+ prioritize_accuracy = False
984
+
985
+ if not prioritize_accuracy:
986
+ if eye is None:
987
+ raise ValueError("eye must be provided when prioritize_accuracy=False")
988
+ image = self.encode_image(image)
989
+ return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
990
+ else:
991
+ if (
992
+ not isinstance(image, Image.Image)
993
+ and "flip_enc_img" not in unstable_settings
994
+ ):
995
+ raise ValueError(
996
+ "image must be a PIL Image when prioritize_accuracy=True, "
997
+ "or flip_enc_img must be provided"
998
+ )
999
+ if face is None:
1000
+ raise ValueError("face must be provided when prioritize_accuracy=True")
1001
+
1002
+ encoded_image = self.encode_image(image)
1003
+ if (
1004
+ isinstance(image, Image.Image)
1005
+ and "flip_enc_img" not in unstable_settings
1006
+ ):
1007
+ flipped_pil = image.copy()
1008
+ flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
1009
+ encoded_flipped_image = self.encode_image(flipped_pil)
1010
+ else:
1011
+ encoded_flipped_image = unstable_settings["flip_enc_img"]
1012
+
1013
+ N = 10
1014
+
1015
+ detections = [
1016
+ self._detect_gaze(
1017
+ encoded_image,
1018
+ (
1019
+ random.uniform(face["x_min"], face["x_max"]),
1020
+ random.uniform(face["y_min"], face["y_max"]),
1021
+ ),
1022
+ force_detect=force_detect,
1023
+ )
1024
+ for _ in range(N)
1025
+ ]
1026
+ detections = [
1027
+ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
1028
+ ]
1029
+ flipped_detections = [
1030
+ self._detect_gaze(
1031
+ encoded_flipped_image,
1032
+ (
1033
+ 1 - random.uniform(face["x_min"], face["x_max"]),
1034
+ random.uniform(face["y_min"], face["y_max"]),
1035
+ ),
1036
+ force_detect=force_detect,
1037
+ )
1038
+ for _ in range(N)
1039
+ ]
1040
+ detections.extend(
1041
+ [
1042
+ (1 - gaze["x"], gaze["y"])
1043
+ for gaze in flipped_detections
1044
+ if gaze is not None
1045
+ ]
1046
+ )
1047
+
1048
+ if len(detections) < N:
1049
+ return {"gaze": None}
1050
+
1051
+ detections = remove_outlier_points(detections)
1052
+ mean_gaze = (
1053
+ sum(gaze[0] for gaze in detections) / len(detections),
1054
+ sum(gaze[1] for gaze in detections) / len(detections),
1055
+ )
1056
+
1057
+ return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
1058
+
1059
+
1060
+ def _is_cjk_char(cp):
1061
+ """Checks whether CP is the codepoint of a CJK character."""
1062
+ # This defines a "chinese character" as anything in the CJK Unicode block:
1063
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
1064
+ if (
1065
+ (cp >= 0x4E00 and cp <= 0x9FFF)
1066
+ or (cp >= 0x3400 and cp <= 0x4DBF)
1067
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
1068
+ ):
1069
+ return True
1070
+ return False
open_vocab_detect.png ADDED

Git LFS Details

  • SHA256: ea98c762d53ff5b619b51c538fda3e9dce2c895d32bcb38675591e79b55e8054
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB
point_count.png ADDED

Git LFS Details

  • SHA256: fa1d46a37291c46a006abd3579a241fe0fa5995d735883e90f5d5210a843b397
  • Pointer size: 132 Bytes
  • Size of remote file: 2 MB
region.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from typing import List, Tuple, Union
6
+
7
+ SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
8
+
9
+
10
+ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ Applies Fourier feature mapping to input tensor x using frequency matrix w. This
13
+ projects inputs through sinusoidal functions to create higher dimensional features
14
+ that help mitigate spectral bias - the tendency of neural networks to learn
15
+ low-frequency functions more easily than high-frequency ones. By explicitly
16
+ mapping inputs to higher frequencies through sin/cos transformations, we enable
17
+ better learning of fine details and higher frequency patterns.
18
+
19
+ Args:
20
+ x: Input tensor to transform
21
+ w: Matrix of frequencies for the Fourier features transformation
22
+
23
+ Returns:
24
+ Concatenated cosine and sine transformed features as a tensor
25
+ """
26
+ f = 2 * math.pi * x @ w
27
+ return torch.cat([f.cos(), f.sin()], dim=-1)
28
+
29
+
30
+ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
31
+ """
32
+ Takes as input a tensor containing a single float coordinate value (x or y)
33
+ and encodes it into hidden states for input to the text model.
34
+
35
+ Args:
36
+ coord: Tensor with single float coordinate value
37
+
38
+ Returns:
39
+ Encoded hidden states tensor for input to text model
40
+ """
41
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
42
+
43
+
44
+ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
45
+ """
46
+ Takes as input the last hidden state from the text model and outputs a single logit
47
+ representing either an x or y coordinate prediction.
48
+
49
+ Args:
50
+ hidden_state: The final hidden state tensor from the text model.
51
+
52
+ Returns:
53
+ A single logit representing the predicted coordinate value (x or y)
54
+ """
55
+ return w.coord_decoder(hidden_state)
56
+
57
+
58
+ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
59
+ """
60
+ Takes a tensor containing width and height values and encodes them into
61
+ hidden states for input to the text model.
62
+
63
+ Args:
64
+ size: Tensor with two floats for width and height
65
+
66
+ Returns:
67
+ Encoded hidden states tensor for input to text model
68
+ """
69
+ return w.size_encoder(fourier_features(size, w.size_features))
70
+
71
+
72
+ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
73
+ """
74
+ Takes as input the last hidden state from the text model and outputs logits
75
+ for 1024 bins representing width and height in log-scale.
76
+
77
+ The bins are distributed according to the formula:
78
+ bin = (log2(size) + 10.0) / 10.0 * 1023.0
79
+ where size values are clamped to be at least 1/1024.
80
+
81
+ To convert from bin back to size:
82
+ size = 2^((bin / 1023.0) * 10.0 - 10.0)
83
+
84
+ Args:
85
+ hidden_state: The final hidden state tensor from the text model.
86
+
87
+ Returns:
88
+ A tensor containing logits for 1024 bins for width and height.
89
+ Shape is (2, 1024) where the first dimension corresponds to width and height.
90
+ """
91
+ return w.size_decoder(hidden_state).view(2, -1)
92
+
93
+
94
+ def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
95
+ """
96
+ Takes a list of spatial references (points or regions) and encodes them into
97
+ hidden states for input to the text model.
98
+
99
+ Args:
100
+ spatial_refs: List of spatial references (points or boxes)
101
+ - Points are represented as normalized (x, y) tuples
102
+ - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
103
+
104
+ Returns:
105
+ {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
106
+ """
107
+ coords, sizes = [], []
108
+ for ref in spatial_refs:
109
+ if len(ref) == 2:
110
+ coords.append(ref[0])
111
+ coords.append(ref[1])
112
+ else:
113
+ x_c = (ref[0] + ref[2]) / 2
114
+ y_c = (ref[1] + ref[3]) / 2
115
+ width = ref[2] - ref[0]
116
+ height = ref[3] - ref[1]
117
+ coords.append(x_c)
118
+ coords.append(y_c)
119
+ sizes.append([width, height])
120
+
121
+ coords = torch.tensor(
122
+ coords, device=w.coord_features.device, dtype=w.coord_features.dtype
123
+ ).view(-1, 1)
124
+ coords = encode_coordinate(coords, w)
125
+
126
+ if sizes:
127
+ sizes = torch.tensor(
128
+ sizes, device=w.size_features.device, dtype=w.size_features.dtype
129
+ )
130
+ sizes = encode_size(sizes, w)
131
+ else:
132
+ sizes = None
133
+
134
+ return {"coords": coords, "sizes": sizes}
rope.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ import torch
4
+
5
+
6
+ def precompute_freqs_cis(
7
+ dim: int,
8
+ end: int,
9
+ theta: float = 1500000.0,
10
+ dtype: torch.dtype = torch.float32,
11
+ ) -> torch.Tensor:
12
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
13
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
14
+ freqs = t * freqs.unsqueeze(0)
15
+ freqs = torch.exp(1j * freqs)
16
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
17
+
18
+
19
+ def apply_rotary_emb(
20
+ x: torch.Tensor,
21
+ freqs_cis: torch.Tensor,
22
+ position_ids: torch.Tensor,
23
+ num_heads: int,
24
+ rot_dim: int = 32,
25
+ interleave: bool = False,
26
+ ) -> torch.Tensor:
27
+ assert rot_dim == freqs_cis.shape[-2] * 2
28
+ assert num_heads == x.shape[1]
29
+
30
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
31
+
32
+ if interleave:
33
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
34
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
35
+ else:
36
+ d_q = x_rot.shape[-1] // 2
37
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
38
+
39
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
40
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
41
+
42
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
43
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
44
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
45
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
46
+
47
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
structured_outputs.png ADDED

Git LFS Details

  • SHA256: 655bc4f2897175ca479b40308c3a0ae8e368b943b92bfbe73075821906a83b0f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
text.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn import functional as F
5
+ from torch.nn.attention.flex_attention import flex_attention
6
+ from typing import Optional
7
+
8
+ from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
+ from .rope import apply_rotary_emb, precompute_freqs_cis
10
+ from .config import TextConfig
11
+
12
+
13
+ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
14
+ return F.embedding(input_ids, w.wte)
15
+
16
+
17
+ def attn(
18
+ x: torch.Tensor,
19
+ w: nn.Module,
20
+ freqs_cis: torch.Tensor,
21
+ kv_cache: nn.Module,
22
+ attn_mask: torch.Tensor,
23
+ n_heads: int,
24
+ n_kv_heads: int,
25
+ position_ids: torch.Tensor,
26
+ lora: Optional[dict] = None,
27
+ flex_block_mask_slice=None,
28
+ ):
29
+ bsz, q_len, d_model = x.shape
30
+ head_dim = d_model // n_heads
31
+
32
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
33
+ if lora is not None:
34
+ qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
35
+ q_dim = n_heads * head_dim
36
+ kv_dim = n_kv_heads * head_dim
37
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
38
+
39
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
40
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
41
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
42
+
43
+ if hasattr(w, "tau") and w.tau is not None:
44
+ tok_feat = F.gelu(qkv_out)
45
+ tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1)
46
+ tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1)
47
+ pos = position_ids.to(q.dtype) + 1
48
+ tau_pos = 1 + (
49
+ torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5
50
+ ) # (H,S)
51
+ tau_q = (tok_q + tau_pos[None]).unsqueeze(-1) # (B,H,S,1)
52
+ tau_v = (tok_v + tau_pos[None]).unsqueeze(-1)
53
+ q = q * tau_q
54
+ v = v * tau_v
55
+
56
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
57
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
58
+
59
+ if kv_cache is not None:
60
+ k, v = kv_cache.update(position_ids, k, v)
61
+
62
+ if flex_block_mask_slice is not None:
63
+ torch._assert(n_heads == n_kv_heads, "gqa not supported yet")
64
+ out = flex_attention(q, k, v, block_mask=flex_block_mask_slice)
65
+ else:
66
+ out = F.scaled_dot_product_attention(
67
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
68
+ )
69
+
70
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
71
+
72
+ out0 = w.proj(out)
73
+ if lora is not None:
74
+ out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
75
+ out = out0 + out1
76
+ else:
77
+ out = out0
78
+
79
+ return out
80
+
81
+
82
+ def text_decoder(
83
+ x: torch.Tensor,
84
+ w: nn.Module,
85
+ attn_mask: torch.Tensor,
86
+ position_ids: torch.Tensor,
87
+ config: TextConfig,
88
+ lora: Optional[dict] = None,
89
+ flex_block_mask_slice=None,
90
+ ):
91
+ for i, block in enumerate(w.blocks):
92
+ if lora is not None:
93
+ layer_lora = lora["text"]["blocks"][str(i)]
94
+ mlp_lora = layer_lora["mlp"]
95
+ attn_lora = layer_lora["attn"]
96
+ else:
97
+ mlp_lora = None
98
+ attn_lora = None
99
+
100
+ l_in = layer_norm(x, block.ln)
101
+ l_attn = attn(
102
+ l_in,
103
+ block.attn,
104
+ freqs_cis=w.freqs_cis,
105
+ kv_cache=block.kv_cache,
106
+ attn_mask=attn_mask,
107
+ n_heads=config.n_heads,
108
+ n_kv_heads=config.n_kv_heads,
109
+ position_ids=position_ids,
110
+ lora=attn_lora,
111
+ flex_block_mask_slice=flex_block_mask_slice,
112
+ )
113
+
114
+ if config.moe is not None and i >= config.moe.start_layer:
115
+ l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)
116
+ else:
117
+ l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
118
+
119
+ x = x + l_attn + l_mlp
120
+
121
+ return x
122
+
123
+
124
+ def lm_head(
125
+ hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None
126
+ ):
127
+ hidden_BC = hidden_BTC[:, -1, :]
128
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
129
+ if indices is not None:
130
+ # Only compute logits for specified token indices
131
+ logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices]
132
+ else:
133
+ logits = w.lm_head(hidden_BC)
134
+ return logits
135
+
136
+
137
+ def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
138
+ return nn.ModuleDict(
139
+ {
140
+ "fc1": linear_cls(d_model, d_ffn, dtype=dtype),
141
+ "fc2": linear_cls(d_ffn, d_model, dtype=dtype),
142
+ }
143
+ )
144
+
145
+
146
+ def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
147
+ # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
148
+ return nn.ModuleDict(
149
+ {
150
+ "router": nn.Linear(d_model, n_experts, dtype=dtype),
151
+ "fc1": nn.ParameterDict(
152
+ {
153
+ "weight": nn.Parameter(
154
+ torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype)
155
+ )
156
+ }
157
+ ),
158
+ "fc2": nn.ParameterDict(
159
+ {
160
+ "weight": nn.Parameter(
161
+ torch.empty(n_experts, d_model, d_ffn, dtype=dtype)
162
+ )
163
+ }
164
+ ),
165
+ }
166
+ )
167
+
168
+
169
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
170
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
171
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
172
+
173
+ text = nn.ModuleDict(
174
+ {
175
+ "blocks": nn.ModuleList(
176
+ [
177
+ nn.ModuleDict(
178
+ {
179
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
180
+ "attn": nn.ModuleDict(
181
+ {
182
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
183
+ "proj": linear_cls(
184
+ config.dim, config.dim, dtype=dtype
185
+ ),
186
+ "tau": nn.ParameterDict(
187
+ {
188
+ "wq": nn.Parameter(
189
+ torch.empty(
190
+ config.n_heads, qkv_dim, dtype=dtype
191
+ )
192
+ ),
193
+ "wv": nn.Parameter(
194
+ torch.empty(
195
+ config.n_heads, qkv_dim, dtype=dtype
196
+ )
197
+ ),
198
+ "alpha": nn.Parameter(
199
+ torch.empty(config.n_heads, dtype=dtype)
200
+ ),
201
+ }
202
+ ),
203
+ }
204
+ ),
205
+ "mlp": (
206
+ build_moe_mlp(
207
+ config.dim,
208
+ config.moe.expert_inner_dim,
209
+ config.moe.num_experts,
210
+ dtype,
211
+ )
212
+ if config.moe is not None
213
+ and layer_idx >= config.moe.start_layer
214
+ else build_dense_mlp(
215
+ config.dim, config.ff_dim, dtype, linear_cls
216
+ )
217
+ ),
218
+ }
219
+ )
220
+ for layer_idx in range(config.n_layers)
221
+ ]
222
+ ),
223
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
224
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
225
+ }
226
+ )
227
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
228
+ text.register_buffer(
229
+ "freqs_cis",
230
+ precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
231
+ persistent=False,
232
+ )
233
+
234
+ return text
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5
+ """
6
+ Robust outlier detection for list of (x,y) tuples.
7
+ Only requires numpy.
8
+
9
+ Args:
10
+ points_tuples: list of (x,y) tuples
11
+ k_nearest: number of neighbors to consider
12
+ threshold: multiplier for median distance
13
+
14
+ Returns:
15
+ list: filtered list of (x,y) tuples with outliers removed
16
+ list: list of booleans indicating which points were kept (True = kept)
17
+ """
18
+ points = np.array(points_tuples)
19
+ n_points = len(points)
20
+
21
+ # Calculate pairwise distances manually
22
+ dist_matrix = np.zeros((n_points, n_points))
23
+ for i in range(n_points):
24
+ for j in range(i + 1, n_points):
25
+ # Euclidean distance between points i and j
26
+ dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27
+ dist_matrix[i, j] = dist
28
+ dist_matrix[j, i] = dist
29
+
30
+ # Get k nearest neighbors' distances
31
+ k = min(k_nearest, n_points - 1)
32
+ neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33
+ avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34
+
35
+ # Calculate mask using median distance
36
+ median_dist = np.median(avg_neighbor_dist)
37
+ mask = avg_neighbor_dist <= threshold * median_dist
38
+
39
+ # Return filtered tuples and mask
40
+ filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41
+ return filtered_tuples
vision.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Union, Tuple
7
+ from PIL import Image
8
+
9
+ from .layers import attn, layer_norm, mlp
10
+ from .image_crops import overlap_crop_image
11
+ from .config import VisionConfig
12
+
13
+ if torch.backends.mps.is_available():
14
+ # Non-divisible input sizes are not implemented on MPS device yet.
15
+ # https://github.com/pytorch/pytorch/issues/96056
16
+ def adaptive_avg_pool2d(input, output_size):
17
+ return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18
+
19
+ else:
20
+ adaptive_avg_pool2d = F.adaptive_avg_pool2d
21
+
22
+ DeviceLike = Union[str, torch.device, int]
23
+
24
+
25
+ def prepare_crops(
26
+ image: Image.Image, config: VisionConfig, device: DeviceLike
27
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
28
+ np_image = np.array(image.convert("RGB"))
29
+ overlap_crops = overlap_crop_image(
30
+ np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
31
+ )
32
+ all_crops = overlap_crops["crops"]
33
+ all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
+ all_crops = (
35
+ torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
+ .div_(255.0)
38
+ .sub_(0.5)
39
+ .div_(0.5)
40
+ )
41
+ return all_crops, overlap_crops["tiling"]
42
+
43
+
44
+ def create_patches(x, patch_size):
45
+ # Original shape: [B, C, H, W]
46
+ B, C, H, W = x.shape
47
+ P1 = P2 = patch_size
48
+
49
+ # Step 1: Split H and W dimensions into patches
50
+ # [B, C, H/P1, P1, W/P2, P2]
51
+ x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52
+
53
+ # Step 2: Rearrange dimensions to match target shape
54
+ # [B, H/P1, W/P2, C, P1, P2]
55
+ x = x.permute(0, 2, 4, 1, 3, 5)
56
+
57
+ # Step 3: Combine dimensions to get final shape
58
+ # [B, (H/P1)*(W/P2), C*P1*P2]
59
+ x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60
+
61
+ return x
62
+
63
+
64
+ def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
+ x = create_patches(input_BCHW, config.enc_patch_size)
66
+
67
+ x = w.patch_emb(x)
68
+ x = x + w.pos_emb
69
+ for block in w.blocks:
70
+ x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
71
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
72
+ x = layer_norm(x, w.post_ln)
73
+
74
+ return x
75
+
76
+
77
+ def vision_projection(
78
+ global_features: torch.Tensor,
79
+ reconstructed: torch.Tensor,
80
+ w: nn.Module,
81
+ config: VisionConfig,
82
+ ):
83
+ reconstructed = reconstructed.permute(2, 0, 1)
84
+ reconstructed = adaptive_avg_pool2d(
85
+ reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
86
+ )
87
+ reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
88
+ final_features = torch.cat([global_features, reconstructed], dim=-1)
89
+ return mlp(final_features, w.proj_mlp)
90
+
91
+
92
+ def build_vision_model(config: VisionConfig, dtype: torch.dtype):
93
+ patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
94
+ grid_size = config.crop_size // config.enc_patch_size
95
+ num_patches = grid_size * grid_size
96
+
97
+ vision = nn.ModuleDict(
98
+ {
99
+ "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
100
+ "blocks": nn.ModuleList(
101
+ [
102
+ nn.ModuleDict(
103
+ {
104
+ "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
105
+ "attn": nn.ModuleDict(
106
+ {
107
+ "qkv": nn.Linear(
108
+ config.enc_dim, 3 * config.enc_dim, dtype=dtype
109
+ ),
110
+ "proj": nn.Linear(
111
+ config.enc_dim, config.enc_dim, dtype=dtype
112
+ ),
113
+ }
114
+ ),
115
+ "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
116
+ "mlp": nn.ModuleDict(
117
+ {
118
+ "fc1": nn.Linear(
119
+ config.enc_dim, config.enc_ff_dim, dtype=dtype
120
+ ),
121
+ "fc2": nn.Linear(
122
+ config.enc_ff_dim, config.enc_dim, dtype=dtype
123
+ ),
124
+ }
125
+ ),
126
+ }
127
+ )
128
+ for _ in range(config.enc_n_layers)
129
+ ]
130
+ ),
131
+ "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
132
+ "proj_mlp": nn.ModuleDict(
133
+ {
134
+ "fc1": nn.Linear(
135
+ config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
136
+ ),
137
+ "fc2": nn.Linear(
138
+ config.proj_inner_dim, config.proj_out_dim, dtype=dtype
139
+ ),
140
+ }
141
+ ),
142
+ }
143
+ )
144
+ vision.pos_emb = nn.Parameter(
145
+ torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
146
+ )
147
+ return vision
visual_reasoning.png ADDED

Git LFS Details

  • SHA256: 43f9b5eb351a66dd8a05957813beb5e18367f6b3ad232b2bde8545d90132759d
  • Pointer size: 131 Bytes
  • Size of remote file: 434 kB