vikhyatk commited on
Commit
20e6e81
·
verified ·
1 Parent(s): 40c276b

Upload HfMoondream

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HfMoondream"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_moondream.HfConfig",
7
+ "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
+ },
9
+ "config": {},
10
+ "model_type": "moondream1",
11
+ "torch_dtype": "bfloat16",
12
+ "transformers_version": "4.51.1"
13
+ }
config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class TextMoeConfig:
7
+ num_experts: int = 64
8
+ start_layer: int = 4
9
+ experts_per_token: int = 8
10
+ expert_inner_dim: int = 1024
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TextConfig:
15
+ dim: int = 2048
16
+ ff_dim: int = 8192
17
+ n_layers: int = 24
18
+ vocab_size: int = 51200
19
+ max_context: int = 4096
20
+ n_heads: int = 32
21
+ n_kv_heads: int = 32
22
+ prefix_attn: int = 730
23
+ group_size: Optional[int] = None
24
+ moe: Optional[TextMoeConfig] = TextMoeConfig()
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class VisionConfig:
29
+ enc_dim: int = 1152
30
+ enc_patch_size: int = 14
31
+ enc_n_layers: int = 27
32
+ enc_ff_dim: int = 4304
33
+ enc_n_heads: int = 16
34
+ proj_out_dim: int = 2048
35
+ crop_size: int = 378
36
+ in_channels: int = 3
37
+ max_crops: int = 12
38
+ overlap_margin: int = 4
39
+ proj_inner_dim: int = 8192
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class RegionConfig:
44
+ dim: int = 2048
45
+ coord_feat_dim: int = 256
46
+ coord_out_dim: int = 1024
47
+ size_feat_dim: int = 512
48
+ size_out_dim: int = 2048
49
+ group_size: Optional[int] = None
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class TokenizerConfig:
54
+ bos_id: int = 0
55
+ eos_id: int = 0
56
+ answer_id: int = 3
57
+ thinking_id: int = 4
58
+ coord_id: int = 5
59
+ size_id: int = 6
60
+ start_ground_points_id: int = 7
61
+ end_ground_id: int = 9
62
+ templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
63
+ default_factory=lambda: {
64
+ "caption": {
65
+ "short": [1, 32708, 2, 12492, 3],
66
+ "normal": [1, 32708, 2, 6382, 3],
67
+ "long": [1, 32708, 2, 4059, 3],
68
+ },
69
+ "query": {"prefix": [1, 15381, 2], "suffix": [3]},
70
+ "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
71
+ "point": {"prefix": [1, 2581, 2], "suffix": [3]},
72
+ }
73
+ )
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class MoondreamConfig:
78
+ text: TextConfig = TextConfig()
79
+ vision: VisionConfig = VisionConfig()
80
+ region: RegionConfig = RegionConfig()
81
+ tokenizer: TokenizerConfig = TokenizerConfig()
82
+
83
+ @classmethod
84
+ def from_dict(cls, config_dict: dict):
85
+ text_config = TextConfig(**config_dict.get("text", {}))
86
+ vision_config = VisionConfig(**config_dict.get("vision", {}))
87
+ region_config = RegionConfig(**config_dict.get("region", {}))
88
+ tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
89
+ return cls(
90
+ text=text_config,
91
+ vision=vision_config,
92
+ region=region_config,
93
+ tokenizer=tokenizer_config,
94
+ )
95
+
96
+ def to_dict(self):
97
+ return {
98
+ "text": self.text.__dict__,
99
+ "vision": self.vision.__dict__,
100
+ "region": self.region.__dict__,
101
+ "tokenizer": self.tokenizer.__dict__,
102
+ }
hf_moondream.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from typing import Union
6
+
7
+ from .config import MoondreamConfig
8
+ from .moondream import MoondreamModel
9
+
10
+ # Files sometimes don't get loaded without these...
11
+ from .image_crops import *
12
+ from .vision import *
13
+ from .text import *
14
+ from .region import *
15
+ from .utils import *
16
+
17
+
18
+ def extract_question(text):
19
+ prefix = "<image>\n\nQuestion: "
20
+ suffix = "\n\nAnswer:"
21
+
22
+ if text.startswith(prefix) and text.endswith(suffix):
23
+ return text[len(prefix) : -len(suffix)]
24
+ else:
25
+ return None
26
+
27
+
28
+ class HfConfig(PretrainedConfig):
29
+ _auto_class = "AutoConfig"
30
+ model_type = "moondream1"
31
+
32
+ def __init__(self, **kwargs):
33
+ super().__init__(**kwargs)
34
+ self.config = {}
35
+
36
+
37
+ class HfMoondream(PreTrainedModel):
38
+ _auto_class = "AutoModelForCausalLM"
39
+ config_class = HfConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.model = MoondreamModel(
44
+ MoondreamConfig.from_dict(config.config), setup_caches=False
45
+ )
46
+ self._is_kv_cache_setup = False
47
+
48
+ def _setup_caches(self):
49
+ if not self._is_kv_cache_setup:
50
+ self.model._setup_caches()
51
+ self._is_kv_cache_setup = True
52
+
53
+ @property
54
+ def encode_image(self):
55
+ self._setup_caches()
56
+ return self.model.encode_image
57
+
58
+ @property
59
+ def query(self):
60
+ self._setup_caches()
61
+ return self.model.query
62
+
63
+ @property
64
+ def caption(self):
65
+ self._setup_caches()
66
+ return self.model.caption
67
+
68
+ @property
69
+ def detect(self):
70
+ self._setup_caches()
71
+ return self.model.detect
72
+
73
+ @property
74
+ def point(self):
75
+ self._setup_caches()
76
+ return self.model.point
77
+
78
+ @property
79
+ def detect_gaze(self):
80
+ self._setup_caches()
81
+ return self.model.detect_gaze
82
+
83
+ def answer_question(
84
+ self,
85
+ image_embeds,
86
+ question,
87
+ tokenizer=None,
88
+ chat_history="",
89
+ result_queue=None,
90
+ max_new_tokens=256,
91
+ **kwargs
92
+ ):
93
+ answer = self.query(image_embeds, question)["answer"].strip()
94
+
95
+ if result_queue is not None:
96
+ result_queue.put(answer)
97
+ return answer
98
+
99
+ def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
100
+ answers = []
101
+ for image, prompt in zip(images, prompts):
102
+ answers.append(self.query(image, prompt)["answer"].strip())
103
+ return answers
104
+
105
+ def _unsupported_exception(self):
106
+ raise NotImplementedError(
107
+ "This method is not supported in the latest version of moondream. "
108
+ "Consider upgrading to the updated API spec, or alternately pin "
109
+ "to 'revision=2024-08-26'."
110
+ )
111
+
112
+ def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
113
+ """
114
+ Function definition remains unchanged for backwards compatibility.
115
+ Be aware that tokenizer, max_new_takens, and kwargs are ignored.
116
+ """
117
+ prompt_extracted = extract_question(prompt)
118
+ if prompt_extracted is not None:
119
+ answer = self.model.query(
120
+ image=image_embeds, question=prompt_extracted, stream=False
121
+ )["answer"]
122
+ else:
123
+ image_embeds = self.encode_image(image_embeds)
124
+ prompt_tokens = torch.tensor(
125
+ [self.model.tokenizer.encode(prompt).ids],
126
+ device=self.device,
127
+ )
128
+
129
+ def generator():
130
+ for token in self.model._generate_answer(
131
+ prompt_tokens,
132
+ image_embeds.kv_cache,
133
+ image_embeds.pos,
134
+ max_new_tokens,
135
+ ):
136
+ yield token
137
+
138
+ answer = "".join(list(generator()))
139
+
140
+ return [answer]
141
+
142
+ def get_input_embeddings(self) -> nn.Embedding:
143
+ """
144
+ Lazily wrap the raw parameter `self.model.text.wte` in a real
145
+ `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
146
+ **shares** the weight tensor—no copy is made.
147
+ """
148
+ if not hasattr(self, "_input_embeddings"):
149
+ self._input_embeddings = nn.Embedding.from_pretrained(
150
+ self.model.text.wte, # tensor created in text.py
151
+ freeze=True, # set to False if you need it trainable
152
+ )
153
+ return self._input_embeddings
154
+
155
+ def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
156
+ """
157
+ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
158
+ embeddings and keeps everything tied to `self.model.text.wte`.
159
+ """
160
+ # 1. point the low-level parameter to the new weight matrix
161
+ self.model.text.wte = value.weight
162
+ # 2. keep a reference for get_input_embeddings()
163
+ self._input_embeddings = value
164
+
165
+ def input_embeds(
166
+ self,
167
+ input_ids: Union[torch.LongTensor, list, tuple],
168
+ *,
169
+ device: torch.device | None = None
170
+ ) -> torch.FloatTensor:
171
+ """
172
+ Back-compat wrapper that turns token IDs into embeddings.
173
+
174
+ Example:
175
+ ids = torch.tensor([[1, 2, 3]])
176
+ embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
177
+ """
178
+ if not torch.is_tensor(input_ids):
179
+ input_ids = torch.as_tensor(input_ids)
180
+ if device is not None:
181
+ input_ids = input_ids.to(device)
182
+
183
+ return self.get_input_embeddings()(input_ids)
image_crops.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ from typing import TypedDict
6
+
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
+
17
+ def select_tiling(
18
+ height: int, width: int, crop_size: int, max_crops: int
19
+ ) -> tuple[int, int]:
20
+ """
21
+ Determine the optimal number of tiles to cover an image with overlapping crops.
22
+ """
23
+ if height <= crop_size or width <= crop_size:
24
+ return (1, 1)
25
+
26
+ # Minimum required tiles in each dimension
27
+ min_h = math.ceil(height / crop_size)
28
+ min_w = math.ceil(width / crop_size)
29
+
30
+ # If minimum required tiles exceed max_crops, return proportional distribution
31
+ if min_h * min_w > max_crops:
32
+ ratio = math.sqrt(max_crops / (min_h * min_w))
33
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
34
+
35
+ # Perfect aspect-ratio tiles that satisfy max_crops
36
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
37
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
38
+
39
+ # Ensure we meet minimum tile requirements
40
+ h_tiles = max(h_tiles, min_h)
41
+ w_tiles = max(w_tiles, min_w)
42
+
43
+ # If we exceeded max_crops, scale down the larger dimension
44
+ if h_tiles * w_tiles > max_crops:
45
+ if w_tiles > h_tiles:
46
+ w_tiles = math.floor(max_crops / h_tiles)
47
+ else:
48
+ h_tiles = math.floor(max_crops / w_tiles)
49
+
50
+ return (max(1, h_tiles), max(1, w_tiles))
51
+
52
+
53
+ class OverlapCropOutput(TypedDict):
54
+ crops: np.ndarray
55
+ tiling: tuple[int, int]
56
+
57
+
58
+ def overlap_crop_image(
59
+ image: np.ndarray,
60
+ overlap_margin: int,
61
+ max_crops: int,
62
+ base_size: tuple[int, int] = (378, 378),
63
+ patch_size: int = 14,
64
+ ) -> OverlapCropOutput:
65
+ """
66
+ Process an image using an overlap-and-resize cropping strategy with margin handling.
67
+
68
+ This function takes an input image and creates multiple overlapping crops with
69
+ consistent margins. It produces:
70
+ 1. A single global crop resized to base_size
71
+ 2. Multiple overlapping local crops that maintain high resolution details
72
+ 3. A patch ordering matrix that tracks correspondence between crops
73
+
74
+ The overlap strategy ensures:
75
+ - Smooth transitions between adjacent crops
76
+ - No loss of information at crop boundaries
77
+ - Proper handling of features that cross crop boundaries
78
+ - Consistent patch indexing across the full image
79
+
80
+ Args:
81
+ image (np.ndarray): Input image as numpy array with shape (H,W,C)
82
+ base_size (tuple[int,int]): Target size for crops, default (378,378)
83
+ patch_size (int): Size of patches in pixels, default 14
84
+ overlap_margin (int): Margin size in patch units, default 4
85
+ max_crops (int): Maximum number of crops allowed, default 12
86
+
87
+ Returns:
88
+ OverlapCropOutput: Dictionary containing:
89
+ - crops: A numpy array containing the global crop of the full image (index 0)
90
+ followed by the overlapping cropped regions (indices 1+)
91
+ - tiling: Tuple of (height,width) tile counts
92
+ """
93
+ original_h, original_w = image.shape[:2]
94
+
95
+ # Convert margin from patch units to pixels
96
+ margin_pixels = patch_size * overlap_margin
97
+ total_margin_pixels = margin_pixels * 2 # Both sides
98
+
99
+ # Calculate crop parameters
100
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
101
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
102
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
103
+
104
+ # Determine tiling
105
+ tiling = select_tiling(
106
+ original_h - total_margin_pixels,
107
+ original_w - total_margin_pixels,
108
+ crop_window_size,
109
+ max_crops,
110
+ )
111
+
112
+ # Pre-allocate crops.
113
+ n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
114
+ crops = np.zeros(
115
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
116
+ )
117
+
118
+ # Resize image to fit tiling
119
+ target_size = (
120
+ tiling[0] * crop_window_size + total_margin_pixels,
121
+ tiling[1] * crop_window_size + total_margin_pixels,
122
+ )
123
+
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
+
152
+ for i in range(tiling[0]):
153
+ for j in range(tiling[1]):
154
+ # Calculate crop coordinates
155
+ y0 = i * crop_window_size
156
+ x0 = j * crop_window_size
157
+
158
+ # Extract crop with padding if needed
159
+ y_end = min(y0 + base_size[0], image.shape[0])
160
+ x_end = min(x0 + base_size[1], image.shape[1])
161
+
162
+ crop_region = image[y0:y_end, x0:x_end]
163
+ crops[
164
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
165
+ ] = crop_region
166
+
167
+ return {"crops": crops, "tiling": tiling}
168
+
169
+
170
+ def reconstruct_from_crops(
171
+ crops: torch.Tensor,
172
+ tiling: tuple[int, int],
173
+ overlap_margin: int,
174
+ patch_size: int = 14,
175
+ ) -> torch.Tensor:
176
+ """
177
+ Reconstruct the original image from overlapping crops into a single seamless image.
178
+
179
+ Takes a list of overlapping image crops along with their positional metadata and
180
+ reconstructs them into a single coherent image by carefully stitching together
181
+ non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
182
+
183
+ Args:
184
+ crops: List of image crops as numpy arrays or PyTorch tensors with shape
185
+ (H,W,C)
186
+ tiling: Tuple of (height,width) indicating crop grid layout
187
+ patch_size: Size in pixels of each patch, default 14
188
+ overlap_margin: Number of overlapping patches on each edge, default 4
189
+
190
+ Returns:
191
+ Reconstructed image as numpy array or PyTorch tensor matching input type,
192
+ with shape (H,W,C) where H,W are the original image dimensions
193
+ """
194
+ tiling_h, tiling_w = tiling
195
+ crop_height, crop_width = crops[0].shape[:2]
196
+ margin_pixels = overlap_margin * patch_size
197
+
198
+ # Calculate output size (only adding margins once)
199
+ output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
200
+ output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
201
+
202
+ reconstructed = torch.zeros(
203
+ (output_h, output_w, crops[0].shape[2]),
204
+ device=crops[0].device,
205
+ dtype=crops[0].dtype,
206
+ )
207
+
208
+ for i, crop in enumerate(crops):
209
+ tile_y = i // tiling_w
210
+ tile_x = i % tiling_w
211
+
212
+ # For each tile, determine which part to keep
213
+ # Keep left margin only for first column
214
+ x_start = 0 if tile_x == 0 else margin_pixels
215
+ # Keep right margin only for last column
216
+ x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
217
+ # Keep top margin only for first row
218
+ y_start = 0 if tile_y == 0 else margin_pixels
219
+ # Keep bottom margin only for last row
220
+ y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
221
+
222
+ # Calculate where this piece belongs in the output
223
+ out_x = tile_x * (crop_width - 2 * margin_pixels)
224
+ out_y = tile_y * (crop_height - 2 * margin_pixels)
225
+
226
+ # Place the piece
227
+ reconstructed[
228
+ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
229
+ ] = crop[y_start:y_end, x_start:x_end]
230
+
231
+ return reconstructed
layers.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Literal, Optional
7
+
8
+ try:
9
+ from torchao import quantize_
10
+ from torchao.quantization import int4_weight_only
11
+ except ImportError:
12
+
13
+ def quantize_(model, quant_mode):
14
+ raise ImportError(
15
+ "torchao is not installed. Please install it with `pip install torchao`."
16
+ )
17
+
18
+ def int4_weight_only(group_size):
19
+ raise ImportError(
20
+ "torchao is not installed. Please install it with `pip install torchao`."
21
+ )
22
+
23
+
24
+ def gelu_approx(x):
25
+ return F.gelu(x, approximate="tanh")
26
+
27
+
28
+ @dataclass
29
+ class LinearWeights:
30
+ weight: torch.Tensor
31
+ bias: torch.Tensor
32
+
33
+
34
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
35
+ return F.linear(x, w.weight, w.bias)
36
+
37
+
38
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39
+ _step = W_q.shape[0]
40
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
41
+ W_r[:_step] = (W_q & 0b11110000) >> 4
42
+ W_r[_step:] = W_q & 0b00001111
43
+ W_r.sub_(zero).mul_(scale)
44
+ return W_r.reshape(orig_shape)
45
+
46
+
47
+ class QuantizedLinear(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ out_features: int,
52
+ dtype: torch.dtype,
53
+ ):
54
+ # TODO: Take group_size as an input instead of hardcoding it here.
55
+ super().__init__()
56
+ self.in_features = in_features
57
+ self.out_features = out_features
58
+ self.weight = nn.ParameterDict(
59
+ {
60
+ "packed": nn.Parameter(
61
+ torch.empty(
62
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
63
+ ),
64
+ requires_grad=False,
65
+ ),
66
+ "scale": nn.Parameter(
67
+ torch.empty(out_features * in_features // 128, 1),
68
+ requires_grad=False,
69
+ ),
70
+ "zero_point": nn.Parameter(
71
+ torch.empty(out_features * in_features // 128, 1),
72
+ requires_grad=False,
73
+ ),
74
+ }
75
+ )
76
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
77
+ self.unpacked = False
78
+
79
+ def unpack(self):
80
+ if self.unpacked:
81
+ return
82
+
83
+ self.weight = nn.Parameter(
84
+ dequantize_tensor(
85
+ self.weight["packed"],
86
+ self.weight["scale"],
87
+ self.weight["zero_point"],
88
+ (self.out_features, self.in_features),
89
+ torch.bfloat16,
90
+ )
91
+ )
92
+ with torch.device("meta"):
93
+ self.linear = nn.Linear(
94
+ self.in_features, self.out_features, dtype=torch.bfloat16
95
+ )
96
+ self.linear.weight = self.weight
97
+ self.linear.bias = nn.Parameter(
98
+ self.bias.to(torch.bfloat16), requires_grad=False
99
+ )
100
+
101
+ del self.weight, self.bias
102
+ quantize_(self, int4_weight_only(group_size=128))
103
+ self.unpacked = True
104
+ torch.cuda.empty_cache()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.unpacked:
108
+ self.unpack()
109
+ return self.linear(x)
110
+
111
+
112
+ @dataclass
113
+ class LayerNormWeights:
114
+ weight: torch.Tensor
115
+ bias: torch.Tensor
116
+
117
+
118
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
119
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
120
+
121
+
122
+ @dataclass
123
+ class MLPWeights:
124
+ fc1: LinearWeights
125
+ fc2: LinearWeights
126
+ act: Literal["gelu_approx"] = "gelu_approx"
127
+
128
+
129
+ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
130
+ x0 = w.fc1(x)
131
+ if lora is not None:
132
+ x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
+ x = x0 + x1
134
+ else:
135
+ x = x0
136
+
137
+ x = gelu_approx(x)
138
+
139
+ x0 = w.fc2(x)
140
+ if lora is not None:
141
+ x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
+ x = x0 + x1
143
+ else:
144
+ x = x0
145
+
146
+ return x
147
+
148
+
149
+ def moe_mlp(
150
+ x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
151
+ ) -> torch.Tensor:
152
+ B, T, C = x.shape
153
+ x = x.reshape(-1, C)
154
+
155
+ # Router computation
156
+ router_logits = mlp_module.router(x)
157
+ topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)
158
+ topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)
159
+ num_tokens, top_k = topk_idxs.shape
160
+
161
+ if T == 1:
162
+ w1_weight = mlp_module.fc1.weight
163
+ w2_weight = mlp_module.fc2.weight
164
+
165
+ # Flatten to process all token-expert pairs at once
166
+ flat_idxs = topk_idxs.view(-1) # [T*A]
167
+ flat_weights = topk_weights.view(-1) # [T*A]
168
+
169
+ # Select expert weights
170
+ w1_selected = w1_weight[flat_idxs] # [T*A, H, D]
171
+ w2_selected = w2_weight[flat_idxs] # [T*A, D, H]
172
+
173
+ # Expand input for all token-expert pairs
174
+ x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
175
+
176
+ # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
177
+ x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
178
+ -1
179
+ ) # [T*A, H]
180
+ x1, g = x1_full.chunk(2, dim=-1)
181
+ x1 = F.gelu(x1) * (g + 1)
182
+
183
+ # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
184
+ expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
185
+
186
+ # Apply weights and reshape
187
+ weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
188
+ weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D]
189
+
190
+ # Sum over experts
191
+ mlp_out = weighted_outs.sum(dim=1) # [T, D]
192
+ mlp_out = mlp_out.view(B, T, C)
193
+
194
+ return mlp_out
195
+ else:
196
+ out = x.new_zeros(x.size())
197
+
198
+ for expert_id in range(mlp_module.fc1.weight.shape[0]):
199
+ token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)
200
+ if token_pos.numel() == 0:
201
+ continue
202
+
203
+ x_tok = x.index_select(0, token_pos)
204
+ gate_tok = topk_weights[token_pos, which_k]
205
+
206
+ h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
207
+ h, g = h_full.chunk(2, dim=-1)
208
+ h = F.gelu(h) * (g + 1)
209
+ y = F.linear(h, mlp_module.fc2.weight[expert_id])
210
+
211
+ y.mul_(gate_tok.unsqueeze(-1))
212
+ out.index_add_(0, token_pos, y)
213
+
214
+ return out.view(B, T, C)
215
+
216
+
217
+ @dataclass
218
+ class AttentionWeights:
219
+ qkv: LinearWeights
220
+ proj: LinearWeights
221
+
222
+
223
+ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
224
+ bsz, q_len, d_model = x.shape
225
+ head_dim = d_model // n_heads
226
+
227
+ q, k, v = [
228
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
229
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
230
+ ]
231
+ out = F.scaled_dot_product_attention(q, k, v)
232
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
233
+ out = linear(out, w.proj)
234
+ return out
lora.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import torch
5
+
6
+ from pathlib import Path
7
+ from urllib.request import Request, urlopen
8
+ from typing import Optional
9
+
10
+
11
+ def variant_cache_dir():
12
+ hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
+ if hf_hub_cache is not None:
14
+ return Path(hf_hub_cache) / "md_variants"
15
+
16
+ hf_home = os.environ.get("HF_HOME")
17
+ if hf_home is not None:
18
+ return Path(hf_home) / "hub" / "md_variants"
19
+
20
+ return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
+
22
+
23
+ def cached_variant_path(variant_id: str):
24
+ variant, *rest = variant_id.split("/", 1)
25
+ step = rest[0] if rest else "final"
26
+
27
+ cache_dir = variant_cache_dir() / variant
28
+ os.makedirs(cache_dir, exist_ok=True)
29
+ dest = cache_dir / f"{step}.pt"
30
+ if dest.exists():
31
+ return dest
32
+
33
+ md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
34
+
35
+ headers = {"User-Agent": "moondream-torch"}
36
+ api_key = os.getenv("MOONDREAM_API_KEY")
37
+ if api_key is not None:
38
+ headers["X-Moondream-Auth"] = api_key
39
+
40
+ req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
+ with urlopen(req) as r, open(dest, "wb") as f:
42
+ shutil.copyfileobj(r, f)
43
+ return dest
44
+
45
+
46
+ def nest(flat):
47
+ tree = {}
48
+ for k, v in flat.items():
49
+ parts = k.split(".")
50
+ d = tree
51
+ for p in parts[:-1]:
52
+ d = d.setdefault(p, {})
53
+ d[parts[-1]] = v
54
+ return tree
55
+
56
+
57
+ @functools.lru_cache(maxsize=5)
58
+ def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
+ if variant_id is None:
60
+ return None
61
+
62
+ state_dict = torch.load(
63
+ cached_variant_path(variant_id), map_location=device, weights_only=True
64
+ )
65
+
66
+ # TODO: Move these into the training code that saves checkpoints...
67
+ rename_rules = [
68
+ ("text_model.transformer.h", "text.blocks"),
69
+ (".mixer", ".attn"),
70
+ (".out_proj", ".proj"),
71
+ (".Wqkv", ".qkv"),
72
+ (".parametrizations.weight.0", ""),
73
+ ]
74
+ new_state_dict = {}
75
+ for key, tensor in state_dict.items():
76
+ new_key = key
77
+ for old, new in rename_rules:
78
+ if old in new_key:
79
+ new_key = new_key.replace(old, new)
80
+ new_state_dict[new_key] = tensor
81
+
82
+ return nest(new_state_dict)
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9414434ab3afb560b37bbd5d3972ae944679e7773a60ece538e4231d2cf142f
3
+ size 4907406296
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0433bb359387b93502680ac120913f46e0d6d62940f74ef75759a085edcad86
3
+ size 4736548872
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82b15aaadff6efa4013788ccaa321d496993fe41240305c7b8dd8e8cfbc4fa69
3
+ size 4502742464
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fde1839f3766d227b30cfa07521cc0126bfacac08e91f135defdc7624405977f
3
+ size 4390620392
model.safetensors.index.json ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18537245664
4
+ },
5
+ "weight_map": {
6
+ "model.region.coord_decoder.bias": "model-00004-of-00004.safetensors",
7
+ "model.region.coord_decoder.weight": "model-00004-of-00004.safetensors",
8
+ "model.region.coord_encoder.bias": "model-00004-of-00004.safetensors",
9
+ "model.region.coord_encoder.weight": "model-00004-of-00004.safetensors",
10
+ "model.region.coord_features": "model-00004-of-00004.safetensors",
11
+ "model.region.size_decoder.bias": "model-00004-of-00004.safetensors",
12
+ "model.region.size_decoder.weight": "model-00004-of-00004.safetensors",
13
+ "model.region.size_encoder.bias": "model-00004-of-00004.safetensors",
14
+ "model.region.size_encoder.weight": "model-00004-of-00004.safetensors",
15
+ "model.region.size_features": "model-00004-of-00004.safetensors",
16
+ "model.text.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
17
+ "model.text.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.text.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
19
+ "model.text.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
20
+ "model.text.blocks.0.attn.tau.alpha": "model-00001-of-00004.safetensors",
21
+ "model.text.blocks.0.attn.tau.wq": "model-00001-of-00004.safetensors",
22
+ "model.text.blocks.0.attn.tau.wv": "model-00001-of-00004.safetensors",
23
+ "model.text.blocks.0.ln.bias": "model-00001-of-00004.safetensors",
24
+ "model.text.blocks.0.ln.weight": "model-00001-of-00004.safetensors",
25
+ "model.text.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
26
+ "model.text.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
27
+ "model.text.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
28
+ "model.text.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
29
+ "model.text.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
30
+ "model.text.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.text.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
32
+ "model.text.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
33
+ "model.text.blocks.1.attn.tau.alpha": "model-00001-of-00004.safetensors",
34
+ "model.text.blocks.1.attn.tau.wq": "model-00001-of-00004.safetensors",
35
+ "model.text.blocks.1.attn.tau.wv": "model-00001-of-00004.safetensors",
36
+ "model.text.blocks.1.ln.bias": "model-00001-of-00004.safetensors",
37
+ "model.text.blocks.1.ln.weight": "model-00001-of-00004.safetensors",
38
+ "model.text.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
39
+ "model.text.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
40
+ "model.text.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
41
+ "model.text.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
42
+ "model.text.blocks.10.attn.proj.bias": "model-00002-of-00004.safetensors",
43
+ "model.text.blocks.10.attn.proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.text.blocks.10.attn.qkv.bias": "model-00002-of-00004.safetensors",
45
+ "model.text.blocks.10.attn.qkv.weight": "model-00002-of-00004.safetensors",
46
+ "model.text.blocks.10.attn.tau.alpha": "model-00002-of-00004.safetensors",
47
+ "model.text.blocks.10.attn.tau.wq": "model-00002-of-00004.safetensors",
48
+ "model.text.blocks.10.attn.tau.wv": "model-00002-of-00004.safetensors",
49
+ "model.text.blocks.10.ln.bias": "model-00002-of-00004.safetensors",
50
+ "model.text.blocks.10.ln.weight": "model-00002-of-00004.safetensors",
51
+ "model.text.blocks.10.mlp.fc1.weight": "model-00002-of-00004.safetensors",
52
+ "model.text.blocks.10.mlp.fc2.weight": "model-00002-of-00004.safetensors",
53
+ "model.text.blocks.10.mlp.router.bias": "model-00002-of-00004.safetensors",
54
+ "model.text.blocks.10.mlp.router.weight": "model-00002-of-00004.safetensors",
55
+ "model.text.blocks.11.attn.proj.bias": "model-00002-of-00004.safetensors",
56
+ "model.text.blocks.11.attn.proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.text.blocks.11.attn.qkv.bias": "model-00002-of-00004.safetensors",
58
+ "model.text.blocks.11.attn.qkv.weight": "model-00002-of-00004.safetensors",
59
+ "model.text.blocks.11.attn.tau.alpha": "model-00002-of-00004.safetensors",
60
+ "model.text.blocks.11.attn.tau.wq": "model-00002-of-00004.safetensors",
61
+ "model.text.blocks.11.attn.tau.wv": "model-00002-of-00004.safetensors",
62
+ "model.text.blocks.11.ln.bias": "model-00002-of-00004.safetensors",
63
+ "model.text.blocks.11.ln.weight": "model-00002-of-00004.safetensors",
64
+ "model.text.blocks.11.mlp.fc1.weight": "model-00002-of-00004.safetensors",
65
+ "model.text.blocks.11.mlp.fc2.weight": "model-00002-of-00004.safetensors",
66
+ "model.text.blocks.11.mlp.router.bias": "model-00002-of-00004.safetensors",
67
+ "model.text.blocks.11.mlp.router.weight": "model-00002-of-00004.safetensors",
68
+ "model.text.blocks.12.attn.proj.bias": "model-00002-of-00004.safetensors",
69
+ "model.text.blocks.12.attn.proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.text.blocks.12.attn.qkv.bias": "model-00002-of-00004.safetensors",
71
+ "model.text.blocks.12.attn.qkv.weight": "model-00002-of-00004.safetensors",
72
+ "model.text.blocks.12.attn.tau.alpha": "model-00002-of-00004.safetensors",
73
+ "model.text.blocks.12.attn.tau.wq": "model-00002-of-00004.safetensors",
74
+ "model.text.blocks.12.attn.tau.wv": "model-00002-of-00004.safetensors",
75
+ "model.text.blocks.12.ln.bias": "model-00002-of-00004.safetensors",
76
+ "model.text.blocks.12.ln.weight": "model-00002-of-00004.safetensors",
77
+ "model.text.blocks.12.mlp.fc1.weight": "model-00002-of-00004.safetensors",
78
+ "model.text.blocks.12.mlp.fc2.weight": "model-00002-of-00004.safetensors",
79
+ "model.text.blocks.12.mlp.router.bias": "model-00002-of-00004.safetensors",
80
+ "model.text.blocks.12.mlp.router.weight": "model-00002-of-00004.safetensors",
81
+ "model.text.blocks.13.attn.proj.bias": "model-00002-of-00004.safetensors",
82
+ "model.text.blocks.13.attn.proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.text.blocks.13.attn.qkv.bias": "model-00002-of-00004.safetensors",
84
+ "model.text.blocks.13.attn.qkv.weight": "model-00002-of-00004.safetensors",
85
+ "model.text.blocks.13.attn.tau.alpha": "model-00002-of-00004.safetensors",
86
+ "model.text.blocks.13.attn.tau.wq": "model-00002-of-00004.safetensors",
87
+ "model.text.blocks.13.attn.tau.wv": "model-00002-of-00004.safetensors",
88
+ "model.text.blocks.13.ln.bias": "model-00002-of-00004.safetensors",
89
+ "model.text.blocks.13.ln.weight": "model-00002-of-00004.safetensors",
90
+ "model.text.blocks.13.mlp.fc1.weight": "model-00002-of-00004.safetensors",
91
+ "model.text.blocks.13.mlp.fc2.weight": "model-00003-of-00004.safetensors",
92
+ "model.text.blocks.13.mlp.router.bias": "model-00002-of-00004.safetensors",
93
+ "model.text.blocks.13.mlp.router.weight": "model-00002-of-00004.safetensors",
94
+ "model.text.blocks.14.attn.proj.bias": "model-00003-of-00004.safetensors",
95
+ "model.text.blocks.14.attn.proj.weight": "model-00003-of-00004.safetensors",
96
+ "model.text.blocks.14.attn.qkv.bias": "model-00003-of-00004.safetensors",
97
+ "model.text.blocks.14.attn.qkv.weight": "model-00003-of-00004.safetensors",
98
+ "model.text.blocks.14.attn.tau.alpha": "model-00003-of-00004.safetensors",
99
+ "model.text.blocks.14.attn.tau.wq": "model-00003-of-00004.safetensors",
100
+ "model.text.blocks.14.attn.tau.wv": "model-00003-of-00004.safetensors",
101
+ "model.text.blocks.14.ln.bias": "model-00003-of-00004.safetensors",
102
+ "model.text.blocks.14.ln.weight": "model-00003-of-00004.safetensors",
103
+ "model.text.blocks.14.mlp.fc1.weight": "model-00003-of-00004.safetensors",
104
+ "model.text.blocks.14.mlp.fc2.weight": "model-00003-of-00004.safetensors",
105
+ "model.text.blocks.14.mlp.router.bias": "model-00003-of-00004.safetensors",
106
+ "model.text.blocks.14.mlp.router.weight": "model-00003-of-00004.safetensors",
107
+ "model.text.blocks.15.attn.proj.bias": "model-00003-of-00004.safetensors",
108
+ "model.text.blocks.15.attn.proj.weight": "model-00003-of-00004.safetensors",
109
+ "model.text.blocks.15.attn.qkv.bias": "model-00003-of-00004.safetensors",
110
+ "model.text.blocks.15.attn.qkv.weight": "model-00003-of-00004.safetensors",
111
+ "model.text.blocks.15.attn.tau.alpha": "model-00003-of-00004.safetensors",
112
+ "model.text.blocks.15.attn.tau.wq": "model-00003-of-00004.safetensors",
113
+ "model.text.blocks.15.attn.tau.wv": "model-00003-of-00004.safetensors",
114
+ "model.text.blocks.15.ln.bias": "model-00003-of-00004.safetensors",
115
+ "model.text.blocks.15.ln.weight": "model-00003-of-00004.safetensors",
116
+ "model.text.blocks.15.mlp.fc1.weight": "model-00003-of-00004.safetensors",
117
+ "model.text.blocks.15.mlp.fc2.weight": "model-00003-of-00004.safetensors",
118
+ "model.text.blocks.15.mlp.router.bias": "model-00003-of-00004.safetensors",
119
+ "model.text.blocks.15.mlp.router.weight": "model-00003-of-00004.safetensors",
120
+ "model.text.blocks.16.attn.proj.bias": "model-00003-of-00004.safetensors",
121
+ "model.text.blocks.16.attn.proj.weight": "model-00003-of-00004.safetensors",
122
+ "model.text.blocks.16.attn.qkv.bias": "model-00003-of-00004.safetensors",
123
+ "model.text.blocks.16.attn.qkv.weight": "model-00003-of-00004.safetensors",
124
+ "model.text.blocks.16.attn.tau.alpha": "model-00003-of-00004.safetensors",
125
+ "model.text.blocks.16.attn.tau.wq": "model-00003-of-00004.safetensors",
126
+ "model.text.blocks.16.attn.tau.wv": "model-00003-of-00004.safetensors",
127
+ "model.text.blocks.16.ln.bias": "model-00003-of-00004.safetensors",
128
+ "model.text.blocks.16.ln.weight": "model-00003-of-00004.safetensors",
129
+ "model.text.blocks.16.mlp.fc1.weight": "model-00003-of-00004.safetensors",
130
+ "model.text.blocks.16.mlp.fc2.weight": "model-00003-of-00004.safetensors",
131
+ "model.text.blocks.16.mlp.router.bias": "model-00003-of-00004.safetensors",
132
+ "model.text.blocks.16.mlp.router.weight": "model-00003-of-00004.safetensors",
133
+ "model.text.blocks.17.attn.proj.bias": "model-00003-of-00004.safetensors",
134
+ "model.text.blocks.17.attn.proj.weight": "model-00003-of-00004.safetensors",
135
+ "model.text.blocks.17.attn.qkv.bias": "model-00003-of-00004.safetensors",
136
+ "model.text.blocks.17.attn.qkv.weight": "model-00003-of-00004.safetensors",
137
+ "model.text.blocks.17.attn.tau.alpha": "model-00003-of-00004.safetensors",
138
+ "model.text.blocks.17.attn.tau.wq": "model-00003-of-00004.safetensors",
139
+ "model.text.blocks.17.attn.tau.wv": "model-00003-of-00004.safetensors",
140
+ "model.text.blocks.17.ln.bias": "model-00003-of-00004.safetensors",
141
+ "model.text.blocks.17.ln.weight": "model-00003-of-00004.safetensors",
142
+ "model.text.blocks.17.mlp.fc1.weight": "model-00003-of-00004.safetensors",
143
+ "model.text.blocks.17.mlp.fc2.weight": "model-00003-of-00004.safetensors",
144
+ "model.text.blocks.17.mlp.router.bias": "model-00003-of-00004.safetensors",
145
+ "model.text.blocks.17.mlp.router.weight": "model-00003-of-00004.safetensors",
146
+ "model.text.blocks.18.attn.proj.bias": "model-00003-of-00004.safetensors",
147
+ "model.text.blocks.18.attn.proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.text.blocks.18.attn.qkv.bias": "model-00003-of-00004.safetensors",
149
+ "model.text.blocks.18.attn.qkv.weight": "model-00003-of-00004.safetensors",
150
+ "model.text.blocks.18.attn.tau.alpha": "model-00003-of-00004.safetensors",
151
+ "model.text.blocks.18.attn.tau.wq": "model-00003-of-00004.safetensors",
152
+ "model.text.blocks.18.attn.tau.wv": "model-00003-of-00004.safetensors",
153
+ "model.text.blocks.18.ln.bias": "model-00003-of-00004.safetensors",
154
+ "model.text.blocks.18.ln.weight": "model-00003-of-00004.safetensors",
155
+ "model.text.blocks.18.mlp.fc1.weight": "model-00003-of-00004.safetensors",
156
+ "model.text.blocks.18.mlp.fc2.weight": "model-00003-of-00004.safetensors",
157
+ "model.text.blocks.18.mlp.router.bias": "model-00003-of-00004.safetensors",
158
+ "model.text.blocks.18.mlp.router.weight": "model-00003-of-00004.safetensors",
159
+ "model.text.blocks.19.attn.proj.bias": "model-00003-of-00004.safetensors",
160
+ "model.text.blocks.19.attn.proj.weight": "model-00003-of-00004.safetensors",
161
+ "model.text.blocks.19.attn.qkv.bias": "model-00003-of-00004.safetensors",
162
+ "model.text.blocks.19.attn.qkv.weight": "model-00003-of-00004.safetensors",
163
+ "model.text.blocks.19.attn.tau.alpha": "model-00003-of-00004.safetensors",
164
+ "model.text.blocks.19.attn.tau.wq": "model-00003-of-00004.safetensors",
165
+ "model.text.blocks.19.attn.tau.wv": "model-00003-of-00004.safetensors",
166
+ "model.text.blocks.19.ln.bias": "model-00003-of-00004.safetensors",
167
+ "model.text.blocks.19.ln.weight": "model-00003-of-00004.safetensors",
168
+ "model.text.blocks.19.mlp.fc1.weight": "model-00004-of-00004.safetensors",
169
+ "model.text.blocks.19.mlp.fc2.weight": "model-00004-of-00004.safetensors",
170
+ "model.text.blocks.19.mlp.router.bias": "model-00003-of-00004.safetensors",
171
+ "model.text.blocks.19.mlp.router.weight": "model-00003-of-00004.safetensors",
172
+ "model.text.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
173
+ "model.text.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
174
+ "model.text.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
175
+ "model.text.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
176
+ "model.text.blocks.2.attn.tau.alpha": "model-00001-of-00004.safetensors",
177
+ "model.text.blocks.2.attn.tau.wq": "model-00001-of-00004.safetensors",
178
+ "model.text.blocks.2.attn.tau.wv": "model-00001-of-00004.safetensors",
179
+ "model.text.blocks.2.ln.bias": "model-00001-of-00004.safetensors",
180
+ "model.text.blocks.2.ln.weight": "model-00001-of-00004.safetensors",
181
+ "model.text.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
182
+ "model.text.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
183
+ "model.text.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
184
+ "model.text.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
185
+ "model.text.blocks.20.attn.proj.bias": "model-00004-of-00004.safetensors",
186
+ "model.text.blocks.20.attn.proj.weight": "model-00004-of-00004.safetensors",
187
+ "model.text.blocks.20.attn.qkv.bias": "model-00004-of-00004.safetensors",
188
+ "model.text.blocks.20.attn.qkv.weight": "model-00004-of-00004.safetensors",
189
+ "model.text.blocks.20.attn.tau.alpha": "model-00004-of-00004.safetensors",
190
+ "model.text.blocks.20.attn.tau.wq": "model-00004-of-00004.safetensors",
191
+ "model.text.blocks.20.attn.tau.wv": "model-00004-of-00004.safetensors",
192
+ "model.text.blocks.20.ln.bias": "model-00004-of-00004.safetensors",
193
+ "model.text.blocks.20.ln.weight": "model-00004-of-00004.safetensors",
194
+ "model.text.blocks.20.mlp.fc1.weight": "model-00004-of-00004.safetensors",
195
+ "model.text.blocks.20.mlp.fc2.weight": "model-00004-of-00004.safetensors",
196
+ "model.text.blocks.20.mlp.router.bias": "model-00004-of-00004.safetensors",
197
+ "model.text.blocks.20.mlp.router.weight": "model-00004-of-00004.safetensors",
198
+ "model.text.blocks.21.attn.proj.bias": "model-00004-of-00004.safetensors",
199
+ "model.text.blocks.21.attn.proj.weight": "model-00004-of-00004.safetensors",
200
+ "model.text.blocks.21.attn.qkv.bias": "model-00004-of-00004.safetensors",
201
+ "model.text.blocks.21.attn.qkv.weight": "model-00004-of-00004.safetensors",
202
+ "model.text.blocks.21.attn.tau.alpha": "model-00004-of-00004.safetensors",
203
+ "model.text.blocks.21.attn.tau.wq": "model-00004-of-00004.safetensors",
204
+ "model.text.blocks.21.attn.tau.wv": "model-00004-of-00004.safetensors",
205
+ "model.text.blocks.21.ln.bias": "model-00004-of-00004.safetensors",
206
+ "model.text.blocks.21.ln.weight": "model-00004-of-00004.safetensors",
207
+ "model.text.blocks.21.mlp.fc1.weight": "model-00004-of-00004.safetensors",
208
+ "model.text.blocks.21.mlp.fc2.weight": "model-00004-of-00004.safetensors",
209
+ "model.text.blocks.21.mlp.router.bias": "model-00004-of-00004.safetensors",
210
+ "model.text.blocks.21.mlp.router.weight": "model-00004-of-00004.safetensors",
211
+ "model.text.blocks.22.attn.proj.bias": "model-00004-of-00004.safetensors",
212
+ "model.text.blocks.22.attn.proj.weight": "model-00004-of-00004.safetensors",
213
+ "model.text.blocks.22.attn.qkv.bias": "model-00004-of-00004.safetensors",
214
+ "model.text.blocks.22.attn.qkv.weight": "model-00004-of-00004.safetensors",
215
+ "model.text.blocks.22.attn.tau.alpha": "model-00004-of-00004.safetensors",
216
+ "model.text.blocks.22.attn.tau.wq": "model-00004-of-00004.safetensors",
217
+ "model.text.blocks.22.attn.tau.wv": "model-00004-of-00004.safetensors",
218
+ "model.text.blocks.22.ln.bias": "model-00004-of-00004.safetensors",
219
+ "model.text.blocks.22.ln.weight": "model-00004-of-00004.safetensors",
220
+ "model.text.blocks.22.mlp.fc1.weight": "model-00004-of-00004.safetensors",
221
+ "model.text.blocks.22.mlp.fc2.weight": "model-00004-of-00004.safetensors",
222
+ "model.text.blocks.22.mlp.router.bias": "model-00004-of-00004.safetensors",
223
+ "model.text.blocks.22.mlp.router.weight": "model-00004-of-00004.safetensors",
224
+ "model.text.blocks.23.attn.proj.bias": "model-00004-of-00004.safetensors",
225
+ "model.text.blocks.23.attn.proj.weight": "model-00004-of-00004.safetensors",
226
+ "model.text.blocks.23.attn.qkv.bias": "model-00004-of-00004.safetensors",
227
+ "model.text.blocks.23.attn.qkv.weight": "model-00004-of-00004.safetensors",
228
+ "model.text.blocks.23.attn.tau.alpha": "model-00004-of-00004.safetensors",
229
+ "model.text.blocks.23.attn.tau.wq": "model-00004-of-00004.safetensors",
230
+ "model.text.blocks.23.attn.tau.wv": "model-00004-of-00004.safetensors",
231
+ "model.text.blocks.23.ln.bias": "model-00004-of-00004.safetensors",
232
+ "model.text.blocks.23.ln.weight": "model-00004-of-00004.safetensors",
233
+ "model.text.blocks.23.mlp.fc1.weight": "model-00004-of-00004.safetensors",
234
+ "model.text.blocks.23.mlp.fc2.weight": "model-00004-of-00004.safetensors",
235
+ "model.text.blocks.23.mlp.router.bias": "model-00004-of-00004.safetensors",
236
+ "model.text.blocks.23.mlp.router.weight": "model-00004-of-00004.safetensors",
237
+ "model.text.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
238
+ "model.text.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
239
+ "model.text.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
240
+ "model.text.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
241
+ "model.text.blocks.3.attn.tau.alpha": "model-00001-of-00004.safetensors",
242
+ "model.text.blocks.3.attn.tau.wq": "model-00001-of-00004.safetensors",
243
+ "model.text.blocks.3.attn.tau.wv": "model-00001-of-00004.safetensors",
244
+ "model.text.blocks.3.ln.bias": "model-00001-of-00004.safetensors",
245
+ "model.text.blocks.3.ln.weight": "model-00001-of-00004.safetensors",
246
+ "model.text.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
247
+ "model.text.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
248
+ "model.text.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
249
+ "model.text.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
250
+ "model.text.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
251
+ "model.text.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
252
+ "model.text.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
253
+ "model.text.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
254
+ "model.text.blocks.4.attn.tau.alpha": "model-00001-of-00004.safetensors",
255
+ "model.text.blocks.4.attn.tau.wq": "model-00001-of-00004.safetensors",
256
+ "model.text.blocks.4.attn.tau.wv": "model-00001-of-00004.safetensors",
257
+ "model.text.blocks.4.ln.bias": "model-00001-of-00004.safetensors",
258
+ "model.text.blocks.4.ln.weight": "model-00001-of-00004.safetensors",
259
+ "model.text.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
260
+ "model.text.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
261
+ "model.text.blocks.4.mlp.router.bias": "model-00001-of-00004.safetensors",
262
+ "model.text.blocks.4.mlp.router.weight": "model-00001-of-00004.safetensors",
263
+ "model.text.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
264
+ "model.text.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
265
+ "model.text.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
266
+ "model.text.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
267
+ "model.text.blocks.5.attn.tau.alpha": "model-00001-of-00004.safetensors",
268
+ "model.text.blocks.5.attn.tau.wq": "model-00001-of-00004.safetensors",
269
+ "model.text.blocks.5.attn.tau.wv": "model-00001-of-00004.safetensors",
270
+ "model.text.blocks.5.ln.bias": "model-00001-of-00004.safetensors",
271
+ "model.text.blocks.5.ln.weight": "model-00001-of-00004.safetensors",
272
+ "model.text.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
273
+ "model.text.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
274
+ "model.text.blocks.5.mlp.router.bias": "model-00001-of-00004.safetensors",
275
+ "model.text.blocks.5.mlp.router.weight": "model-00001-of-00004.safetensors",
276
+ "model.text.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
277
+ "model.text.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
278
+ "model.text.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
279
+ "model.text.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
280
+ "model.text.blocks.6.attn.tau.alpha": "model-00001-of-00004.safetensors",
281
+ "model.text.blocks.6.attn.tau.wq": "model-00001-of-00004.safetensors",
282
+ "model.text.blocks.6.attn.tau.wv": "model-00001-of-00004.safetensors",
283
+ "model.text.blocks.6.ln.bias": "model-00001-of-00004.safetensors",
284
+ "model.text.blocks.6.ln.weight": "model-00001-of-00004.safetensors",
285
+ "model.text.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
286
+ "model.text.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
287
+ "model.text.blocks.6.mlp.router.bias": "model-00001-of-00004.safetensors",
288
+ "model.text.blocks.6.mlp.router.weight": "model-00001-of-00004.safetensors",
289
+ "model.text.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
290
+ "model.text.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
291
+ "model.text.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
292
+ "model.text.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
293
+ "model.text.blocks.7.attn.tau.alpha": "model-00001-of-00004.safetensors",
294
+ "model.text.blocks.7.attn.tau.wq": "model-00001-of-00004.safetensors",
295
+ "model.text.blocks.7.attn.tau.wv": "model-00001-of-00004.safetensors",
296
+ "model.text.blocks.7.ln.bias": "model-00001-of-00004.safetensors",
297
+ "model.text.blocks.7.ln.weight": "model-00001-of-00004.safetensors",
298
+ "model.text.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
299
+ "model.text.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
300
+ "model.text.blocks.7.mlp.router.bias": "model-00001-of-00004.safetensors",
301
+ "model.text.blocks.7.mlp.router.weight": "model-00001-of-00004.safetensors",
302
+ "model.text.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
303
+ "model.text.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
304
+ "model.text.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
305
+ "model.text.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
306
+ "model.text.blocks.8.attn.tau.alpha": "model-00001-of-00004.safetensors",
307
+ "model.text.blocks.8.attn.tau.wq": "model-00001-of-00004.safetensors",
308
+ "model.text.blocks.8.attn.tau.wv": "model-00001-of-00004.safetensors",
309
+ "model.text.blocks.8.ln.bias": "model-00001-of-00004.safetensors",
310
+ "model.text.blocks.8.ln.weight": "model-00001-of-00004.safetensors",
311
+ "model.text.blocks.8.mlp.fc1.weight": "model-00002-of-00004.safetensors",
312
+ "model.text.blocks.8.mlp.fc2.weight": "model-00002-of-00004.safetensors",
313
+ "model.text.blocks.8.mlp.router.bias": "model-00001-of-00004.safetensors",
314
+ "model.text.blocks.8.mlp.router.weight": "model-00001-of-00004.safetensors",
315
+ "model.text.blocks.9.attn.proj.bias": "model-00002-of-00004.safetensors",
316
+ "model.text.blocks.9.attn.proj.weight": "model-00002-of-00004.safetensors",
317
+ "model.text.blocks.9.attn.qkv.bias": "model-00002-of-00004.safetensors",
318
+ "model.text.blocks.9.attn.qkv.weight": "model-00002-of-00004.safetensors",
319
+ "model.text.blocks.9.attn.tau.alpha": "model-00002-of-00004.safetensors",
320
+ "model.text.blocks.9.attn.tau.wq": "model-00002-of-00004.safetensors",
321
+ "model.text.blocks.9.attn.tau.wv": "model-00002-of-00004.safetensors",
322
+ "model.text.blocks.9.ln.bias": "model-00002-of-00004.safetensors",
323
+ "model.text.blocks.9.ln.weight": "model-00002-of-00004.safetensors",
324
+ "model.text.blocks.9.mlp.fc1.weight": "model-00002-of-00004.safetensors",
325
+ "model.text.blocks.9.mlp.fc2.weight": "model-00002-of-00004.safetensors",
326
+ "model.text.blocks.9.mlp.router.bias": "model-00002-of-00004.safetensors",
327
+ "model.text.blocks.9.mlp.router.weight": "model-00002-of-00004.safetensors",
328
+ "model.text.lm_head.bias": "model-00004-of-00004.safetensors",
329
+ "model.text.lm_head.weight": "model-00004-of-00004.safetensors",
330
+ "model.text.post_ln.bias": "model-00004-of-00004.safetensors",
331
+ "model.text.post_ln.weight": "model-00004-of-00004.safetensors",
332
+ "model.text.wte": "model-00001-of-00004.safetensors",
333
+ "model.vision.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
334
+ "model.vision.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
335
+ "model.vision.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
336
+ "model.vision.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
337
+ "model.vision.blocks.0.ln1.bias": "model-00001-of-00004.safetensors",
338
+ "model.vision.blocks.0.ln1.weight": "model-00001-of-00004.safetensors",
339
+ "model.vision.blocks.0.ln2.bias": "model-00001-of-00004.safetensors",
340
+ "model.vision.blocks.0.ln2.weight": "model-00001-of-00004.safetensors",
341
+ "model.vision.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
342
+ "model.vision.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
343
+ "model.vision.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
344
+ "model.vision.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
345
+ "model.vision.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
346
+ "model.vision.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
347
+ "model.vision.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
348
+ "model.vision.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
349
+ "model.vision.blocks.1.ln1.bias": "model-00001-of-00004.safetensors",
350
+ "model.vision.blocks.1.ln1.weight": "model-00001-of-00004.safetensors",
351
+ "model.vision.blocks.1.ln2.bias": "model-00001-of-00004.safetensors",
352
+ "model.vision.blocks.1.ln2.weight": "model-00001-of-00004.safetensors",
353
+ "model.vision.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
354
+ "model.vision.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
355
+ "model.vision.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
356
+ "model.vision.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
357
+ "model.vision.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
358
+ "model.vision.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
359
+ "model.vision.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
360
+ "model.vision.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
361
+ "model.vision.blocks.10.ln1.bias": "model-00001-of-00004.safetensors",
362
+ "model.vision.blocks.10.ln1.weight": "model-00001-of-00004.safetensors",
363
+ "model.vision.blocks.10.ln2.bias": "model-00001-of-00004.safetensors",
364
+ "model.vision.blocks.10.ln2.weight": "model-00001-of-00004.safetensors",
365
+ "model.vision.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
366
+ "model.vision.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
367
+ "model.vision.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
368
+ "model.vision.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
369
+ "model.vision.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
370
+ "model.vision.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.vision.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
372
+ "model.vision.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
373
+ "model.vision.blocks.11.ln1.bias": "model-00001-of-00004.safetensors",
374
+ "model.vision.blocks.11.ln1.weight": "model-00001-of-00004.safetensors",
375
+ "model.vision.blocks.11.ln2.bias": "model-00001-of-00004.safetensors",
376
+ "model.vision.blocks.11.ln2.weight": "model-00001-of-00004.safetensors",
377
+ "model.vision.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
378
+ "model.vision.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
379
+ "model.vision.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
380
+ "model.vision.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
381
+ "model.vision.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
382
+ "model.vision.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
383
+ "model.vision.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
384
+ "model.vision.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
385
+ "model.vision.blocks.12.ln1.bias": "model-00001-of-00004.safetensors",
386
+ "model.vision.blocks.12.ln1.weight": "model-00001-of-00004.safetensors",
387
+ "model.vision.blocks.12.ln2.bias": "model-00001-of-00004.safetensors",
388
+ "model.vision.blocks.12.ln2.weight": "model-00001-of-00004.safetensors",
389
+ "model.vision.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
390
+ "model.vision.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
391
+ "model.vision.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
392
+ "model.vision.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
393
+ "model.vision.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
394
+ "model.vision.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
395
+ "model.vision.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
396
+ "model.vision.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
397
+ "model.vision.blocks.13.ln1.bias": "model-00001-of-00004.safetensors",
398
+ "model.vision.blocks.13.ln1.weight": "model-00001-of-00004.safetensors",
399
+ "model.vision.blocks.13.ln2.bias": "model-00001-of-00004.safetensors",
400
+ "model.vision.blocks.13.ln2.weight": "model-00001-of-00004.safetensors",
401
+ "model.vision.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
402
+ "model.vision.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
403
+ "model.vision.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
404
+ "model.vision.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
405
+ "model.vision.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
406
+ "model.vision.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
407
+ "model.vision.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
408
+ "model.vision.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
409
+ "model.vision.blocks.14.ln1.bias": "model-00001-of-00004.safetensors",
410
+ "model.vision.blocks.14.ln1.weight": "model-00001-of-00004.safetensors",
411
+ "model.vision.blocks.14.ln2.bias": "model-00001-of-00004.safetensors",
412
+ "model.vision.blocks.14.ln2.weight": "model-00001-of-00004.safetensors",
413
+ "model.vision.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
414
+ "model.vision.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
415
+ "model.vision.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
416
+ "model.vision.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
417
+ "model.vision.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
418
+ "model.vision.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
419
+ "model.vision.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
420
+ "model.vision.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
421
+ "model.vision.blocks.15.ln1.bias": "model-00001-of-00004.safetensors",
422
+ "model.vision.blocks.15.ln1.weight": "model-00001-of-00004.safetensors",
423
+ "model.vision.blocks.15.ln2.bias": "model-00001-of-00004.safetensors",
424
+ "model.vision.blocks.15.ln2.weight": "model-00001-of-00004.safetensors",
425
+ "model.vision.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
426
+ "model.vision.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
427
+ "model.vision.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
428
+ "model.vision.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
429
+ "model.vision.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
430
+ "model.vision.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
431
+ "model.vision.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
432
+ "model.vision.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
433
+ "model.vision.blocks.16.ln1.bias": "model-00001-of-00004.safetensors",
434
+ "model.vision.blocks.16.ln1.weight": "model-00001-of-00004.safetensors",
435
+ "model.vision.blocks.16.ln2.bias": "model-00001-of-00004.safetensors",
436
+ "model.vision.blocks.16.ln2.weight": "model-00001-of-00004.safetensors",
437
+ "model.vision.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
438
+ "model.vision.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
439
+ "model.vision.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
440
+ "model.vision.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
441
+ "model.vision.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
442
+ "model.vision.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
443
+ "model.vision.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
444
+ "model.vision.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
445
+ "model.vision.blocks.17.ln1.bias": "model-00001-of-00004.safetensors",
446
+ "model.vision.blocks.17.ln1.weight": "model-00001-of-00004.safetensors",
447
+ "model.vision.blocks.17.ln2.bias": "model-00001-of-00004.safetensors",
448
+ "model.vision.blocks.17.ln2.weight": "model-00001-of-00004.safetensors",
449
+ "model.vision.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
450
+ "model.vision.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
451
+ "model.vision.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
452
+ "model.vision.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
453
+ "model.vision.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
454
+ "model.vision.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
455
+ "model.vision.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
456
+ "model.vision.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
457
+ "model.vision.blocks.18.ln1.bias": "model-00001-of-00004.safetensors",
458
+ "model.vision.blocks.18.ln1.weight": "model-00001-of-00004.safetensors",
459
+ "model.vision.blocks.18.ln2.bias": "model-00001-of-00004.safetensors",
460
+ "model.vision.blocks.18.ln2.weight": "model-00001-of-00004.safetensors",
461
+ "model.vision.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
462
+ "model.vision.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
463
+ "model.vision.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
464
+ "model.vision.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
465
+ "model.vision.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
466
+ "model.vision.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
467
+ "model.vision.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
468
+ "model.vision.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
469
+ "model.vision.blocks.19.ln1.bias": "model-00001-of-00004.safetensors",
470
+ "model.vision.blocks.19.ln1.weight": "model-00001-of-00004.safetensors",
471
+ "model.vision.blocks.19.ln2.bias": "model-00001-of-00004.safetensors",
472
+ "model.vision.blocks.19.ln2.weight": "model-00001-of-00004.safetensors",
473
+ "model.vision.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
474
+ "model.vision.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
475
+ "model.vision.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
476
+ "model.vision.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
477
+ "model.vision.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
478
+ "model.vision.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
479
+ "model.vision.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
480
+ "model.vision.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
481
+ "model.vision.blocks.2.ln1.bias": "model-00001-of-00004.safetensors",
482
+ "model.vision.blocks.2.ln1.weight": "model-00001-of-00004.safetensors",
483
+ "model.vision.blocks.2.ln2.bias": "model-00001-of-00004.safetensors",
484
+ "model.vision.blocks.2.ln2.weight": "model-00001-of-00004.safetensors",
485
+ "model.vision.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
486
+ "model.vision.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
487
+ "model.vision.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
488
+ "model.vision.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
489
+ "model.vision.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
490
+ "model.vision.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
491
+ "model.vision.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
492
+ "model.vision.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
493
+ "model.vision.blocks.20.ln1.bias": "model-00001-of-00004.safetensors",
494
+ "model.vision.blocks.20.ln1.weight": "model-00001-of-00004.safetensors",
495
+ "model.vision.blocks.20.ln2.bias": "model-00001-of-00004.safetensors",
496
+ "model.vision.blocks.20.ln2.weight": "model-00001-of-00004.safetensors",
497
+ "model.vision.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
498
+ "model.vision.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
499
+ "model.vision.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
500
+ "model.vision.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
501
+ "model.vision.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
502
+ "model.vision.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
503
+ "model.vision.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
504
+ "model.vision.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
505
+ "model.vision.blocks.21.ln1.bias": "model-00001-of-00004.safetensors",
506
+ "model.vision.blocks.21.ln1.weight": "model-00001-of-00004.safetensors",
507
+ "model.vision.blocks.21.ln2.bias": "model-00001-of-00004.safetensors",
508
+ "model.vision.blocks.21.ln2.weight": "model-00001-of-00004.safetensors",
509
+ "model.vision.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
510
+ "model.vision.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
511
+ "model.vision.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
512
+ "model.vision.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
513
+ "model.vision.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
514
+ "model.vision.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
515
+ "model.vision.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
516
+ "model.vision.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
517
+ "model.vision.blocks.22.ln1.bias": "model-00001-of-00004.safetensors",
518
+ "model.vision.blocks.22.ln1.weight": "model-00001-of-00004.safetensors",
519
+ "model.vision.blocks.22.ln2.bias": "model-00001-of-00004.safetensors",
520
+ "model.vision.blocks.22.ln2.weight": "model-00001-of-00004.safetensors",
521
+ "model.vision.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
522
+ "model.vision.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
523
+ "model.vision.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
524
+ "model.vision.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
525
+ "model.vision.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
526
+ "model.vision.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
527
+ "model.vision.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
528
+ "model.vision.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
529
+ "model.vision.blocks.23.ln1.bias": "model-00001-of-00004.safetensors",
530
+ "model.vision.blocks.23.ln1.weight": "model-00001-of-00004.safetensors",
531
+ "model.vision.blocks.23.ln2.bias": "model-00001-of-00004.safetensors",
532
+ "model.vision.blocks.23.ln2.weight": "model-00001-of-00004.safetensors",
533
+ "model.vision.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
534
+ "model.vision.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
535
+ "model.vision.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
536
+ "model.vision.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
537
+ "model.vision.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
538
+ "model.vision.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
539
+ "model.vision.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
540
+ "model.vision.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
541
+ "model.vision.blocks.24.ln1.bias": "model-00001-of-00004.safetensors",
542
+ "model.vision.blocks.24.ln1.weight": "model-00001-of-00004.safetensors",
543
+ "model.vision.blocks.24.ln2.bias": "model-00001-of-00004.safetensors",
544
+ "model.vision.blocks.24.ln2.weight": "model-00001-of-00004.safetensors",
545
+ "model.vision.blocks.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
546
+ "model.vision.blocks.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
547
+ "model.vision.blocks.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
548
+ "model.vision.blocks.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
549
+ "model.vision.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
550
+ "model.vision.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
551
+ "model.vision.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
552
+ "model.vision.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
553
+ "model.vision.blocks.25.ln1.bias": "model-00001-of-00004.safetensors",
554
+ "model.vision.blocks.25.ln1.weight": "model-00001-of-00004.safetensors",
555
+ "model.vision.blocks.25.ln2.bias": "model-00001-of-00004.safetensors",
556
+ "model.vision.blocks.25.ln2.weight": "model-00001-of-00004.safetensors",
557
+ "model.vision.blocks.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
558
+ "model.vision.blocks.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
559
+ "model.vision.blocks.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
560
+ "model.vision.blocks.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
561
+ "model.vision.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
562
+ "model.vision.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
563
+ "model.vision.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
564
+ "model.vision.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
565
+ "model.vision.blocks.26.ln1.bias": "model-00001-of-00004.safetensors",
566
+ "model.vision.blocks.26.ln1.weight": "model-00001-of-00004.safetensors",
567
+ "model.vision.blocks.26.ln2.bias": "model-00001-of-00004.safetensors",
568
+ "model.vision.blocks.26.ln2.weight": "model-00001-of-00004.safetensors",
569
+ "model.vision.blocks.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
570
+ "model.vision.blocks.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
571
+ "model.vision.blocks.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
572
+ "model.vision.blocks.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
573
+ "model.vision.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
574
+ "model.vision.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
575
+ "model.vision.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
576
+ "model.vision.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
577
+ "model.vision.blocks.3.ln1.bias": "model-00001-of-00004.safetensors",
578
+ "model.vision.blocks.3.ln1.weight": "model-00001-of-00004.safetensors",
579
+ "model.vision.blocks.3.ln2.bias": "model-00001-of-00004.safetensors",
580
+ "model.vision.blocks.3.ln2.weight": "model-00001-of-00004.safetensors",
581
+ "model.vision.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
582
+ "model.vision.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
583
+ "model.vision.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
584
+ "model.vision.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
585
+ "model.vision.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
586
+ "model.vision.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
587
+ "model.vision.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
588
+ "model.vision.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
589
+ "model.vision.blocks.4.ln1.bias": "model-00001-of-00004.safetensors",
590
+ "model.vision.blocks.4.ln1.weight": "model-00001-of-00004.safetensors",
591
+ "model.vision.blocks.4.ln2.bias": "model-00001-of-00004.safetensors",
592
+ "model.vision.blocks.4.ln2.weight": "model-00001-of-00004.safetensors",
593
+ "model.vision.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
594
+ "model.vision.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
595
+ "model.vision.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
596
+ "model.vision.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
597
+ "model.vision.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
598
+ "model.vision.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
599
+ "model.vision.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
600
+ "model.vision.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
601
+ "model.vision.blocks.5.ln1.bias": "model-00001-of-00004.safetensors",
602
+ "model.vision.blocks.5.ln1.weight": "model-00001-of-00004.safetensors",
603
+ "model.vision.blocks.5.ln2.bias": "model-00001-of-00004.safetensors",
604
+ "model.vision.blocks.5.ln2.weight": "model-00001-of-00004.safetensors",
605
+ "model.vision.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
606
+ "model.vision.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
607
+ "model.vision.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
608
+ "model.vision.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
609
+ "model.vision.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
610
+ "model.vision.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
611
+ "model.vision.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
612
+ "model.vision.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
613
+ "model.vision.blocks.6.ln1.bias": "model-00001-of-00004.safetensors",
614
+ "model.vision.blocks.6.ln1.weight": "model-00001-of-00004.safetensors",
615
+ "model.vision.blocks.6.ln2.bias": "model-00001-of-00004.safetensors",
616
+ "model.vision.blocks.6.ln2.weight": "model-00001-of-00004.safetensors",
617
+ "model.vision.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
618
+ "model.vision.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
619
+ "model.vision.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
620
+ "model.vision.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
621
+ "model.vision.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
622
+ "model.vision.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
623
+ "model.vision.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
624
+ "model.vision.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
625
+ "model.vision.blocks.7.ln1.bias": "model-00001-of-00004.safetensors",
626
+ "model.vision.blocks.7.ln1.weight": "model-00001-of-00004.safetensors",
627
+ "model.vision.blocks.7.ln2.bias": "model-00001-of-00004.safetensors",
628
+ "model.vision.blocks.7.ln2.weight": "model-00001-of-00004.safetensors",
629
+ "model.vision.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
630
+ "model.vision.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
631
+ "model.vision.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
632
+ "model.vision.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
633
+ "model.vision.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
634
+ "model.vision.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
635
+ "model.vision.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
636
+ "model.vision.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
637
+ "model.vision.blocks.8.ln1.bias": "model-00001-of-00004.safetensors",
638
+ "model.vision.blocks.8.ln1.weight": "model-00001-of-00004.safetensors",
639
+ "model.vision.blocks.8.ln2.bias": "model-00001-of-00004.safetensors",
640
+ "model.vision.blocks.8.ln2.weight": "model-00001-of-00004.safetensors",
641
+ "model.vision.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
642
+ "model.vision.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
643
+ "model.vision.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
644
+ "model.vision.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
645
+ "model.vision.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
646
+ "model.vision.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
647
+ "model.vision.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
648
+ "model.vision.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
649
+ "model.vision.blocks.9.ln1.bias": "model-00001-of-00004.safetensors",
650
+ "model.vision.blocks.9.ln1.weight": "model-00001-of-00004.safetensors",
651
+ "model.vision.blocks.9.ln2.bias": "model-00001-of-00004.safetensors",
652
+ "model.vision.blocks.9.ln2.weight": "model-00001-of-00004.safetensors",
653
+ "model.vision.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
654
+ "model.vision.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
655
+ "model.vision.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
656
+ "model.vision.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
657
+ "model.vision.patch_emb.bias": "model-00001-of-00004.safetensors",
658
+ "model.vision.patch_emb.weight": "model-00001-of-00004.safetensors",
659
+ "model.vision.pos_emb": "model-00001-of-00004.safetensors",
660
+ "model.vision.post_ln.bias": "model-00001-of-00004.safetensors",
661
+ "model.vision.post_ln.weight": "model-00001-of-00004.safetensors",
662
+ "model.vision.proj_mlp.fc1.bias": "model-00001-of-00004.safetensors",
663
+ "model.vision.proj_mlp.fc1.weight": "model-00001-of-00004.safetensors",
664
+ "model.vision.proj_mlp.fc2.bias": "model-00001-of-00004.safetensors",
665
+ "model.vision.proj_mlp.fc2.weight": "model-00001-of-00004.safetensors"
666
+ }
667
+ }
moondream.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
+ from PIL import Image
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
+ from torch.nn.attention.flex_attention import create_block_mask
10
+
11
+ from .config import MoondreamConfig
12
+ from .image_crops import reconstruct_from_crops
13
+ from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
14
+ from .text import build_text_model, text_encoder, lm_head, text_decoder
15
+ from .region import (
16
+ decode_coordinate,
17
+ encode_coordinate,
18
+ decode_size,
19
+ encode_size,
20
+ encode_spatial_refs,
21
+ SpatialRefs,
22
+ )
23
+ from .layers import QuantizedLinear
24
+ from .lora import variant_state_dict
25
+ from .utils import remove_outlier_points
26
+
27
+ ImageEncodingSettings = TypedDict(
28
+ "ImageEncodingSettings",
29
+ {"variant": str},
30
+ total=False,
31
+ )
32
+
33
+ TextSamplingSettings = TypedDict(
34
+ "TextSamplingSettings",
35
+ {
36
+ "max_tokens": int,
37
+ "temperature": float,
38
+ "top_p": float,
39
+ "variant": str,
40
+ },
41
+ total=False,
42
+ )
43
+
44
+ ObjectSamplingSettings = TypedDict(
45
+ "ObjectSamplingSettings",
46
+ {"max_objects": int, "variant": str},
47
+ total=False,
48
+ )
49
+
50
+
51
+ DEFAULT_MAX_TOKENS = 768
52
+ DEFAULT_TEMPERATURE = 0.5
53
+ DEFAULT_TOP_P = 0.9
54
+ DEFAULT_MAX_OBJECTS = 50
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class EncodedImage:
59
+ pos: int
60
+ caches: List[Tuple[torch.Tensor, torch.Tensor]]
61
+
62
+
63
+ class KVCache(nn.Module):
64
+
65
+ def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
66
+ super().__init__()
67
+ cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
68
+ self.register_buffer(
69
+ "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
70
+ )
71
+ self.register_buffer(
72
+ "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
73
+ )
74
+
75
+ def update(self, pos_ids, k, v):
76
+ kout, vout = self.k_cache, self.v_cache
77
+ kout[:, :, pos_ids, :] = k
78
+ vout[:, :, pos_ids, :] = v
79
+ return kout, vout
80
+
81
+
82
+ def causal_mask(b, h, q_idx, kv_idx):
83
+ return q_idx >= kv_idx
84
+
85
+
86
+ def get_mask_mod(mask_mod, offset):
87
+ def _mask_mod(b, h, q, kv):
88
+ return mask_mod(b, h, q + offset, kv)
89
+
90
+ return _mask_mod
91
+
92
+
93
+ class MoondreamModel(nn.Module):
94
+
95
+ def __init__(
96
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
97
+ ):
98
+ super().__init__()
99
+ self.config = config
100
+
101
+ self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
102
+ self.vision = build_vision_model(config.vision, dtype)
103
+ self.text = build_text_model(config.text, dtype)
104
+
105
+ # Region Model
106
+ linear_cls = (
107
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
108
+ )
109
+ self.region = nn.ModuleDict(
110
+ {
111
+ "coord_encoder": linear_cls(
112
+ config.region.coord_feat_dim, config.region.dim, dtype=dtype
113
+ ),
114
+ "coord_decoder": linear_cls(
115
+ config.region.dim, config.region.coord_out_dim, dtype=dtype
116
+ ),
117
+ "size_encoder": linear_cls(
118
+ config.region.size_feat_dim, config.region.dim, dtype=dtype
119
+ ),
120
+ "size_decoder": linear_cls(
121
+ config.region.dim, config.region.size_out_dim, dtype=dtype
122
+ ),
123
+ }
124
+ )
125
+ self.region.coord_features = nn.Parameter(
126
+ torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
127
+ )
128
+ self.region.size_features = nn.Parameter(
129
+ torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
130
+ )
131
+
132
+ attn_mask = torch.tril(
133
+ torch.ones(
134
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
135
+ )
136
+ )
137
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
138
+ prefix_attn_len = 1 + patch_w**2
139
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
140
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
141
+
142
+ self.use_flex_decoding = True
143
+ self._causal_block_mask = None
144
+ self._point_gen_indices = None
145
+
146
+ # Initialize KV caches.
147
+ if setup_caches:
148
+ self._setup_caches()
149
+
150
+ @property
151
+ def causal_block_mask(self):
152
+ # The things we do to deal with ZeroGPU...
153
+ if self._causal_block_mask is None:
154
+ self._causal_block_mask = create_block_mask(
155
+ causal_mask,
156
+ B=None,
157
+ H=None,
158
+ Q_LEN=self.config.text.max_context,
159
+ KV_LEN=self.config.text.max_context,
160
+ )
161
+ return self._causal_block_mask
162
+
163
+ @property
164
+ def point_gen_indices(self):
165
+ if self._point_gen_indices is None:
166
+ self._point_gen_indices = torch.tensor(
167
+ [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id],
168
+ device=self.device,
169
+ )
170
+ return self._point_gen_indices
171
+
172
+ def _setup_caches(self):
173
+ c = self.config.text
174
+ for b in self.text.blocks:
175
+ b.kv_cache = KVCache(
176
+ c.n_heads,
177
+ c.n_kv_heads,
178
+ c.max_context,
179
+ c.dim,
180
+ device=self.device,
181
+ dtype=self.vision.pos_emb.dtype,
182
+ )
183
+
184
+ @property
185
+ def device(self):
186
+ return self.vision.pos_emb.device
187
+
188
+ def _vis_enc(self, x: torch.Tensor):
189
+ return vision_encoder(x, self.vision, self.config.vision)
190
+
191
+ def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
192
+ return vision_projection(g, r, self.vision, self.config.vision)
193
+
194
+ def _prefill(
195
+ self,
196
+ x: torch.Tensor,
197
+ attn_mask: torch.Tensor,
198
+ pos_ids: torch.Tensor,
199
+ lora: Optional[torch.Tensor],
200
+ ):
201
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
202
+
203
+ def _decode_one_tok(
204
+ self,
205
+ x: torch.Tensor,
206
+ attn_mask: torch.Tensor,
207
+ pos_ids: torch.Tensor,
208
+ lora: Optional[torch.Tensor],
209
+ lm_head_indices: Optional[torch.Tensor] = None,
210
+ ):
211
+ if self.use_flex_decoding:
212
+ torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape")
213
+ block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0]
214
+ mask = self.causal_block_mask[:, :, block_index]
215
+ mask.seq_lengths = (1, mask.seq_lengths[1])
216
+ mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0])
217
+ else:
218
+ mask = None
219
+
220
+ hidden = text_decoder(
221
+ x,
222
+ self.text,
223
+ attn_mask,
224
+ pos_ids,
225
+ self.config.text,
226
+ lora=lora,
227
+ flex_block_mask_slice=mask,
228
+ )
229
+ logits = lm_head(hidden, self.text, indices=lm_head_indices)
230
+ return logits, hidden
231
+
232
+ def compile(self):
233
+ for module in self.modules():
234
+ if isinstance(module, QuantizedLinear):
235
+ module.unpack()
236
+
237
+ # Initialize lazy properties to avoid first-call overhead
238
+ self.causal_block_mask
239
+ self.point_gen_indices
240
+
241
+ # TODO: vision_projection and _prefill is not being compiled
242
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
243
+ self._decode_one_tok = torch.compile(
244
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
245
+ )
246
+
247
+ # Warm up compiled methods with dummy forward passes
248
+ device = self.device
249
+ dtype = self.vision.pos_emb.dtype
250
+ with torch.no_grad():
251
+ # Warmup vision encoder
252
+ dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype)
253
+ self._vis_enc(dummy_crops)
254
+
255
+ # Warmup _decode_one_tok (both normal and point generation modes)
256
+ dummy_emb = torch.randn(
257
+ 1, 1, self.config.text.dim, device=device, dtype=dtype
258
+ )
259
+ dummy_mask = torch.ones(
260
+ 1, 1, self.config.text.max_context, device=device, dtype=torch.bool
261
+ )
262
+ dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long)
263
+ self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None)
264
+ self._decode_one_tok(
265
+ dummy_emb,
266
+ dummy_mask,
267
+ dummy_pos_ids,
268
+ None,
269
+ lm_head_indices=self.point_gen_indices,
270
+ )
271
+
272
+ def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
273
+ all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
274
+
275
+ torch._dynamo.mark_dynamic(all_crops, 0)
276
+
277
+ outputs = self._vis_enc(all_crops)
278
+
279
+ global_features = outputs[0]
280
+ local_features = outputs[1:].view(
281
+ -1,
282
+ self.config.vision.enc_n_layers,
283
+ self.config.vision.enc_n_layers,
284
+ self.config.vision.enc_dim,
285
+ )
286
+
287
+ reconstructed = reconstruct_from_crops(
288
+ local_features,
289
+ tiling,
290
+ patch_size=1,
291
+ overlap_margin=self.config.vision.overlap_margin,
292
+ )
293
+
294
+ return self._vis_proj(global_features, reconstructed)
295
+
296
+ def encode_image(
297
+ self,
298
+ image: Union[Image.Image, EncodedImage],
299
+ settings: Optional[ImageEncodingSettings] = None,
300
+ ) -> EncodedImage:
301
+ if isinstance(image, EncodedImage):
302
+ return image
303
+ elif not isinstance(image, Image.Image):
304
+ raise ValueError("image must be a PIL Image or EncodedImage")
305
+
306
+ lora = (
307
+ variant_state_dict(settings["variant"], device=self.device)
308
+ if settings is not None and "variant" in settings
309
+ else None
310
+ )
311
+
312
+ # Run through text model in addition to the vision encoder, to minimize
313
+ # re-computation if multiple queries are performed on this image.
314
+ with torch.inference_mode():
315
+ img_emb = self._run_vision_encoder(image)
316
+ bos_emb = text_encoder(
317
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
318
+ self.text,
319
+ )
320
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
321
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
322
+ pos_ids = torch.arange(
323
+ inputs_embeds.size(1), dtype=torch.long, device=self.device
324
+ )
325
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
326
+
327
+ return EncodedImage(
328
+ pos=inputs_embeds.size(1),
329
+ caches=[
330
+ (
331
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
332
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
333
+ )
334
+ for b in self.text.blocks
335
+ ],
336
+ )
337
+
338
+ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
339
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
340
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
341
+ mask = probs_sum - probs_sort > top_p
342
+ probs_sort[mask] = 0.0
343
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
344
+ next_probs = torch.zeros_like(probs)
345
+ next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
346
+ return next_probs
347
+
348
+ def _prefill_prompt(
349
+ self,
350
+ prompt_tokens: torch.Tensor,
351
+ pos: int,
352
+ temperature: float,
353
+ top_p: float,
354
+ spatial_refs: Optional[SpatialRefs] = None,
355
+ attn_mask: Optional[torch.Tensor] = None,
356
+ lora: Optional[dict] = None,
357
+ ):
358
+ with torch.inference_mode():
359
+ prompt_emb = text_encoder(prompt_tokens, self.text)
360
+
361
+ if spatial_refs:
362
+ encoded_refs = encode_spatial_refs(spatial_refs, self.region)
363
+ prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
364
+ encoded_refs["coords"]
365
+ )
366
+ if encoded_refs["sizes"] is not None:
367
+ prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
368
+ encoded_refs["sizes"]
369
+ )
370
+
371
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
372
+
373
+ if attn_mask is None:
374
+ attn_mask = self.attn_mask
375
+
376
+ mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
377
+ pos_ids = torch.arange(
378
+ pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device
379
+ )
380
+ hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
381
+ logits_BV = lm_head(hidden_BC, self.text)
382
+
383
+ if temperature == 0:
384
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
385
+ else:
386
+ probs = torch.softmax(logits_BV / temperature, dim=-1)
387
+ probs = self._apply_top_p(probs, top_p)
388
+ next_token = torch.multinomial(probs, num_samples=1)
389
+
390
+ pos = pos + prompt_emb.size(1)
391
+ return logits_BV, hidden_BC, next_token, pos
392
+
393
+ def _generate_reasoning(
394
+ self,
395
+ prompt_tokens,
396
+ pos,
397
+ settings: Optional[TextSamplingSettings] = None,
398
+ spatial_refs: Optional[SpatialRefs] = None,
399
+ attn_mask: Optional[torch.Tensor] = None,
400
+ ) -> Tuple[int, str, List[dict]]:
401
+ max_tokens = (
402
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
403
+ if settings
404
+ else DEFAULT_MAX_TOKENS
405
+ )
406
+ temperature = (
407
+ settings.get("temperature", DEFAULT_TEMPERATURE)
408
+ if settings
409
+ else DEFAULT_TEMPERATURE
410
+ )
411
+ lora = (
412
+ variant_state_dict(settings["variant"], device=self.device)
413
+ if settings is not None and "variant" in settings
414
+ else None
415
+ )
416
+
417
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
418
+ eos_id = self.config.tokenizer.answer_id
419
+
420
+ _, last_hidden_BC, next_token, pos = self._prefill_prompt(
421
+ prompt_tokens,
422
+ pos,
423
+ temperature,
424
+ top_p,
425
+ spatial_refs,
426
+ attn_mask=attn_mask,
427
+ lora=lora,
428
+ )
429
+
430
+ text_token_chunks = [[]]
431
+ grounding_chunks = [[]]
432
+
433
+ mask = torch.zeros(
434
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
435
+ )
436
+ mask[:, :, :pos] = 1
437
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
438
+ generated_tokens = 0
439
+
440
+ while (
441
+ next_token_id := next_token.item()
442
+ ) != eos_id and generated_tokens < max_tokens:
443
+ if (
444
+ next_token_id == self.config.tokenizer.start_ground_points_id
445
+ or next_token_id == self.config.tokenizer.end_ground_id
446
+ ):
447
+ text_token_chunks.append([])
448
+ grounding_chunks.append([])
449
+
450
+ text_token_chunks[-1].append(next_token_id)
451
+
452
+ with torch.inference_mode():
453
+ if next_token_id == self.config.tokenizer.coord_id:
454
+ coord_logits = decode_coordinate(last_hidden_BC, self.region)
455
+ coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
456
+ grounding_chunks[-1].append(coord.item())
457
+
458
+ next_emb = encode_coordinate(
459
+ coord.to(dtype=coord_logits.dtype), self.region
460
+ ).unsqueeze(0)
461
+ else:
462
+ next_emb = text_encoder(next_token, self.text)
463
+
464
+ mask[:, :, pos], pos_ids[0] = 1, pos
465
+
466
+ logits_BV, last_hidden_BC = self._decode_one_tok(
467
+ next_emb, mask, pos_ids, lora
468
+ )
469
+ logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
470
+ logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
471
+
472
+ pos += 1
473
+
474
+ if temperature == 0:
475
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
476
+ else:
477
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
478
+ probs = self._apply_top_p(probs, top_p)
479
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
480
+
481
+ generated_tokens += 1
482
+
483
+ text_chunks = [
484
+ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
485
+ ]
486
+ text = "".join(text_chunks)
487
+
488
+ start_idx = 0
489
+ grounding = []
490
+ for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
491
+ if len(grounding_chunk) > 1:
492
+ points = []
493
+ for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
494
+ points.append((grounding_chunk[i], grounding_chunk[i + 1]))
495
+ grounding.append(
496
+ {
497
+ "start_idx": start_idx,
498
+ "end_idx": start_idx + len(text_chunk),
499
+ "points": points,
500
+ }
501
+ )
502
+ start_idx += len(text_chunk)
503
+
504
+ return pos, text, grounding
505
+
506
+ def _generate_answer(
507
+ self,
508
+ prompt_tokens: torch.Tensor,
509
+ pos: int,
510
+ settings: Optional[TextSamplingSettings] = None,
511
+ spatial_refs: Optional[SpatialRefs] = None,
512
+ eos_id: Optional[int] = None,
513
+ attn_mask: Optional[torch.Tensor] = None,
514
+ ):
515
+ max_tokens = (
516
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
517
+ if settings
518
+ else DEFAULT_MAX_TOKENS
519
+ )
520
+ temperature = (
521
+ settings.get("temperature", DEFAULT_TEMPERATURE)
522
+ if settings
523
+ else DEFAULT_TEMPERATURE
524
+ )
525
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
526
+ eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
527
+ lora = (
528
+ variant_state_dict(settings["variant"], device=self.device)
529
+ if settings is not None and "variant" in settings
530
+ else None
531
+ )
532
+
533
+ _, _, next_token, pos = self._prefill_prompt(
534
+ prompt_tokens,
535
+ pos,
536
+ temperature,
537
+ top_p,
538
+ spatial_refs,
539
+ attn_mask=attn_mask,
540
+ lora=lora,
541
+ )
542
+
543
+ def generator(next_token, pos):
544
+ mask = torch.zeros(
545
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
546
+ )
547
+ mask[:, :, :pos] = 1
548
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
549
+ generated_tokens = 0
550
+
551
+ # For properly handling token streaming with Unicode
552
+ token_cache = []
553
+ print_len = 0
554
+
555
+ while (
556
+ next_token_id := next_token.item()
557
+ ) != eos_id and generated_tokens < max_tokens:
558
+ # Add token to our cache
559
+ token_cache.append(next_token_id)
560
+
561
+ # Decode all tokens collected so far
562
+ text = self.tokenizer.decode(token_cache)
563
+
564
+ # After a newline, we flush the cache completely
565
+ if text.endswith("\n"):
566
+ printable_text = text[print_len:]
567
+ token_cache = []
568
+ print_len = 0
569
+ if printable_text:
570
+ yield printable_text
571
+ # If the last token is a CJK character, we can safely print it
572
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
573
+ printable_text = text[print_len:]
574
+ print_len += len(printable_text)
575
+ if printable_text:
576
+ yield printable_text
577
+ # Otherwise, only yield up to the last space to avoid cutting words
578
+ else:
579
+ last_space_idx = text.rfind(" ", print_len)
580
+ if last_space_idx >= print_len:
581
+ printable_text = text[print_len : last_space_idx + 1]
582
+ print_len += len(printable_text)
583
+ if printable_text:
584
+ yield printable_text
585
+
586
+ with torch.inference_mode():
587
+ next_emb = text_encoder(next_token, self.text)
588
+ mask[:, :, pos], pos_ids[0] = 1, pos
589
+
590
+ logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
591
+ logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
592
+
593
+ # Suppress EOS for the first token to ensure at least one answer token
594
+ if generated_tokens == 0:
595
+ logits_BV[:, eos_id] = float("-inf")
596
+
597
+ pos += 1
598
+
599
+ if temperature == 0:
600
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
601
+ 1
602
+ ) # (1, 1)
603
+ else:
604
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
605
+ probs = self._apply_top_p(probs, top_p)
606
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
607
+
608
+ generated_tokens += 1
609
+
610
+ # Flush any remaining text in the cache
611
+ if token_cache:
612
+ text = self.tokenizer.decode(token_cache)
613
+ printable_text = text[print_len:]
614
+ if printable_text:
615
+ yield printable_text
616
+
617
+ return generator(next_token, pos)
618
+
619
+ def query(
620
+ self,
621
+ image: Optional[Union[Image.Image, EncodedImage]] = None,
622
+ question: str = None,
623
+ reasoning: bool = False,
624
+ spatial_refs: Optional[SpatialRefs] = None,
625
+ stream: bool = False,
626
+ settings: Optional[TextSamplingSettings] = None,
627
+ ):
628
+ if self.config.tokenizer.templates["query"] is None:
629
+ raise NotImplementedError("Model does not support querying.")
630
+
631
+ if question is None:
632
+ raise ValueError("question must be provided.")
633
+
634
+ if spatial_refs and image is None:
635
+ raise ValueError("spatial_refs can only be used with an image.")
636
+
637
+ attn_mask = self.attn_mask
638
+ if image is not None:
639
+ image = self.encode_image(image, settings)
640
+ self.load_encoded_image(image)
641
+ pos = image.pos
642
+ prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
643
+ else:
644
+ self._setup_caches()
645
+ pos = 0
646
+ prompt_toks = [
647
+ self.config.tokenizer.bos_id
648
+ ] + self.config.tokenizer.templates["query"]["prefix"]
649
+ max_context = self.config.text.max_context
650
+ attn_mask = torch.tril(
651
+ torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
652
+ ).to(self.device)
653
+
654
+ spatial_toks = []
655
+ if spatial_refs:
656
+ for ref in spatial_refs:
657
+ coord_id = self.config.tokenizer.coord_id
658
+ size_id = self.config.tokenizer.size_id
659
+ if len(ref) == 2:
660
+ spatial_toks.extend([coord_id, coord_id])
661
+ else:
662
+ spatial_toks.extend([coord_id, coord_id, size_id])
663
+
664
+ prompt_tokens = [
665
+ prompt_toks
666
+ + spatial_toks
667
+ + self.tokenizer.encode(question).ids
668
+ + self.config.tokenizer.templates["query"]["suffix"]
669
+ ]
670
+
671
+ if reasoning:
672
+ prompt_tokens[0] += [self.config.tokenizer.thinking_id]
673
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
674
+ pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
675
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
676
+ )
677
+ prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
678
+ reasoning_dict = {
679
+ "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
680
+ }
681
+ else:
682
+ prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
683
+ reasoning_dict = {}
684
+
685
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
686
+
687
+ def generator():
688
+ for token in self._generate_answer(
689
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
690
+ ):
691
+ yield token
692
+
693
+ if stream:
694
+ return {**reasoning_dict, "answer": generator()}
695
+ else:
696
+ return {**reasoning_dict, "answer": "".join(list(generator()))}
697
+
698
+ def load_encoded_image(self, encoded_image: EncodedImage):
699
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
700
+ b.kv_cache.k_cache[:, :, : k.size(2), :] = k
701
+ b.kv_cache.v_cache[:, :, : v.size(2), :] = v
702
+
703
+ def caption(
704
+ self,
705
+ image: Union[Image.Image, EncodedImage],
706
+ length: Literal["normal", "short", "long"] = "normal",
707
+ stream: bool = False,
708
+ settings: Optional[TextSamplingSettings] = None,
709
+ ):
710
+ if self.config.tokenizer.templates["caption"] is None:
711
+ raise NotImplementedError("Model does not support captioning.")
712
+ if length not in self.config.tokenizer.templates["caption"]:
713
+ raise ValueError(f"Model does not support caption length '{length}'.")
714
+
715
+ image = self.encode_image(image, settings)
716
+ self.load_encoded_image(image)
717
+
718
+ prompt_tokens = torch.tensor(
719
+ [self.config.tokenizer.templates["caption"][length]], device=self.device
720
+ )
721
+
722
+ def generator():
723
+ for token in self._generate_answer(prompt_tokens, image.pos, settings):
724
+ yield token
725
+
726
+ if stream:
727
+ return {"caption": generator()}
728
+ else:
729
+ return {"caption": "".join(list(generator()))}
730
+
731
+ def _generate_points(
732
+ self,
733
+ hidden: torch.Tensor,
734
+ next_token: torch.Tensor,
735
+ pos: int,
736
+ include_size: bool = True,
737
+ max_objects: int = DEFAULT_MAX_OBJECTS,
738
+ lora: Optional[dict] = None,
739
+ ):
740
+ out = []
741
+ mask = torch.zeros(
742
+ 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
743
+ )
744
+ mask[:, :, :pos] = 1
745
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
746
+
747
+ with torch.inference_mode():
748
+ while (
749
+ next_token.item() != self.config.tokenizer.eos_id
750
+ and len(out) < max_objects
751
+ ):
752
+ x_logits = decode_coordinate(hidden, self.region)
753
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
754
+ next_emb = encode_coordinate(
755
+ x_center.to(dtype=x_logits.dtype), self.region
756
+ ).unsqueeze(0)
757
+
758
+ # Decode y-coordinate
759
+ mask[:, :, pos], pos_ids[0] = 1, pos
760
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
761
+ pos += 1
762
+ y_logits = decode_coordinate(hidden, self.region)
763
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
764
+ next_emb = encode_coordinate(
765
+ y_center.to(dtype=y_logits.dtype), self.region
766
+ ).unsqueeze(0)
767
+
768
+ # Decode size
769
+ if include_size:
770
+ mask[:, :, pos], pos_ids[0] = 1, pos
771
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
772
+ pos += 1
773
+ size_logits = decode_size(hidden, self.region)
774
+
775
+ # Get bin indices from the logits
776
+ w_bin = torch.argmax(size_logits[0], dim=-1)
777
+ h_bin = torch.argmax(size_logits[1], dim=-1)
778
+
779
+ # Convert from bin indices to actual size values using the inverse of the log-scale mapping
780
+ # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
781
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
782
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
783
+
784
+ next_emb = (
785
+ encode_size(
786
+ torch.tensor(
787
+ [w, h], device=self.device, dtype=size_logits.dtype
788
+ ),
789
+ self.region,
790
+ )
791
+ .unsqueeze(0)
792
+ .unsqueeze(0)
793
+ )
794
+
795
+ # Add object
796
+ out.append(
797
+ {
798
+ "x_min": x_center.item() - w.item() / 2,
799
+ "y_min": y_center.item() - h.item() / 2,
800
+ "x_max": x_center.item() + w.item() / 2,
801
+ "y_max": y_center.item() + h.item() / 2,
802
+ }
803
+ )
804
+ else:
805
+ out.append({"x": x_center.item(), "y": y_center.item()})
806
+
807
+ # Decode next token (x-coordinate, or eos)
808
+ mask[:, :, pos], pos_ids[0] = 1, pos
809
+ logits, hidden = self._decode_one_tok(
810
+ next_emb,
811
+ mask,
812
+ pos_ids,
813
+ lora,
814
+ lm_head_indices=self.point_gen_indices,
815
+ )
816
+ pos += 1
817
+ # Map back: index 0 -> coord_id, index 1 -> eos_id
818
+ next_token_idx = torch.argmax(logits, dim=-1)
819
+ next_token = self.point_gen_indices[next_token_idx]
820
+
821
+ return out
822
+
823
+ def detect(
824
+ self,
825
+ image: Union[Image.Image, EncodedImage],
826
+ object: str,
827
+ settings: Optional[ObjectSamplingSettings] = None,
828
+ ):
829
+ if self.config.tokenizer.templates["detect"] is None:
830
+ raise NotImplementedError("Model does not support object detection.")
831
+
832
+ image = self.encode_image(image, settings)
833
+ self.load_encoded_image(image)
834
+
835
+ prompt_tokens = torch.tensor(
836
+ [
837
+ self.config.tokenizer.templates["detect"]["prefix"]
838
+ + self.tokenizer.encode(" " + object).ids
839
+ + self.config.tokenizer.templates["detect"]["suffix"]
840
+ ],
841
+ device=self.device,
842
+ )
843
+
844
+ lora = (
845
+ variant_state_dict(settings["variant"], device=self.device)
846
+ if settings is not None and "variant" in settings
847
+ else None
848
+ )
849
+
850
+ _, hidden, next_token, pos = self._prefill_prompt(
851
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
852
+ )
853
+ hidden = hidden[:, -1:, :]
854
+
855
+ max_objects = (
856
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
857
+ if settings
858
+ else DEFAULT_MAX_OBJECTS
859
+ )
860
+ objects = self._generate_points(
861
+ hidden,
862
+ next_token,
863
+ pos,
864
+ include_size=True,
865
+ max_objects=max_objects,
866
+ lora=lora,
867
+ )
868
+
869
+ return {"objects": objects}
870
+
871
+ def point(
872
+ self,
873
+ image: Union[Image.Image, EncodedImage],
874
+ object: str,
875
+ settings: Optional[ObjectSamplingSettings] = None,
876
+ ):
877
+ if self.config.tokenizer.templates["point"] is None:
878
+ raise NotImplementedError("Model does not support pointing.")
879
+
880
+ image = self.encode_image(image, settings)
881
+ self.load_encoded_image(image)
882
+
883
+ prompt_tokens = torch.tensor(
884
+ [
885
+ self.config.tokenizer.templates["point"]["prefix"]
886
+ + self.tokenizer.encode(" " + object).ids
887
+ + self.config.tokenizer.templates["point"]["suffix"]
888
+ ],
889
+ device=self.device,
890
+ )
891
+
892
+ lora = (
893
+ variant_state_dict(settings["variant"], device=self.device)
894
+ if settings is not None and "variant" in settings
895
+ else None
896
+ )
897
+
898
+ _, hidden, next_token, pos = self._prefill_prompt(
899
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
900
+ )
901
+ hidden = hidden[:, -1:, :]
902
+
903
+ max_objects = (
904
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
905
+ if settings
906
+ else DEFAULT_MAX_OBJECTS
907
+ )
908
+ objects = self._generate_points(
909
+ hidden,
910
+ next_token,
911
+ pos,
912
+ include_size=False,
913
+ max_objects=max_objects,
914
+ lora=lora,
915
+ )
916
+
917
+ return {"points": objects}
918
+
919
+ def _detect_gaze(
920
+ self,
921
+ image: EncodedImage,
922
+ source: Tuple[float, float],
923
+ force_detect: bool = False,
924
+ ):
925
+ with torch.inference_mode():
926
+ before_emb = text_encoder(
927
+ torch.tensor(
928
+ [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
929
+ ),
930
+ self.text,
931
+ )
932
+ after_emb = text_encoder(
933
+ torch.tensor(
934
+ [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
935
+ ),
936
+ self.text,
937
+ )
938
+ x_emb = encode_coordinate(
939
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
940
+ self.region,
941
+ )
942
+ y_emb = encode_coordinate(
943
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
944
+ self.region,
945
+ )
946
+
947
+ prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
948
+
949
+ self.load_encoded_image(image)
950
+
951
+ mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
952
+ pos_ids = torch.arange(
953
+ image.pos,
954
+ image.pos + prompt_emb.size(1),
955
+ dtype=torch.long,
956
+ device=self.device,
957
+ )
958
+ hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
959
+ logits = lm_head(hidden, self.text)
960
+ next_token = torch.argmax(logits, dim=-1)
961
+ pos = image.pos + prompt_emb.size(1)
962
+ hidden = hidden[:, -1:, :]
963
+
964
+ if force_detect:
965
+ next_token = torch.tensor([[0]], device=self.device)
966
+
967
+ if next_token.item() == self.config.tokenizer.eos_id:
968
+ return None
969
+
970
+ gaze = self._generate_points(
971
+ hidden, next_token, pos, include_size=False, max_objects=1
972
+ )
973
+ return gaze[0]
974
+
975
+ def detect_gaze(
976
+ self,
977
+ image: Union[Image.Image, EncodedImage],
978
+ eye: Optional[Tuple[float, float]] = None,
979
+ face: Optional[Dict[str, float]] = None,
980
+ unstable_settings: Dict[str, Any] = {},
981
+ ):
982
+ if "force_detect" in unstable_settings:
983
+ force_detect = unstable_settings["force_detect"]
984
+ else:
985
+ force_detect = False
986
+
987
+ if "prioritize_accuracy" in unstable_settings:
988
+ prioritize_accuracy = unstable_settings["prioritize_accuracy"]
989
+ else:
990
+ prioritize_accuracy = False
991
+
992
+ if not prioritize_accuracy:
993
+ if eye is None:
994
+ raise ValueError("eye must be provided when prioritize_accuracy=False")
995
+ image = self.encode_image(image)
996
+ return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
997
+ else:
998
+ if (
999
+ not isinstance(image, Image.Image)
1000
+ and "flip_enc_img" not in unstable_settings
1001
+ ):
1002
+ raise ValueError(
1003
+ "image must be a PIL Image when prioritize_accuracy=True, "
1004
+ "or flip_enc_img must be provided"
1005
+ )
1006
+ if face is None:
1007
+ raise ValueError("face must be provided when prioritize_accuracy=True")
1008
+
1009
+ encoded_image = self.encode_image(image)
1010
+ if (
1011
+ isinstance(image, Image.Image)
1012
+ and "flip_enc_img" not in unstable_settings
1013
+ ):
1014
+ flipped_pil = image.copy()
1015
+ flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
1016
+ encoded_flipped_image = self.encode_image(flipped_pil)
1017
+ else:
1018
+ encoded_flipped_image = unstable_settings["flip_enc_img"]
1019
+
1020
+ N = 10
1021
+
1022
+ detections = [
1023
+ self._detect_gaze(
1024
+ encoded_image,
1025
+ (
1026
+ random.uniform(face["x_min"], face["x_max"]),
1027
+ random.uniform(face["y_min"], face["y_max"]),
1028
+ ),
1029
+ force_detect=force_detect,
1030
+ )
1031
+ for _ in range(N)
1032
+ ]
1033
+ detections = [
1034
+ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
1035
+ ]
1036
+ flipped_detections = [
1037
+ self._detect_gaze(
1038
+ encoded_flipped_image,
1039
+ (
1040
+ 1 - random.uniform(face["x_min"], face["x_max"]),
1041
+ random.uniform(face["y_min"], face["y_max"]),
1042
+ ),
1043
+ force_detect=force_detect,
1044
+ )
1045
+ for _ in range(N)
1046
+ ]
1047
+ detections.extend(
1048
+ [
1049
+ (1 - gaze["x"], gaze["y"])
1050
+ for gaze in flipped_detections
1051
+ if gaze is not None
1052
+ ]
1053
+ )
1054
+
1055
+ if len(detections) < N:
1056
+ return {"gaze": None}
1057
+
1058
+ detections = remove_outlier_points(detections)
1059
+ mean_gaze = (
1060
+ sum(gaze[0] for gaze in detections) / len(detections),
1061
+ sum(gaze[1] for gaze in detections) / len(detections),
1062
+ )
1063
+
1064
+ return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
1065
+
1066
+
1067
+ def _is_cjk_char(cp):
1068
+ """Checks whether CP is the codepoint of a CJK character."""
1069
+ # This defines a "chinese character" as anything in the CJK Unicode block:
1070
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
1071
+ if (
1072
+ (cp >= 0x4E00 and cp <= 0x9FFF)
1073
+ or (cp >= 0x3400 and cp <= 0x4DBF)
1074
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
1075
+ ):
1076
+ return True
1077
+ return False
region.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from typing import List, Tuple, Union
6
+
7
+ SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
8
+
9
+
10
+ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ Applies Fourier feature mapping to input tensor x using frequency matrix w. This
13
+ projects inputs through sinusoidal functions to create higher dimensional features
14
+ that help mitigate spectral bias - the tendency of neural networks to learn
15
+ low-frequency functions more easily than high-frequency ones. By explicitly
16
+ mapping inputs to higher frequencies through sin/cos transformations, we enable
17
+ better learning of fine details and higher frequency patterns.
18
+
19
+ Args:
20
+ x: Input tensor to transform
21
+ w: Matrix of frequencies for the Fourier features transformation
22
+
23
+ Returns:
24
+ Concatenated cosine and sine transformed features as a tensor
25
+ """
26
+ f = 2 * math.pi * x @ w
27
+ return torch.cat([f.cos(), f.sin()], dim=-1)
28
+
29
+
30
+ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
31
+ """
32
+ Takes as input a tensor containing a single float coordinate value (x or y)
33
+ and encodes it into hidden states for input to the text model.
34
+
35
+ Args:
36
+ coord: Tensor with single float coordinate value
37
+
38
+ Returns:
39
+ Encoded hidden states tensor for input to text model
40
+ """
41
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
42
+
43
+
44
+ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
45
+ """
46
+ Takes as input the last hidden state from the text model and outputs a single logit
47
+ representing either an x or y coordinate prediction.
48
+
49
+ Args:
50
+ hidden_state: The final hidden state tensor from the text model.
51
+
52
+ Returns:
53
+ A single logit representing the predicted coordinate value (x or y)
54
+ """
55
+ return w.coord_decoder(hidden_state)
56
+
57
+
58
+ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
59
+ """
60
+ Takes a tensor containing width and height values and encodes them into
61
+ hidden states for input to the text model.
62
+
63
+ Args:
64
+ size: Tensor with two floats for width and height
65
+
66
+ Returns:
67
+ Encoded hidden states tensor for input to text model
68
+ """
69
+ return w.size_encoder(fourier_features(size, w.size_features))
70
+
71
+
72
+ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
73
+ """
74
+ Takes as input the last hidden state from the text model and outputs logits
75
+ for 1024 bins representing width and height in log-scale.
76
+
77
+ The bins are distributed according to the formula:
78
+ bin = (log2(size) + 10.0) / 10.0 * 1023.0
79
+ where size values are clamped to be at least 1/1024.
80
+
81
+ To convert from bin back to size:
82
+ size = 2^((bin / 1023.0) * 10.0 - 10.0)
83
+
84
+ Args:
85
+ hidden_state: The final hidden state tensor from the text model.
86
+
87
+ Returns:
88
+ A tensor containing logits for 1024 bins for width and height.
89
+ Shape is (2, 1024) where the first dimension corresponds to width and height.
90
+ """
91
+ return w.size_decoder(hidden_state).view(2, -1)
92
+
93
+
94
+ def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
95
+ """
96
+ Takes a list of spatial references (points or regions) and encodes them into
97
+ hidden states for input to the text model.
98
+
99
+ Args:
100
+ spatial_refs: List of spatial references (points or boxes)
101
+ - Points are represented as normalized (x, y) tuples
102
+ - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
103
+
104
+ Returns:
105
+ {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
106
+ """
107
+ coords, sizes = [], []
108
+ for ref in spatial_refs:
109
+ if len(ref) == 2:
110
+ coords.append(ref[0])
111
+ coords.append(ref[1])
112
+ else:
113
+ x_c = (ref[0] + ref[2]) / 2
114
+ y_c = (ref[1] + ref[3]) / 2
115
+ width = ref[2] - ref[0]
116
+ height = ref[3] - ref[1]
117
+ coords.append(x_c)
118
+ coords.append(y_c)
119
+ sizes.append([width, height])
120
+
121
+ coords = torch.tensor(
122
+ coords, device=w.coord_features.device, dtype=w.coord_features.dtype
123
+ ).view(-1, 1)
124
+ coords = encode_coordinate(coords, w)
125
+
126
+ if sizes:
127
+ sizes = torch.tensor(
128
+ sizes, device=w.size_features.device, dtype=w.size_features.dtype
129
+ )
130
+ sizes = encode_size(sizes, w)
131
+ else:
132
+ sizes = None
133
+
134
+ return {"coords": coords, "sizes": sizes}
rope.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ import torch
4
+
5
+
6
+ def precompute_freqs_cis(
7
+ dim: int,
8
+ end: int,
9
+ theta: float = 1500000.0,
10
+ dtype: torch.dtype = torch.float32,
11
+ ) -> torch.Tensor:
12
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
13
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
14
+ freqs = t * freqs.unsqueeze(0)
15
+ freqs = torch.exp(1j * freqs)
16
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
17
+
18
+
19
+ def apply_rotary_emb(
20
+ x: torch.Tensor,
21
+ freqs_cis: torch.Tensor,
22
+ position_ids: torch.Tensor,
23
+ num_heads: int,
24
+ rot_dim: int = 32,
25
+ interleave: bool = False,
26
+ ) -> torch.Tensor:
27
+ assert rot_dim == freqs_cis.shape[-2] * 2
28
+ assert num_heads == x.shape[1]
29
+
30
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
31
+
32
+ if interleave:
33
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
34
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
35
+ else:
36
+ d_q = x_rot.shape[-1] // 2
37
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
38
+
39
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
40
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
41
+
42
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
43
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
44
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
45
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
46
+
47
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
text.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn import functional as F
5
+ from torch.nn.attention.flex_attention import flex_attention
6
+ from typing import Optional
7
+
8
+ from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
+ from .rope import apply_rotary_emb, precompute_freqs_cis
10
+ from .config import TextConfig
11
+
12
+
13
+ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
14
+ return F.embedding(input_ids, w.wte)
15
+
16
+
17
+ def attn(
18
+ x: torch.Tensor,
19
+ w: nn.Module,
20
+ freqs_cis: torch.Tensor,
21
+ kv_cache: nn.Module,
22
+ attn_mask: torch.Tensor,
23
+ n_heads: int,
24
+ n_kv_heads: int,
25
+ position_ids: torch.Tensor,
26
+ lora: Optional[dict] = None,
27
+ flex_block_mask_slice=None,
28
+ ):
29
+ bsz, q_len, d_model = x.shape
30
+ head_dim = d_model // n_heads
31
+
32
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
33
+ if lora is not None:
34
+ qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
35
+ q_dim = n_heads * head_dim
36
+ kv_dim = n_kv_heads * head_dim
37
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
38
+
39
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
40
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
41
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
42
+
43
+ if hasattr(w, "tau") and w.tau is not None:
44
+ tok_feat = F.gelu(qkv_out)
45
+ tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1)
46
+ tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1)
47
+ pos = position_ids.to(q.dtype) + 1
48
+ tau_pos = 1 + (
49
+ torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5
50
+ ) # (H,S)
51
+ tau_q = (tok_q + tau_pos[None]).unsqueeze(-1) # (B,H,S,1)
52
+ tau_v = (tok_v + tau_pos[None]).unsqueeze(-1)
53
+ q = q * tau_q
54
+ v = v * tau_v
55
+
56
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
57
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
58
+
59
+ if kv_cache is not None:
60
+ k, v = kv_cache.update(position_ids, k, v)
61
+
62
+ if flex_block_mask_slice is not None:
63
+ torch._assert(n_heads == n_kv_heads, "gqa not supported yet")
64
+ out = flex_attention(q, k, v, block_mask=flex_block_mask_slice)
65
+ else:
66
+ out = F.scaled_dot_product_attention(
67
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
68
+ )
69
+
70
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
71
+
72
+ out0 = w.proj(out)
73
+ if lora is not None:
74
+ out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
75
+ out = out0 + out1
76
+ else:
77
+ out = out0
78
+
79
+ return out
80
+
81
+
82
+ def text_decoder(
83
+ x: torch.Tensor,
84
+ w: nn.Module,
85
+ attn_mask: torch.Tensor,
86
+ position_ids: torch.Tensor,
87
+ config: TextConfig,
88
+ lora: Optional[dict] = None,
89
+ flex_block_mask_slice=None,
90
+ ):
91
+ for i, block in enumerate(w.blocks):
92
+ if lora is not None:
93
+ layer_lora = lora["text"]["blocks"][str(i)]
94
+ mlp_lora = layer_lora["mlp"]
95
+ attn_lora = layer_lora["attn"]
96
+ else:
97
+ mlp_lora = None
98
+ attn_lora = None
99
+
100
+ l_in = layer_norm(x, block.ln)
101
+ l_attn = attn(
102
+ l_in,
103
+ block.attn,
104
+ freqs_cis=w.freqs_cis,
105
+ kv_cache=block.kv_cache,
106
+ attn_mask=attn_mask,
107
+ n_heads=config.n_heads,
108
+ n_kv_heads=config.n_kv_heads,
109
+ position_ids=position_ids,
110
+ lora=attn_lora,
111
+ flex_block_mask_slice=flex_block_mask_slice,
112
+ )
113
+
114
+ if config.moe is not None and i >= config.moe.start_layer:
115
+ l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)
116
+ else:
117
+ l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
118
+
119
+ x = x + l_attn + l_mlp
120
+
121
+ return x
122
+
123
+
124
+ def lm_head(
125
+ hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None
126
+ ):
127
+ hidden_BC = hidden_BTC[:, -1, :]
128
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
129
+ if indices is not None:
130
+ # Only compute logits for specified token indices
131
+ logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices]
132
+ else:
133
+ logits = w.lm_head(hidden_BC)
134
+ return logits
135
+
136
+
137
+ def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
138
+ return nn.ModuleDict(
139
+ {
140
+ "fc1": linear_cls(d_model, d_ffn, dtype=dtype),
141
+ "fc2": linear_cls(d_ffn, d_model, dtype=dtype),
142
+ }
143
+ )
144
+
145
+
146
+ def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
147
+ # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
148
+ return nn.ModuleDict(
149
+ {
150
+ "router": nn.Linear(d_model, n_experts, dtype=dtype),
151
+ "fc1": nn.ParameterDict(
152
+ {
153
+ "weight": nn.Parameter(
154
+ torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype)
155
+ )
156
+ }
157
+ ),
158
+ "fc2": nn.ParameterDict(
159
+ {
160
+ "weight": nn.Parameter(
161
+ torch.empty(n_experts, d_model, d_ffn, dtype=dtype)
162
+ )
163
+ }
164
+ ),
165
+ }
166
+ )
167
+
168
+
169
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
170
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
171
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
172
+
173
+ text = nn.ModuleDict(
174
+ {
175
+ "blocks": nn.ModuleList(
176
+ [
177
+ nn.ModuleDict(
178
+ {
179
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
180
+ "attn": nn.ModuleDict(
181
+ {
182
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
183
+ "proj": linear_cls(
184
+ config.dim, config.dim, dtype=dtype
185
+ ),
186
+ "tau": nn.ParameterDict(
187
+ {
188
+ "wq": nn.Parameter(
189
+ torch.empty(
190
+ config.n_heads, qkv_dim, dtype=dtype
191
+ )
192
+ ),
193
+ "wv": nn.Parameter(
194
+ torch.empty(
195
+ config.n_heads, qkv_dim, dtype=dtype
196
+ )
197
+ ),
198
+ "alpha": nn.Parameter(
199
+ torch.empty(config.n_heads, dtype=dtype)
200
+ ),
201
+ }
202
+ ),
203
+ }
204
+ ),
205
+ "mlp": (
206
+ build_moe_mlp(
207
+ config.dim,
208
+ config.moe.expert_inner_dim,
209
+ config.moe.num_experts,
210
+ dtype,
211
+ )
212
+ if config.moe is not None
213
+ and layer_idx >= config.moe.start_layer
214
+ else build_dense_mlp(
215
+ config.dim, config.ff_dim, dtype, linear_cls
216
+ )
217
+ ),
218
+ }
219
+ )
220
+ for layer_idx in range(config.n_layers)
221
+ ]
222
+ ),
223
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
224
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
225
+ }
226
+ )
227
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
228
+ text.register_buffer(
229
+ "freqs_cis",
230
+ precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
231
+ persistent=False,
232
+ )
233
+
234
+ return text
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5
+ """
6
+ Robust outlier detection for list of (x,y) tuples.
7
+ Only requires numpy.
8
+
9
+ Args:
10
+ points_tuples: list of (x,y) tuples
11
+ k_nearest: number of neighbors to consider
12
+ threshold: multiplier for median distance
13
+
14
+ Returns:
15
+ list: filtered list of (x,y) tuples with outliers removed
16
+ list: list of booleans indicating which points were kept (True = kept)
17
+ """
18
+ points = np.array(points_tuples)
19
+ n_points = len(points)
20
+
21
+ # Calculate pairwise distances manually
22
+ dist_matrix = np.zeros((n_points, n_points))
23
+ for i in range(n_points):
24
+ for j in range(i + 1, n_points):
25
+ # Euclidean distance between points i and j
26
+ dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27
+ dist_matrix[i, j] = dist
28
+ dist_matrix[j, i] = dist
29
+
30
+ # Get k nearest neighbors' distances
31
+ k = min(k_nearest, n_points - 1)
32
+ neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33
+ avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34
+
35
+ # Calculate mask using median distance
36
+ median_dist = np.median(avg_neighbor_dist)
37
+ mask = avg_neighbor_dist <= threshold * median_dist
38
+
39
+ # Return filtered tuples and mask
40
+ filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41
+ return filtered_tuples
vision.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Union, Tuple
7
+ from PIL import Image
8
+
9
+ from .layers import attn, layer_norm, mlp
10
+ from .image_crops import overlap_crop_image
11
+ from .config import VisionConfig
12
+
13
+ if torch.backends.mps.is_available():
14
+ # Non-divisible input sizes are not implemented on MPS device yet.
15
+ # https://github.com/pytorch/pytorch/issues/96056
16
+ def adaptive_avg_pool2d(input, output_size):
17
+ return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18
+
19
+ else:
20
+ adaptive_avg_pool2d = F.adaptive_avg_pool2d
21
+
22
+ DeviceLike = Union[str, torch.device, int]
23
+
24
+
25
+ def prepare_crops(
26
+ image: Image.Image, config: VisionConfig, device: DeviceLike
27
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
28
+ np_image = np.array(image.convert("RGB"))
29
+ overlap_crops = overlap_crop_image(
30
+ np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
31
+ )
32
+ all_crops = overlap_crops["crops"]
33
+ all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
+ all_crops = (
35
+ torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
+ .div_(255.0)
38
+ .sub_(0.5)
39
+ .div_(0.5)
40
+ )
41
+ return all_crops, overlap_crops["tiling"]
42
+
43
+
44
+ def create_patches(x, patch_size):
45
+ # Original shape: [B, C, H, W]
46
+ B, C, H, W = x.shape
47
+ P1 = P2 = patch_size
48
+
49
+ # Step 1: Split H and W dimensions into patches
50
+ # [B, C, H/P1, P1, W/P2, P2]
51
+ x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52
+
53
+ # Step 2: Rearrange dimensions to match target shape
54
+ # [B, H/P1, W/P2, C, P1, P2]
55
+ x = x.permute(0, 2, 4, 1, 3, 5)
56
+
57
+ # Step 3: Combine dimensions to get final shape
58
+ # [B, (H/P1)*(W/P2), C*P1*P2]
59
+ x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60
+
61
+ return x
62
+
63
+
64
+ def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
+ x = create_patches(input_BCHW, config.enc_patch_size)
66
+
67
+ x = w.patch_emb(x)
68
+ x = x + w.pos_emb
69
+ for block in w.blocks:
70
+ x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
71
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
72
+ x = layer_norm(x, w.post_ln)
73
+
74
+ return x
75
+
76
+
77
+ def vision_projection(
78
+ global_features: torch.Tensor,
79
+ reconstructed: torch.Tensor,
80
+ w: nn.Module,
81
+ config: VisionConfig,
82
+ ):
83
+ reconstructed = reconstructed.permute(2, 0, 1)
84
+ reconstructed = adaptive_avg_pool2d(
85
+ reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
86
+ )
87
+ reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
88
+ final_features = torch.cat([global_features, reconstructed], dim=-1)
89
+ return mlp(final_features, w.proj_mlp)
90
+
91
+
92
+ def build_vision_model(config: VisionConfig, dtype: torch.dtype):
93
+ patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
94
+ grid_size = config.crop_size // config.enc_patch_size
95
+ num_patches = grid_size * grid_size
96
+
97
+ vision = nn.ModuleDict(
98
+ {
99
+ "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
100
+ "blocks": nn.ModuleList(
101
+ [
102
+ nn.ModuleDict(
103
+ {
104
+ "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
105
+ "attn": nn.ModuleDict(
106
+ {
107
+ "qkv": nn.Linear(
108
+ config.enc_dim, 3 * config.enc_dim, dtype=dtype
109
+ ),
110
+ "proj": nn.Linear(
111
+ config.enc_dim, config.enc_dim, dtype=dtype
112
+ ),
113
+ }
114
+ ),
115
+ "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
116
+ "mlp": nn.ModuleDict(
117
+ {
118
+ "fc1": nn.Linear(
119
+ config.enc_dim, config.enc_ff_dim, dtype=dtype
120
+ ),
121
+ "fc2": nn.Linear(
122
+ config.enc_ff_dim, config.enc_dim, dtype=dtype
123
+ ),
124
+ }
125
+ ),
126
+ }
127
+ )
128
+ for _ in range(config.enc_n_layers)
129
+ ]
130
+ ),
131
+ "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
132
+ "proj_mlp": nn.ModuleDict(
133
+ {
134
+ "fc1": nn.Linear(
135
+ config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
136
+ ),
137
+ "fc2": nn.Linear(
138
+ config.proj_inner_dim, config.proj_out_dim, dtype=dtype
139
+ ),
140
+ }
141
+ ),
142
+ }
143
+ )
144
+ vision.pos_emb = nn.Parameter(
145
+ torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
146
+ )
147
+ return vision