Upload 24 files
Browse files- .gitattributes +4 -0
- LICENSE.md +25 -0
- README.md +204 -0
- config.json +13 -0
- config.py +102 -0
- hf_moondream.py +183 -0
- image_crops.py +231 -0
- layers.py +234 -0
- lora.py +82 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +667 -0
- moondream.py +1070 -0
- open_vocab_detect.png +3 -0
- point_count.png +3 -0
- region.py +134 -0
- rope.py +47 -0
- structured_outputs.png +3 -0
- text.py +234 -0
- utils.py +41 -0
- vision.py +147 -0
- visual_reasoning.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
structured_outputs.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
open_vocab_detect.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
point_count.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
visual_reasoning.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
| License | Business Source License (BSL 1.1) |
|
| 2 |
+
| --- | --- |
|
| 3 |
+
| Licensor | M87 Labs, Inc. |
|
| 4 |
+
| Licensed Work | “Moondream 3 (Preview)” including Model Weights and any Derivatives (“Derivatives” include fine-tunes, merges, quantizations, weight deltas, and other weight-level modifications or conversions.) |
|
| 5 |
+
| Additional Use Grant | You may use the Licensed Work and Derivatives for any purpose, including commercial use, and you may self-host them for your or your organization’s internal use. You may not provide the Licensed Work, Derivatives, or any service that exposes their functionality to third parties (including via API, website, application, model hub, or dataset/model redistribution) without a separate commercial agreement with the Licensor. |
|
| 6 |
+
| Change Date | Two years after the first public release of this version of the Licensed Work |
|
| 7 |
+
| Change License | Apache License, Version 2.0 |
|
| 8 |
+
|
| 9 |
+
The text of the Business Source License 1.1 follows. License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved. “Business Source License” is a trademark of MariaDB Corporation Ab.
|
| 10 |
+
|
| 11 |
+
## Terms
|
| 12 |
+
|
| 13 |
+
The Licensor hereby grants you the right to copy, modify, create derivative works, redistribute, and make non-production use of the Licensed Work. The Licensor may make an Additional Use Grant, above, permitting limited production use.
|
| 14 |
+
|
| 15 |
+
Effective on the Change Date, or the fourth anniversary of the first publicly available distribution of a specific version of the Licensed Work under this License, whichever comes first, the Licensor hereby grants you rights under the terms of the Change License, and the rights granted in the paragraph above terminate.
|
| 16 |
+
|
| 17 |
+
If your use of the Licensed Work does not comply with the requirements currently in effect as described in this License, you must purchase a commercial license from the Licensor, its affiliated entities, or authorized resellers, or you must refrain from using the Licensed Work.
|
| 18 |
+
|
| 19 |
+
All copies of the original and modified Licensed Work, and derivative works of the Licensed Work, are subject to this License. This License applies separately for each version of the Licensed Work and the Change Date may vary for each version of the Licensed Work released by Licensor.
|
| 20 |
+
|
| 21 |
+
You must conspicuously display this License on each original or modified copy of the Licensed Work. If you receive the Licensed Work in original or modified form from a third party, the terms and conditions set forth in this License apply to your use of that work.
|
| 22 |
+
|
| 23 |
+
Any use of the Licensed Work in violation of this License will automatically terminate your rights under this License for the current and all other versions of the Licensed Work.
|
| 24 |
+
|
| 25 |
+
This License does not grant you any right in any trademark or logo of Licensor or its affiliates (provided that you may use a trademark or logo of Licensor as expressly required by this License).TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND TITLE.
|
README.md
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
pipeline_tag: image-text-to-text
|
| 4 |
+
license: other
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
**Moondream 3 (Preview)** is an vision language model with a mixture-of-experts architecture (9B total parameters, 2B active). This model makes no compromises, delivering state-of-the-art visual reasoning while still retaining our efficient and deployment-friendly ethos.
|
| 8 |
+
|
| 9 |
+
[✨ Demo](https://moondream.ai/c/playground)   ·   [☁️ Cloud API](https://moondream.ai/c/docs/quickstart)   ·   [📝 Release notes](https://moondream.ai/blog/moondream-3-preview)
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+

|
| 13 |
+

|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
## Architecture
|
| 17 |
+
|
| 18 |
+
1. 24 layers; the first four are dense, the rest have MoE FFNs with 64 experts, 8 activated per token
|
| 19 |
+
2. MoE FFNs have GeGLU architecture, with inner/gate dim of 1024. The model's hidden dim is 2048.
|
| 20 |
+
3. Usable context length increased to 32K, with [a custom efficient SuperBPE tokenizer](https://huggingface.co/moondream/starmie-v1)
|
| 21 |
+
4. Multi-headed attention with learned position- and data-dependent temperature scaling
|
| 22 |
+
5. SigLIP-based vision encoder, with multi-crop channel concatenation for token-efficient high resolution image processing
|
| 23 |
+
|
| 24 |
+
For more details, please refer to the [release notes]((https://moondream.ai/blog/moondream-3-preview). Or try the model out in our [playground demo](https://moondream.ai/c/playground).
|
| 25 |
+
|
| 26 |
+
The following instructions demonstrate how to run the model locally using Transformers. We also offer a [cloud API](https://moondream.ai/c/docs/quickstart) with a generous free tier that can help you get started quicker!
|
| 27 |
+
|
| 28 |
+
## Usage
|
| 29 |
+
|
| 30 |
+
Load the model and prepare it for inference. We use [FlexAttention for inference](https://pytorch.org/blog/flexattention-for-inference/), so calling `.compile()` is critical for fast decoding. Our `compile` implementation also handles warmup, so you can start making requests directly once it returns.
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
import torch
|
| 34 |
+
from transformers import AutoModelForCausalLM
|
| 35 |
+
|
| 36 |
+
moondream = AutoModelForCausalLM.from_pretrained(
|
| 37 |
+
"moondream/moondream3-preview",
|
| 38 |
+
trust_remote_code=True,
|
| 39 |
+
dtype=torch.bfloat16,
|
| 40 |
+
device_map={"": "cuda"},
|
| 41 |
+
)
|
| 42 |
+
moondream.compile()
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
The model comes with four skills, tailored towards different visual understanding tasks.
|
| 46 |
+
|
| 47 |
+
### Query
|
| 48 |
+
|
| 49 |
+
The `query` skill can be used to ask open-ended questions about images.
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
from PIL import Image
|
| 53 |
+
|
| 54 |
+
# Simple VQA
|
| 55 |
+
image = Image.open("photo.jpg")
|
| 56 |
+
result = moondream.query(image=image, question="What's in this image?")
|
| 57 |
+
print(result["answer"])
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
By default, `query` runs in reasoning mode, allowing the model to "think" about the question before generating an answer. This is helpful for more complicated tasks, but sometimes the task you're running is simple and doesn't benefit from reasoning. To save on inference cost when this is the case, you can disable reasoning:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
# Without reasoning for simple questions
|
| 64 |
+
result = moondream.query(
|
| 65 |
+
image=image,
|
| 66 |
+
question="What color is the sky?",
|
| 67 |
+
reasoning=False
|
| 68 |
+
)
|
| 69 |
+
print(result["answer"])
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
If you want to stream outputs, pass in `stream=True`. You can control the temperature, top-p, and maximum number of tokens generated by passing in optional settings.
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
# Streaming with custom settings
|
| 76 |
+
settings = {
|
| 77 |
+
"temperature": 0.7,
|
| 78 |
+
"top_p": 0.95,
|
| 79 |
+
"max_tokens": 512
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
result = moondream.query(
|
| 83 |
+
image=image,
|
| 84 |
+
question="Describe what's happening in detail",
|
| 85 |
+
stream=True,
|
| 86 |
+
settings=settings
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Stream the answer
|
| 90 |
+
for chunk in result["answer"]:
|
| 91 |
+
print(chunk, end="", flush=True)
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Note that this isn't just for images; Moondream is also a strong general-purpose text model.
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
# Text-only example (no image)
|
| 98 |
+
result = moondream.query(
|
| 99 |
+
question="Explain the concept of machine learning in simple terms"
|
| 100 |
+
)
|
| 101 |
+
print(result["answer"])
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### Caption
|
| 105 |
+
|
| 106 |
+
Whether you want short, normal-sized or long descriptions of images, the `caption` skill has you covered.
|
| 107 |
+
|
| 108 |
+
```python
|
| 109 |
+
# Different caption lengths
|
| 110 |
+
image = Image.open("landscape.jpg")
|
| 111 |
+
|
| 112 |
+
# Short caption
|
| 113 |
+
short = moondream.caption(image, length="short")
|
| 114 |
+
print(f"Short: {short['caption']}")
|
| 115 |
+
|
| 116 |
+
# Normal caption (default)
|
| 117 |
+
normal = moondream.caption(image, length="normal")
|
| 118 |
+
print(f"Normal: {normal['caption']}")
|
| 119 |
+
|
| 120 |
+
# Long caption
|
| 121 |
+
long = moondream.caption(image, length="long")
|
| 122 |
+
print(f"Long: {long['caption']}")
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
It accepts the same streaming and temperature etc. settings as the `query` skill.
|
| 126 |
+
|
| 127 |
+
```python
|
| 128 |
+
# Streaming caption with custom settings
|
| 129 |
+
result = moondream.caption(
|
| 130 |
+
image,
|
| 131 |
+
length="long",
|
| 132 |
+
stream=True,
|
| 133 |
+
settings={"temperature": 0.3}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
for chunk in result["caption"]:
|
| 137 |
+
print(chunk, end="", flush=True)
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### Point
|
| 141 |
+
|
| 142 |
+
The `point` skill identifies specific points (x, y coordinates) for objects in an image.
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
# Find points for specific objects
|
| 146 |
+
image = Image.open("crowd.jpg")
|
| 147 |
+
result = moondream.point(image, "person wearing a red shirt")
|
| 148 |
+
|
| 149 |
+
# Points are normalized coordinates (0-1)
|
| 150 |
+
for i, point in enumerate(result["points"]):
|
| 151 |
+
print(f"Point {i+1}: x={point['x']:.3f}, y={point['y']:.3f}")
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Detect
|
| 155 |
+
|
| 156 |
+
The `detect` skill provides bounding boxes for objects in an image.
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
# Detect objects with bounding boxes
|
| 160 |
+
image = Image.open("street_scene.jpg")
|
| 161 |
+
result = moondream.detect(image, "car")
|
| 162 |
+
|
| 163 |
+
# Bounding boxes are normalized coordinates (0-1)
|
| 164 |
+
for i, obj in enumerate(result["objects"]):
|
| 165 |
+
print(f"Object {i+1}: "
|
| 166 |
+
f"x_min={obj['x_min']:.3f}, y_min={obj['y_min']:.3f}, "
|
| 167 |
+
f"x_max={obj['x_max']:.3f}, y_max={obj['y_max']:.3f}")
|
| 168 |
+
|
| 169 |
+
# Control maximum number of objects
|
| 170 |
+
settings = {"max_objects": 10}
|
| 171 |
+
result = moondream.detect(image, "person", settings=settings)
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### Caching image encodings (advanced)
|
| 175 |
+
|
| 176 |
+
If you're planning to run multiple inferences on the same image, you can pre-encode it once and reuse the encoding for better performance.
|
| 177 |
+
|
| 178 |
+
```python
|
| 179 |
+
# Encode image once
|
| 180 |
+
image = Image.open("complex_scene.jpg")
|
| 181 |
+
encoded = moondream.encode_image(image)
|
| 182 |
+
|
| 183 |
+
# Reuse the encoding for multiple queries
|
| 184 |
+
questions = [
|
| 185 |
+
"How many people are in this image?",
|
| 186 |
+
"What time of day was this taken?",
|
| 187 |
+
"What's the weather like?"
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
for q in questions:
|
| 191 |
+
result = moondream.query(image=encoded, question=q, reasoning=False)
|
| 192 |
+
print(f"Q: {q}")
|
| 193 |
+
print(f"A: {result['answer']}\n")
|
| 194 |
+
|
| 195 |
+
# Also works with other skills
|
| 196 |
+
caption = moondream.caption(encoded, length="normal")
|
| 197 |
+
objects = moondream.detect(encoded, "vehicle")
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
Copyright (c) 2025 M87 Labs, Inc.
|
| 203 |
+
|
| 204 |
+
This distribution includes Model Weights licensed under the [Business Source License 1.1 with an Additional Use Grant (No Third-Party Service)](https://huggingface.co/moondream/moondream3-preview/blob/main/LICENSE.md). Commercial hosting or rehosting requires an agreement with <contact@m87.ai>.
|
config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"HfMoondream"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "hf_moondream.HfConfig",
|
| 7 |
+
"AutoModelForCausalLM": "hf_moondream.HfMoondream"
|
| 8 |
+
},
|
| 9 |
+
"config": {},
|
| 10 |
+
"model_type": "moondream1",
|
| 11 |
+
"torch_dtype": "bfloat16",
|
| 12 |
+
"transformers_version": "4.51.1"
|
| 13 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass(frozen=True)
|
| 6 |
+
class TextMoeConfig:
|
| 7 |
+
num_experts: int = 64
|
| 8 |
+
start_layer: int = 4
|
| 9 |
+
experts_per_token: int = 8
|
| 10 |
+
expert_inner_dim: int = 1024
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class TextConfig:
|
| 15 |
+
dim: int = 2048
|
| 16 |
+
ff_dim: int = 8192
|
| 17 |
+
n_layers: int = 24
|
| 18 |
+
vocab_size: int = 51200
|
| 19 |
+
max_context: int = 4096
|
| 20 |
+
n_heads: int = 32
|
| 21 |
+
n_kv_heads: int = 32
|
| 22 |
+
prefix_attn: int = 730
|
| 23 |
+
group_size: Optional[int] = None
|
| 24 |
+
moe: Optional[TextMoeConfig] = TextMoeConfig()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class VisionConfig:
|
| 29 |
+
enc_dim: int = 1152
|
| 30 |
+
enc_patch_size: int = 14
|
| 31 |
+
enc_n_layers: int = 27
|
| 32 |
+
enc_ff_dim: int = 4304
|
| 33 |
+
enc_n_heads: int = 16
|
| 34 |
+
proj_out_dim: int = 2048
|
| 35 |
+
crop_size: int = 378
|
| 36 |
+
in_channels: int = 3
|
| 37 |
+
max_crops: int = 12
|
| 38 |
+
overlap_margin: int = 4
|
| 39 |
+
proj_inner_dim: int = 8192
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(frozen=True)
|
| 43 |
+
class RegionConfig:
|
| 44 |
+
dim: int = 2048
|
| 45 |
+
coord_feat_dim: int = 256
|
| 46 |
+
coord_out_dim: int = 1024
|
| 47 |
+
size_feat_dim: int = 512
|
| 48 |
+
size_out_dim: int = 2048
|
| 49 |
+
group_size: Optional[int] = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass(frozen=True)
|
| 53 |
+
class TokenizerConfig:
|
| 54 |
+
bos_id: int = 0
|
| 55 |
+
eos_id: int = 0
|
| 56 |
+
answer_id: int = 3
|
| 57 |
+
thinking_id: int = 4
|
| 58 |
+
coord_id: int = 5
|
| 59 |
+
size_id: int = 6
|
| 60 |
+
start_ground_points_id: int = 7
|
| 61 |
+
end_ground_id: int = 9
|
| 62 |
+
templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
|
| 63 |
+
default_factory=lambda: {
|
| 64 |
+
"caption": {
|
| 65 |
+
"short": [1, 32708, 2, 12492, 3],
|
| 66 |
+
"normal": [1, 32708, 2, 6382, 3],
|
| 67 |
+
"long": [1, 32708, 2, 4059, 3],
|
| 68 |
+
},
|
| 69 |
+
"query": {"prefix": [1, 15381, 2], "suffix": [3]},
|
| 70 |
+
"detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
|
| 71 |
+
"point": {"prefix": [1, 2581, 2], "suffix": [3]},
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass(frozen=True)
|
| 77 |
+
class MoondreamConfig:
|
| 78 |
+
text: TextConfig = TextConfig()
|
| 79 |
+
vision: VisionConfig = VisionConfig()
|
| 80 |
+
region: RegionConfig = RegionConfig()
|
| 81 |
+
tokenizer: TokenizerConfig = TokenizerConfig()
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def from_dict(cls, config_dict: dict):
|
| 85 |
+
text_config = TextConfig(**config_dict.get("text", {}))
|
| 86 |
+
vision_config = VisionConfig(**config_dict.get("vision", {}))
|
| 87 |
+
region_config = RegionConfig(**config_dict.get("region", {}))
|
| 88 |
+
tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
|
| 89 |
+
return cls(
|
| 90 |
+
text=text_config,
|
| 91 |
+
vision=vision_config,
|
| 92 |
+
region=region_config,
|
| 93 |
+
tokenizer=tokenizer_config,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def to_dict(self):
|
| 97 |
+
return {
|
| 98 |
+
"text": self.text.__dict__,
|
| 99 |
+
"vision": self.vision.__dict__,
|
| 100 |
+
"region": self.region.__dict__,
|
| 101 |
+
"tokenizer": self.tokenizer.__dict__,
|
| 102 |
+
}
|
hf_moondream.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
from .config import MoondreamConfig
|
| 8 |
+
from .moondream import MoondreamModel
|
| 9 |
+
|
| 10 |
+
# Files sometimes don't get loaded without these...
|
| 11 |
+
from .image_crops import *
|
| 12 |
+
from .vision import *
|
| 13 |
+
from .text import *
|
| 14 |
+
from .region import *
|
| 15 |
+
from .utils import *
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def extract_question(text):
|
| 19 |
+
prefix = "<image>\n\nQuestion: "
|
| 20 |
+
suffix = "\n\nAnswer:"
|
| 21 |
+
|
| 22 |
+
if text.startswith(prefix) and text.endswith(suffix):
|
| 23 |
+
return text[len(prefix) : -len(suffix)]
|
| 24 |
+
else:
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class HfConfig(PretrainedConfig):
|
| 29 |
+
_auto_class = "AutoConfig"
|
| 30 |
+
model_type = "moondream1"
|
| 31 |
+
|
| 32 |
+
def __init__(self, **kwargs):
|
| 33 |
+
super().__init__(**kwargs)
|
| 34 |
+
self.config = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HfMoondream(PreTrainedModel):
|
| 38 |
+
_auto_class = "AutoModelForCausalLM"
|
| 39 |
+
config_class = HfConfig
|
| 40 |
+
|
| 41 |
+
def __init__(self, config):
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
self.model = MoondreamModel(
|
| 44 |
+
MoondreamConfig.from_dict(config.config), setup_caches=False
|
| 45 |
+
)
|
| 46 |
+
self._is_kv_cache_setup = False
|
| 47 |
+
|
| 48 |
+
def _setup_caches(self):
|
| 49 |
+
if not self._is_kv_cache_setup:
|
| 50 |
+
self.model._setup_caches()
|
| 51 |
+
self._is_kv_cache_setup = True
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def encode_image(self):
|
| 55 |
+
self._setup_caches()
|
| 56 |
+
return self.model.encode_image
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def query(self):
|
| 60 |
+
self._setup_caches()
|
| 61 |
+
return self.model.query
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def caption(self):
|
| 65 |
+
self._setup_caches()
|
| 66 |
+
return self.model.caption
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def detect(self):
|
| 70 |
+
self._setup_caches()
|
| 71 |
+
return self.model.detect
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def point(self):
|
| 75 |
+
self._setup_caches()
|
| 76 |
+
return self.model.point
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def detect_gaze(self):
|
| 80 |
+
self._setup_caches()
|
| 81 |
+
return self.model.detect_gaze
|
| 82 |
+
|
| 83 |
+
def answer_question(
|
| 84 |
+
self,
|
| 85 |
+
image_embeds,
|
| 86 |
+
question,
|
| 87 |
+
tokenizer=None,
|
| 88 |
+
chat_history="",
|
| 89 |
+
result_queue=None,
|
| 90 |
+
max_new_tokens=256,
|
| 91 |
+
**kwargs
|
| 92 |
+
):
|
| 93 |
+
answer = self.query(image_embeds, question)["answer"].strip()
|
| 94 |
+
|
| 95 |
+
if result_queue is not None:
|
| 96 |
+
result_queue.put(answer)
|
| 97 |
+
return answer
|
| 98 |
+
|
| 99 |
+
def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
|
| 100 |
+
answers = []
|
| 101 |
+
for image, prompt in zip(images, prompts):
|
| 102 |
+
answers.append(self.query(image, prompt)["answer"].strip())
|
| 103 |
+
return answers
|
| 104 |
+
|
| 105 |
+
def _unsupported_exception(self):
|
| 106 |
+
raise NotImplementedError(
|
| 107 |
+
"This method is not supported in the latest version of moondream. "
|
| 108 |
+
"Consider upgrading to the updated API spec, or alternately pin "
|
| 109 |
+
"to 'revision=2024-08-26'."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
|
| 113 |
+
"""
|
| 114 |
+
Function definition remains unchanged for backwards compatibility.
|
| 115 |
+
Be aware that tokenizer, max_new_takens, and kwargs are ignored.
|
| 116 |
+
"""
|
| 117 |
+
prompt_extracted = extract_question(prompt)
|
| 118 |
+
if prompt_extracted is not None:
|
| 119 |
+
answer = self.model.query(
|
| 120 |
+
image=image_embeds, question=prompt_extracted, stream=False
|
| 121 |
+
)["answer"]
|
| 122 |
+
else:
|
| 123 |
+
image_embeds = self.encode_image(image_embeds)
|
| 124 |
+
prompt_tokens = torch.tensor(
|
| 125 |
+
[self.model.tokenizer.encode(prompt).ids],
|
| 126 |
+
device=self.device,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def generator():
|
| 130 |
+
for token in self.model._generate_answer(
|
| 131 |
+
prompt_tokens,
|
| 132 |
+
image_embeds.kv_cache,
|
| 133 |
+
image_embeds.pos,
|
| 134 |
+
max_new_tokens,
|
| 135 |
+
):
|
| 136 |
+
yield token
|
| 137 |
+
|
| 138 |
+
answer = "".join(list(generator()))
|
| 139 |
+
|
| 140 |
+
return [answer]
|
| 141 |
+
|
| 142 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 143 |
+
"""
|
| 144 |
+
Lazily wrap the raw parameter `self.model.text.wte` in a real
|
| 145 |
+
`nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
|
| 146 |
+
**shares** the weight tensor—no copy is made.
|
| 147 |
+
"""
|
| 148 |
+
if not hasattr(self, "_input_embeddings"):
|
| 149 |
+
self._input_embeddings = nn.Embedding.from_pretrained(
|
| 150 |
+
self.model.text.wte, # tensor created in text.py
|
| 151 |
+
freeze=True, # set to False if you need it trainable
|
| 152 |
+
)
|
| 153 |
+
return self._input_embeddings
|
| 154 |
+
|
| 155 |
+
def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
|
| 156 |
+
"""
|
| 157 |
+
Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
|
| 158 |
+
embeddings and keeps everything tied to `self.model.text.wte`.
|
| 159 |
+
"""
|
| 160 |
+
# 1. point the low-level parameter to the new weight matrix
|
| 161 |
+
self.model.text.wte = value.weight
|
| 162 |
+
# 2. keep a reference for get_input_embeddings()
|
| 163 |
+
self._input_embeddings = value
|
| 164 |
+
|
| 165 |
+
def input_embeds(
|
| 166 |
+
self,
|
| 167 |
+
input_ids: Union[torch.LongTensor, list, tuple],
|
| 168 |
+
*,
|
| 169 |
+
device: torch.device | None = None
|
| 170 |
+
) -> torch.FloatTensor:
|
| 171 |
+
"""
|
| 172 |
+
Back-compat wrapper that turns token IDs into embeddings.
|
| 173 |
+
|
| 174 |
+
Example:
|
| 175 |
+
ids = torch.tensor([[1, 2, 3]])
|
| 176 |
+
embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
|
| 177 |
+
"""
|
| 178 |
+
if not torch.is_tensor(input_ids):
|
| 179 |
+
input_ids = torch.as_tensor(input_ids)
|
| 180 |
+
if device is not None:
|
| 181 |
+
input_ids = input_ids.to(device)
|
| 182 |
+
|
| 183 |
+
return self.get_input_embeddings()(input_ids)
|
image_crops.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from typing import TypedDict
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import pyvips
|
| 9 |
+
|
| 10 |
+
HAS_VIPS = True
|
| 11 |
+
except:
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
HAS_VIPS = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def select_tiling(
|
| 18 |
+
height: int, width: int, crop_size: int, max_crops: int
|
| 19 |
+
) -> tuple[int, int]:
|
| 20 |
+
"""
|
| 21 |
+
Determine the optimal number of tiles to cover an image with overlapping crops.
|
| 22 |
+
"""
|
| 23 |
+
if height <= crop_size or width <= crop_size:
|
| 24 |
+
return (1, 1)
|
| 25 |
+
|
| 26 |
+
# Minimum required tiles in each dimension
|
| 27 |
+
min_h = math.ceil(height / crop_size)
|
| 28 |
+
min_w = math.ceil(width / crop_size)
|
| 29 |
+
|
| 30 |
+
# If minimum required tiles exceed max_crops, return proportional distribution
|
| 31 |
+
if min_h * min_w > max_crops:
|
| 32 |
+
ratio = math.sqrt(max_crops / (min_h * min_w))
|
| 33 |
+
return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
|
| 34 |
+
|
| 35 |
+
# Perfect aspect-ratio tiles that satisfy max_crops
|
| 36 |
+
h_tiles = math.floor(math.sqrt(max_crops * height / width))
|
| 37 |
+
w_tiles = math.floor(math.sqrt(max_crops * width / height))
|
| 38 |
+
|
| 39 |
+
# Ensure we meet minimum tile requirements
|
| 40 |
+
h_tiles = max(h_tiles, min_h)
|
| 41 |
+
w_tiles = max(w_tiles, min_w)
|
| 42 |
+
|
| 43 |
+
# If we exceeded max_crops, scale down the larger dimension
|
| 44 |
+
if h_tiles * w_tiles > max_crops:
|
| 45 |
+
if w_tiles > h_tiles:
|
| 46 |
+
w_tiles = math.floor(max_crops / h_tiles)
|
| 47 |
+
else:
|
| 48 |
+
h_tiles = math.floor(max_crops / w_tiles)
|
| 49 |
+
|
| 50 |
+
return (max(1, h_tiles), max(1, w_tiles))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class OverlapCropOutput(TypedDict):
|
| 54 |
+
crops: np.ndarray
|
| 55 |
+
tiling: tuple[int, int]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def overlap_crop_image(
|
| 59 |
+
image: np.ndarray,
|
| 60 |
+
overlap_margin: int,
|
| 61 |
+
max_crops: int,
|
| 62 |
+
base_size: tuple[int, int] = (378, 378),
|
| 63 |
+
patch_size: int = 14,
|
| 64 |
+
) -> OverlapCropOutput:
|
| 65 |
+
"""
|
| 66 |
+
Process an image using an overlap-and-resize cropping strategy with margin handling.
|
| 67 |
+
|
| 68 |
+
This function takes an input image and creates multiple overlapping crops with
|
| 69 |
+
consistent margins. It produces:
|
| 70 |
+
1. A single global crop resized to base_size
|
| 71 |
+
2. Multiple overlapping local crops that maintain high resolution details
|
| 72 |
+
3. A patch ordering matrix that tracks correspondence between crops
|
| 73 |
+
|
| 74 |
+
The overlap strategy ensures:
|
| 75 |
+
- Smooth transitions between adjacent crops
|
| 76 |
+
- No loss of information at crop boundaries
|
| 77 |
+
- Proper handling of features that cross crop boundaries
|
| 78 |
+
- Consistent patch indexing across the full image
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image (np.ndarray): Input image as numpy array with shape (H,W,C)
|
| 82 |
+
base_size (tuple[int,int]): Target size for crops, default (378,378)
|
| 83 |
+
patch_size (int): Size of patches in pixels, default 14
|
| 84 |
+
overlap_margin (int): Margin size in patch units, default 4
|
| 85 |
+
max_crops (int): Maximum number of crops allowed, default 12
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
OverlapCropOutput: Dictionary containing:
|
| 89 |
+
- crops: A numpy array containing the global crop of the full image (index 0)
|
| 90 |
+
followed by the overlapping cropped regions (indices 1+)
|
| 91 |
+
- tiling: Tuple of (height,width) tile counts
|
| 92 |
+
"""
|
| 93 |
+
original_h, original_w = image.shape[:2]
|
| 94 |
+
|
| 95 |
+
# Convert margin from patch units to pixels
|
| 96 |
+
margin_pixels = patch_size * overlap_margin
|
| 97 |
+
total_margin_pixels = margin_pixels * 2 # Both sides
|
| 98 |
+
|
| 99 |
+
# Calculate crop parameters
|
| 100 |
+
crop_patches = base_size[0] // patch_size # patches per crop dimension
|
| 101 |
+
crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
|
| 102 |
+
crop_window_size = crop_window_patches * patch_size # usable size in pixels
|
| 103 |
+
|
| 104 |
+
# Determine tiling
|
| 105 |
+
tiling = select_tiling(
|
| 106 |
+
original_h - total_margin_pixels,
|
| 107 |
+
original_w - total_margin_pixels,
|
| 108 |
+
crop_window_size,
|
| 109 |
+
max_crops,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Pre-allocate crops.
|
| 113 |
+
n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
|
| 114 |
+
crops = np.zeros(
|
| 115 |
+
(n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Resize image to fit tiling
|
| 119 |
+
target_size = (
|
| 120 |
+
tiling[0] * crop_window_size + total_margin_pixels,
|
| 121 |
+
tiling[1] * crop_window_size + total_margin_pixels,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if HAS_VIPS:
|
| 125 |
+
# Convert to vips for resizing
|
| 126 |
+
vips_image = pyvips.Image.new_from_array(image)
|
| 127 |
+
scale_x = target_size[1] / image.shape[1]
|
| 128 |
+
scale_y = target_size[0] / image.shape[0]
|
| 129 |
+
resized = vips_image.resize(scale_x, vscale=scale_y)
|
| 130 |
+
image = resized.numpy()
|
| 131 |
+
|
| 132 |
+
# Create global crop
|
| 133 |
+
scale_x = base_size[1] / vips_image.width
|
| 134 |
+
scale_y = base_size[0] / vips_image.height
|
| 135 |
+
global_vips = vips_image.resize(scale_x, vscale=scale_y)
|
| 136 |
+
crops[0] = global_vips.numpy()
|
| 137 |
+
else:
|
| 138 |
+
# Fallback to PIL
|
| 139 |
+
pil_img = Image.fromarray(image)
|
| 140 |
+
resized = pil_img.resize(
|
| 141 |
+
(int(target_size[1]), int(target_size[0])),
|
| 142 |
+
resample=Image.Resampling.LANCZOS,
|
| 143 |
+
)
|
| 144 |
+
image = np.asarray(resized)
|
| 145 |
+
|
| 146 |
+
# Create global crop
|
| 147 |
+
global_pil = pil_img.resize(
|
| 148 |
+
(int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
|
| 149 |
+
)
|
| 150 |
+
crops[0] = np.asarray(global_pil)
|
| 151 |
+
|
| 152 |
+
for i in range(tiling[0]):
|
| 153 |
+
for j in range(tiling[1]):
|
| 154 |
+
# Calculate crop coordinates
|
| 155 |
+
y0 = i * crop_window_size
|
| 156 |
+
x0 = j * crop_window_size
|
| 157 |
+
|
| 158 |
+
# Extract crop with padding if needed
|
| 159 |
+
y_end = min(y0 + base_size[0], image.shape[0])
|
| 160 |
+
x_end = min(x0 + base_size[1], image.shape[1])
|
| 161 |
+
|
| 162 |
+
crop_region = image[y0:y_end, x0:x_end]
|
| 163 |
+
crops[
|
| 164 |
+
1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
|
| 165 |
+
] = crop_region
|
| 166 |
+
|
| 167 |
+
return {"crops": crops, "tiling": tiling}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def reconstruct_from_crops(
|
| 171 |
+
crops: torch.Tensor,
|
| 172 |
+
tiling: tuple[int, int],
|
| 173 |
+
overlap_margin: int,
|
| 174 |
+
patch_size: int = 14,
|
| 175 |
+
) -> torch.Tensor:
|
| 176 |
+
"""
|
| 177 |
+
Reconstruct the original image from overlapping crops into a single seamless image.
|
| 178 |
+
|
| 179 |
+
Takes a list of overlapping image crops along with their positional metadata and
|
| 180 |
+
reconstructs them into a single coherent image by carefully stitching together
|
| 181 |
+
non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
crops: List of image crops as numpy arrays or PyTorch tensors with shape
|
| 185 |
+
(H,W,C)
|
| 186 |
+
tiling: Tuple of (height,width) indicating crop grid layout
|
| 187 |
+
patch_size: Size in pixels of each patch, default 14
|
| 188 |
+
overlap_margin: Number of overlapping patches on each edge, default 4
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Reconstructed image as numpy array or PyTorch tensor matching input type,
|
| 192 |
+
with shape (H,W,C) where H,W are the original image dimensions
|
| 193 |
+
"""
|
| 194 |
+
tiling_h, tiling_w = tiling
|
| 195 |
+
crop_height, crop_width = crops[0].shape[:2]
|
| 196 |
+
margin_pixels = overlap_margin * patch_size
|
| 197 |
+
|
| 198 |
+
# Calculate output size (only adding margins once)
|
| 199 |
+
output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
|
| 200 |
+
output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
|
| 201 |
+
|
| 202 |
+
reconstructed = torch.zeros(
|
| 203 |
+
(output_h, output_w, crops[0].shape[2]),
|
| 204 |
+
device=crops[0].device,
|
| 205 |
+
dtype=crops[0].dtype,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
for i, crop in enumerate(crops):
|
| 209 |
+
tile_y = i // tiling_w
|
| 210 |
+
tile_x = i % tiling_w
|
| 211 |
+
|
| 212 |
+
# For each tile, determine which part to keep
|
| 213 |
+
# Keep left margin only for first column
|
| 214 |
+
x_start = 0 if tile_x == 0 else margin_pixels
|
| 215 |
+
# Keep right margin only for last column
|
| 216 |
+
x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
|
| 217 |
+
# Keep top margin only for first row
|
| 218 |
+
y_start = 0 if tile_y == 0 else margin_pixels
|
| 219 |
+
# Keep bottom margin only for last row
|
| 220 |
+
y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
|
| 221 |
+
|
| 222 |
+
# Calculate where this piece belongs in the output
|
| 223 |
+
out_x = tile_x * (crop_width - 2 * margin_pixels)
|
| 224 |
+
out_y = tile_y * (crop_height - 2 * margin_pixels)
|
| 225 |
+
|
| 226 |
+
# Place the piece
|
| 227 |
+
reconstructed[
|
| 228 |
+
out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
|
| 229 |
+
] = crop[y_start:y_end, x_start:x_end]
|
| 230 |
+
|
| 231 |
+
return reconstructed
|
layers.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Literal, Optional
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from torchao import quantize_
|
| 10 |
+
from torchao.quantization import int4_weight_only
|
| 11 |
+
except ImportError:
|
| 12 |
+
|
| 13 |
+
def quantize_(model, quant_mode):
|
| 14 |
+
raise ImportError(
|
| 15 |
+
"torchao is not installed. Please install it with `pip install torchao`."
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def int4_weight_only(group_size):
|
| 19 |
+
raise ImportError(
|
| 20 |
+
"torchao is not installed. Please install it with `pip install torchao`."
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def gelu_approx(x):
|
| 25 |
+
return F.gelu(x, approximate="tanh")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class LinearWeights:
|
| 30 |
+
weight: torch.Tensor
|
| 31 |
+
bias: torch.Tensor
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
|
| 35 |
+
return F.linear(x, w.weight, w.bias)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
|
| 39 |
+
_step = W_q.shape[0]
|
| 40 |
+
W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
|
| 41 |
+
W_r[:_step] = (W_q & 0b11110000) >> 4
|
| 42 |
+
W_r[_step:] = W_q & 0b00001111
|
| 43 |
+
W_r.sub_(zero).mul_(scale)
|
| 44 |
+
return W_r.reshape(orig_shape)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class QuantizedLinear(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
in_features: int,
|
| 51 |
+
out_features: int,
|
| 52 |
+
dtype: torch.dtype,
|
| 53 |
+
):
|
| 54 |
+
# TODO: Take group_size as an input instead of hardcoding it here.
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.in_features = in_features
|
| 57 |
+
self.out_features = out_features
|
| 58 |
+
self.weight = nn.ParameterDict(
|
| 59 |
+
{
|
| 60 |
+
"packed": nn.Parameter(
|
| 61 |
+
torch.empty(
|
| 62 |
+
out_features * in_features // (128 * 2), 128, dtype=torch.uint8
|
| 63 |
+
),
|
| 64 |
+
requires_grad=False,
|
| 65 |
+
),
|
| 66 |
+
"scale": nn.Parameter(
|
| 67 |
+
torch.empty(out_features * in_features // 128, 1),
|
| 68 |
+
requires_grad=False,
|
| 69 |
+
),
|
| 70 |
+
"zero_point": nn.Parameter(
|
| 71 |
+
torch.empty(out_features * in_features // 128, 1),
|
| 72 |
+
requires_grad=False,
|
| 73 |
+
),
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
|
| 77 |
+
self.unpacked = False
|
| 78 |
+
|
| 79 |
+
def unpack(self):
|
| 80 |
+
if self.unpacked:
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
self.weight = nn.Parameter(
|
| 84 |
+
dequantize_tensor(
|
| 85 |
+
self.weight["packed"],
|
| 86 |
+
self.weight["scale"],
|
| 87 |
+
self.weight["zero_point"],
|
| 88 |
+
(self.out_features, self.in_features),
|
| 89 |
+
torch.bfloat16,
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
with torch.device("meta"):
|
| 93 |
+
self.linear = nn.Linear(
|
| 94 |
+
self.in_features, self.out_features, dtype=torch.bfloat16
|
| 95 |
+
)
|
| 96 |
+
self.linear.weight = self.weight
|
| 97 |
+
self.linear.bias = nn.Parameter(
|
| 98 |
+
self.bias.to(torch.bfloat16), requires_grad=False
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
del self.weight, self.bias
|
| 102 |
+
quantize_(self, int4_weight_only(group_size=128))
|
| 103 |
+
self.unpacked = True
|
| 104 |
+
torch.cuda.empty_cache()
|
| 105 |
+
|
| 106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
if not self.unpacked:
|
| 108 |
+
self.unpack()
|
| 109 |
+
return self.linear(x)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class LayerNormWeights:
|
| 114 |
+
weight: torch.Tensor
|
| 115 |
+
bias: torch.Tensor
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
|
| 119 |
+
return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass
|
| 123 |
+
class MLPWeights:
|
| 124 |
+
fc1: LinearWeights
|
| 125 |
+
fc2: LinearWeights
|
| 126 |
+
act: Literal["gelu_approx"] = "gelu_approx"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
|
| 130 |
+
x0 = w.fc1(x)
|
| 131 |
+
if lora is not None:
|
| 132 |
+
x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
|
| 133 |
+
x = x0 + x1
|
| 134 |
+
else:
|
| 135 |
+
x = x0
|
| 136 |
+
|
| 137 |
+
x = gelu_approx(x)
|
| 138 |
+
|
| 139 |
+
x0 = w.fc2(x)
|
| 140 |
+
if lora is not None:
|
| 141 |
+
x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
|
| 142 |
+
x = x0 + x1
|
| 143 |
+
else:
|
| 144 |
+
x = x0
|
| 145 |
+
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def moe_mlp(
|
| 150 |
+
x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
|
| 151 |
+
) -> torch.Tensor:
|
| 152 |
+
B, T, C = x.shape
|
| 153 |
+
x = x.reshape(-1, C)
|
| 154 |
+
|
| 155 |
+
# Router computation
|
| 156 |
+
router_logits = mlp_module.router(x)
|
| 157 |
+
topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)
|
| 158 |
+
topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)
|
| 159 |
+
num_tokens, top_k = topk_idxs.shape
|
| 160 |
+
|
| 161 |
+
if T == 1:
|
| 162 |
+
w1_weight = mlp_module.fc1.weight
|
| 163 |
+
w2_weight = mlp_module.fc2.weight
|
| 164 |
+
|
| 165 |
+
# Flatten to process all token-expert pairs at once
|
| 166 |
+
flat_idxs = topk_idxs.view(-1) # [T*A]
|
| 167 |
+
flat_weights = topk_weights.view(-1) # [T*A]
|
| 168 |
+
|
| 169 |
+
# Select expert weights
|
| 170 |
+
w1_selected = w1_weight[flat_idxs] # [T*A, H, D]
|
| 171 |
+
w2_selected = w2_weight[flat_idxs] # [T*A, D, H]
|
| 172 |
+
|
| 173 |
+
# Expand input for all token-expert pairs
|
| 174 |
+
x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
|
| 175 |
+
|
| 176 |
+
# First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
|
| 177 |
+
x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
|
| 178 |
+
-1
|
| 179 |
+
) # [T*A, H]
|
| 180 |
+
x1, g = x1_full.chunk(2, dim=-1)
|
| 181 |
+
x1 = F.gelu(x1) * (g + 1)
|
| 182 |
+
|
| 183 |
+
# Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
|
| 184 |
+
expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
|
| 185 |
+
|
| 186 |
+
# Apply weights and reshape
|
| 187 |
+
weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
|
| 188 |
+
weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D]
|
| 189 |
+
|
| 190 |
+
# Sum over experts
|
| 191 |
+
mlp_out = weighted_outs.sum(dim=1) # [T, D]
|
| 192 |
+
mlp_out = mlp_out.view(B, T, C)
|
| 193 |
+
|
| 194 |
+
return mlp_out
|
| 195 |
+
else:
|
| 196 |
+
out = x.new_zeros(x.size())
|
| 197 |
+
|
| 198 |
+
for expert_id in range(mlp_module.fc1.weight.shape[0]):
|
| 199 |
+
token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)
|
| 200 |
+
if token_pos.numel() == 0:
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
x_tok = x.index_select(0, token_pos)
|
| 204 |
+
gate_tok = topk_weights[token_pos, which_k]
|
| 205 |
+
|
| 206 |
+
h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
|
| 207 |
+
h, g = h_full.chunk(2, dim=-1)
|
| 208 |
+
h = F.gelu(h) * (g + 1)
|
| 209 |
+
y = F.linear(h, mlp_module.fc2.weight[expert_id])
|
| 210 |
+
|
| 211 |
+
y.mul_(gate_tok.unsqueeze(-1))
|
| 212 |
+
out.index_add_(0, token_pos, y)
|
| 213 |
+
|
| 214 |
+
return out.view(B, T, C)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@dataclass
|
| 218 |
+
class AttentionWeights:
|
| 219 |
+
qkv: LinearWeights
|
| 220 |
+
proj: LinearWeights
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
|
| 224 |
+
bsz, q_len, d_model = x.shape
|
| 225 |
+
head_dim = d_model // n_heads
|
| 226 |
+
|
| 227 |
+
q, k, v = [
|
| 228 |
+
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
| 229 |
+
for t in linear(x, w.qkv).chunk(3, dim=-1)
|
| 230 |
+
]
|
| 231 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
| 232 |
+
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
| 233 |
+
out = linear(out, w.proj)
|
| 234 |
+
return out
|
lora.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from urllib.request import Request, urlopen
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def variant_cache_dir():
|
| 12 |
+
hf_hub_cache = os.environ.get("HF_HUB_CACHE")
|
| 13 |
+
if hf_hub_cache is not None:
|
| 14 |
+
return Path(hf_hub_cache) / "md_variants"
|
| 15 |
+
|
| 16 |
+
hf_home = os.environ.get("HF_HOME")
|
| 17 |
+
if hf_home is not None:
|
| 18 |
+
return Path(hf_home) / "hub" / "md_variants"
|
| 19 |
+
|
| 20 |
+
return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def cached_variant_path(variant_id: str):
|
| 24 |
+
variant, *rest = variant_id.split("/", 1)
|
| 25 |
+
step = rest[0] if rest else "final"
|
| 26 |
+
|
| 27 |
+
cache_dir = variant_cache_dir() / variant
|
| 28 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 29 |
+
dest = cache_dir / f"{step}.pt"
|
| 30 |
+
if dest.exists():
|
| 31 |
+
return dest
|
| 32 |
+
|
| 33 |
+
md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
|
| 34 |
+
|
| 35 |
+
headers = {"User-Agent": "moondream-torch"}
|
| 36 |
+
api_key = os.getenv("MOONDREAM_API_KEY")
|
| 37 |
+
if api_key is not None:
|
| 38 |
+
headers["X-Moondream-Auth"] = api_key
|
| 39 |
+
|
| 40 |
+
req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
|
| 41 |
+
with urlopen(req) as r, open(dest, "wb") as f:
|
| 42 |
+
shutil.copyfileobj(r, f)
|
| 43 |
+
return dest
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def nest(flat):
|
| 47 |
+
tree = {}
|
| 48 |
+
for k, v in flat.items():
|
| 49 |
+
parts = k.split(".")
|
| 50 |
+
d = tree
|
| 51 |
+
for p in parts[:-1]:
|
| 52 |
+
d = d.setdefault(p, {})
|
| 53 |
+
d[parts[-1]] = v
|
| 54 |
+
return tree
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@functools.lru_cache(maxsize=5)
|
| 58 |
+
def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
|
| 59 |
+
if variant_id is None:
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
state_dict = torch.load(
|
| 63 |
+
cached_variant_path(variant_id), map_location=device, weights_only=True
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# TODO: Move these into the training code that saves checkpoints...
|
| 67 |
+
rename_rules = [
|
| 68 |
+
("text_model.transformer.h", "text.blocks"),
|
| 69 |
+
(".mixer", ".attn"),
|
| 70 |
+
(".out_proj", ".proj"),
|
| 71 |
+
(".Wqkv", ".qkv"),
|
| 72 |
+
(".parametrizations.weight.0", ""),
|
| 73 |
+
]
|
| 74 |
+
new_state_dict = {}
|
| 75 |
+
for key, tensor in state_dict.items():
|
| 76 |
+
new_key = key
|
| 77 |
+
for old, new in rename_rules:
|
| 78 |
+
if old in new_key:
|
| 79 |
+
new_key = new_key.replace(old, new)
|
| 80 |
+
new_state_dict[new_key] = tensor
|
| 81 |
+
|
| 82 |
+
return nest(new_state_dict)
|
model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd4b3d0d6daae9c4212056cd64f02f408ff083bbb0244114eecd05fcba30037e
|
| 3 |
+
size 4907406296
|
model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cf6d17391db58801b61173510ba629875679dbcbe4bfd3cb38ac0958b3c70a0
|
| 3 |
+
size 4736548872
|
model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4391c6d6b46ed49aa00afddf1f7df9dd0845cbc681fdaf424e727b01ea2d3e4
|
| 3 |
+
size 4502742464
|
model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6af14858bdd7cdea5d19d786726e48434b02a3c0c52a771a0f25b6a8ca640187
|
| 3 |
+
size 4390620392
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 18537245664
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"model.region.coord_decoder.bias": "model-00004-of-00004.safetensors",
|
| 7 |
+
"model.region.coord_decoder.weight": "model-00004-of-00004.safetensors",
|
| 8 |
+
"model.region.coord_encoder.bias": "model-00004-of-00004.safetensors",
|
| 9 |
+
"model.region.coord_encoder.weight": "model-00004-of-00004.safetensors",
|
| 10 |
+
"model.region.coord_features": "model-00004-of-00004.safetensors",
|
| 11 |
+
"model.region.size_decoder.bias": "model-00004-of-00004.safetensors",
|
| 12 |
+
"model.region.size_decoder.weight": "model-00004-of-00004.safetensors",
|
| 13 |
+
"model.region.size_encoder.bias": "model-00004-of-00004.safetensors",
|
| 14 |
+
"model.region.size_encoder.weight": "model-00004-of-00004.safetensors",
|
| 15 |
+
"model.region.size_features": "model-00004-of-00004.safetensors",
|
| 16 |
+
"model.text.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 17 |
+
"model.text.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 18 |
+
"model.text.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 19 |
+
"model.text.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 20 |
+
"model.text.blocks.0.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 21 |
+
"model.text.blocks.0.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 22 |
+
"model.text.blocks.0.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 23 |
+
"model.text.blocks.0.ln.bias": "model-00001-of-00004.safetensors",
|
| 24 |
+
"model.text.blocks.0.ln.weight": "model-00001-of-00004.safetensors",
|
| 25 |
+
"model.text.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 26 |
+
"model.text.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 27 |
+
"model.text.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 28 |
+
"model.text.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 29 |
+
"model.text.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 30 |
+
"model.text.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 31 |
+
"model.text.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 32 |
+
"model.text.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 33 |
+
"model.text.blocks.1.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 34 |
+
"model.text.blocks.1.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 35 |
+
"model.text.blocks.1.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 36 |
+
"model.text.blocks.1.ln.bias": "model-00001-of-00004.safetensors",
|
| 37 |
+
"model.text.blocks.1.ln.weight": "model-00001-of-00004.safetensors",
|
| 38 |
+
"model.text.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 39 |
+
"model.text.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 40 |
+
"model.text.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 41 |
+
"model.text.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 42 |
+
"model.text.blocks.10.attn.proj.bias": "model-00002-of-00004.safetensors",
|
| 43 |
+
"model.text.blocks.10.attn.proj.weight": "model-00002-of-00004.safetensors",
|
| 44 |
+
"model.text.blocks.10.attn.qkv.bias": "model-00002-of-00004.safetensors",
|
| 45 |
+
"model.text.blocks.10.attn.qkv.weight": "model-00002-of-00004.safetensors",
|
| 46 |
+
"model.text.blocks.10.attn.tau.alpha": "model-00002-of-00004.safetensors",
|
| 47 |
+
"model.text.blocks.10.attn.tau.wq": "model-00002-of-00004.safetensors",
|
| 48 |
+
"model.text.blocks.10.attn.tau.wv": "model-00002-of-00004.safetensors",
|
| 49 |
+
"model.text.blocks.10.ln.bias": "model-00002-of-00004.safetensors",
|
| 50 |
+
"model.text.blocks.10.ln.weight": "model-00002-of-00004.safetensors",
|
| 51 |
+
"model.text.blocks.10.mlp.fc1.weight": "model-00002-of-00004.safetensors",
|
| 52 |
+
"model.text.blocks.10.mlp.fc2.weight": "model-00002-of-00004.safetensors",
|
| 53 |
+
"model.text.blocks.10.mlp.router.bias": "model-00002-of-00004.safetensors",
|
| 54 |
+
"model.text.blocks.10.mlp.router.weight": "model-00002-of-00004.safetensors",
|
| 55 |
+
"model.text.blocks.11.attn.proj.bias": "model-00002-of-00004.safetensors",
|
| 56 |
+
"model.text.blocks.11.attn.proj.weight": "model-00002-of-00004.safetensors",
|
| 57 |
+
"model.text.blocks.11.attn.qkv.bias": "model-00002-of-00004.safetensors",
|
| 58 |
+
"model.text.blocks.11.attn.qkv.weight": "model-00002-of-00004.safetensors",
|
| 59 |
+
"model.text.blocks.11.attn.tau.alpha": "model-00002-of-00004.safetensors",
|
| 60 |
+
"model.text.blocks.11.attn.tau.wq": "model-00002-of-00004.safetensors",
|
| 61 |
+
"model.text.blocks.11.attn.tau.wv": "model-00002-of-00004.safetensors",
|
| 62 |
+
"model.text.blocks.11.ln.bias": "model-00002-of-00004.safetensors",
|
| 63 |
+
"model.text.blocks.11.ln.weight": "model-00002-of-00004.safetensors",
|
| 64 |
+
"model.text.blocks.11.mlp.fc1.weight": "model-00002-of-00004.safetensors",
|
| 65 |
+
"model.text.blocks.11.mlp.fc2.weight": "model-00002-of-00004.safetensors",
|
| 66 |
+
"model.text.blocks.11.mlp.router.bias": "model-00002-of-00004.safetensors",
|
| 67 |
+
"model.text.blocks.11.mlp.router.weight": "model-00002-of-00004.safetensors",
|
| 68 |
+
"model.text.blocks.12.attn.proj.bias": "model-00002-of-00004.safetensors",
|
| 69 |
+
"model.text.blocks.12.attn.proj.weight": "model-00002-of-00004.safetensors",
|
| 70 |
+
"model.text.blocks.12.attn.qkv.bias": "model-00002-of-00004.safetensors",
|
| 71 |
+
"model.text.blocks.12.attn.qkv.weight": "model-00002-of-00004.safetensors",
|
| 72 |
+
"model.text.blocks.12.attn.tau.alpha": "model-00002-of-00004.safetensors",
|
| 73 |
+
"model.text.blocks.12.attn.tau.wq": "model-00002-of-00004.safetensors",
|
| 74 |
+
"model.text.blocks.12.attn.tau.wv": "model-00002-of-00004.safetensors",
|
| 75 |
+
"model.text.blocks.12.ln.bias": "model-00002-of-00004.safetensors",
|
| 76 |
+
"model.text.blocks.12.ln.weight": "model-00002-of-00004.safetensors",
|
| 77 |
+
"model.text.blocks.12.mlp.fc1.weight": "model-00002-of-00004.safetensors",
|
| 78 |
+
"model.text.blocks.12.mlp.fc2.weight": "model-00002-of-00004.safetensors",
|
| 79 |
+
"model.text.blocks.12.mlp.router.bias": "model-00002-of-00004.safetensors",
|
| 80 |
+
"model.text.blocks.12.mlp.router.weight": "model-00002-of-00004.safetensors",
|
| 81 |
+
"model.text.blocks.13.attn.proj.bias": "model-00002-of-00004.safetensors",
|
| 82 |
+
"model.text.blocks.13.attn.proj.weight": "model-00002-of-00004.safetensors",
|
| 83 |
+
"model.text.blocks.13.attn.qkv.bias": "model-00002-of-00004.safetensors",
|
| 84 |
+
"model.text.blocks.13.attn.qkv.weight": "model-00002-of-00004.safetensors",
|
| 85 |
+
"model.text.blocks.13.attn.tau.alpha": "model-00002-of-00004.safetensors",
|
| 86 |
+
"model.text.blocks.13.attn.tau.wq": "model-00002-of-00004.safetensors",
|
| 87 |
+
"model.text.blocks.13.attn.tau.wv": "model-00002-of-00004.safetensors",
|
| 88 |
+
"model.text.blocks.13.ln.bias": "model-00002-of-00004.safetensors",
|
| 89 |
+
"model.text.blocks.13.ln.weight": "model-00002-of-00004.safetensors",
|
| 90 |
+
"model.text.blocks.13.mlp.fc1.weight": "model-00002-of-00004.safetensors",
|
| 91 |
+
"model.text.blocks.13.mlp.fc2.weight": "model-00003-of-00004.safetensors",
|
| 92 |
+
"model.text.blocks.13.mlp.router.bias": "model-00002-of-00004.safetensors",
|
| 93 |
+
"model.text.blocks.13.mlp.router.weight": "model-00002-of-00004.safetensors",
|
| 94 |
+
"model.text.blocks.14.attn.proj.bias": "model-00003-of-00004.safetensors",
|
| 95 |
+
"model.text.blocks.14.attn.proj.weight": "model-00003-of-00004.safetensors",
|
| 96 |
+
"model.text.blocks.14.attn.qkv.bias": "model-00003-of-00004.safetensors",
|
| 97 |
+
"model.text.blocks.14.attn.qkv.weight": "model-00003-of-00004.safetensors",
|
| 98 |
+
"model.text.blocks.14.attn.tau.alpha": "model-00003-of-00004.safetensors",
|
| 99 |
+
"model.text.blocks.14.attn.tau.wq": "model-00003-of-00004.safetensors",
|
| 100 |
+
"model.text.blocks.14.attn.tau.wv": "model-00003-of-00004.safetensors",
|
| 101 |
+
"model.text.blocks.14.ln.bias": "model-00003-of-00004.safetensors",
|
| 102 |
+
"model.text.blocks.14.ln.weight": "model-00003-of-00004.safetensors",
|
| 103 |
+
"model.text.blocks.14.mlp.fc1.weight": "model-00003-of-00004.safetensors",
|
| 104 |
+
"model.text.blocks.14.mlp.fc2.weight": "model-00003-of-00004.safetensors",
|
| 105 |
+
"model.text.blocks.14.mlp.router.bias": "model-00003-of-00004.safetensors",
|
| 106 |
+
"model.text.blocks.14.mlp.router.weight": "model-00003-of-00004.safetensors",
|
| 107 |
+
"model.text.blocks.15.attn.proj.bias": "model-00003-of-00004.safetensors",
|
| 108 |
+
"model.text.blocks.15.attn.proj.weight": "model-00003-of-00004.safetensors",
|
| 109 |
+
"model.text.blocks.15.attn.qkv.bias": "model-00003-of-00004.safetensors",
|
| 110 |
+
"model.text.blocks.15.attn.qkv.weight": "model-00003-of-00004.safetensors",
|
| 111 |
+
"model.text.blocks.15.attn.tau.alpha": "model-00003-of-00004.safetensors",
|
| 112 |
+
"model.text.blocks.15.attn.tau.wq": "model-00003-of-00004.safetensors",
|
| 113 |
+
"model.text.blocks.15.attn.tau.wv": "model-00003-of-00004.safetensors",
|
| 114 |
+
"model.text.blocks.15.ln.bias": "model-00003-of-00004.safetensors",
|
| 115 |
+
"model.text.blocks.15.ln.weight": "model-00003-of-00004.safetensors",
|
| 116 |
+
"model.text.blocks.15.mlp.fc1.weight": "model-00003-of-00004.safetensors",
|
| 117 |
+
"model.text.blocks.15.mlp.fc2.weight": "model-00003-of-00004.safetensors",
|
| 118 |
+
"model.text.blocks.15.mlp.router.bias": "model-00003-of-00004.safetensors",
|
| 119 |
+
"model.text.blocks.15.mlp.router.weight": "model-00003-of-00004.safetensors",
|
| 120 |
+
"model.text.blocks.16.attn.proj.bias": "model-00003-of-00004.safetensors",
|
| 121 |
+
"model.text.blocks.16.attn.proj.weight": "model-00003-of-00004.safetensors",
|
| 122 |
+
"model.text.blocks.16.attn.qkv.bias": "model-00003-of-00004.safetensors",
|
| 123 |
+
"model.text.blocks.16.attn.qkv.weight": "model-00003-of-00004.safetensors",
|
| 124 |
+
"model.text.blocks.16.attn.tau.alpha": "model-00003-of-00004.safetensors",
|
| 125 |
+
"model.text.blocks.16.attn.tau.wq": "model-00003-of-00004.safetensors",
|
| 126 |
+
"model.text.blocks.16.attn.tau.wv": "model-00003-of-00004.safetensors",
|
| 127 |
+
"model.text.blocks.16.ln.bias": "model-00003-of-00004.safetensors",
|
| 128 |
+
"model.text.blocks.16.ln.weight": "model-00003-of-00004.safetensors",
|
| 129 |
+
"model.text.blocks.16.mlp.fc1.weight": "model-00003-of-00004.safetensors",
|
| 130 |
+
"model.text.blocks.16.mlp.fc2.weight": "model-00003-of-00004.safetensors",
|
| 131 |
+
"model.text.blocks.16.mlp.router.bias": "model-00003-of-00004.safetensors",
|
| 132 |
+
"model.text.blocks.16.mlp.router.weight": "model-00003-of-00004.safetensors",
|
| 133 |
+
"model.text.blocks.17.attn.proj.bias": "model-00003-of-00004.safetensors",
|
| 134 |
+
"model.text.blocks.17.attn.proj.weight": "model-00003-of-00004.safetensors",
|
| 135 |
+
"model.text.blocks.17.attn.qkv.bias": "model-00003-of-00004.safetensors",
|
| 136 |
+
"model.text.blocks.17.attn.qkv.weight": "model-00003-of-00004.safetensors",
|
| 137 |
+
"model.text.blocks.17.attn.tau.alpha": "model-00003-of-00004.safetensors",
|
| 138 |
+
"model.text.blocks.17.attn.tau.wq": "model-00003-of-00004.safetensors",
|
| 139 |
+
"model.text.blocks.17.attn.tau.wv": "model-00003-of-00004.safetensors",
|
| 140 |
+
"model.text.blocks.17.ln.bias": "model-00003-of-00004.safetensors",
|
| 141 |
+
"model.text.blocks.17.ln.weight": "model-00003-of-00004.safetensors",
|
| 142 |
+
"model.text.blocks.17.mlp.fc1.weight": "model-00003-of-00004.safetensors",
|
| 143 |
+
"model.text.blocks.17.mlp.fc2.weight": "model-00003-of-00004.safetensors",
|
| 144 |
+
"model.text.blocks.17.mlp.router.bias": "model-00003-of-00004.safetensors",
|
| 145 |
+
"model.text.blocks.17.mlp.router.weight": "model-00003-of-00004.safetensors",
|
| 146 |
+
"model.text.blocks.18.attn.proj.bias": "model-00003-of-00004.safetensors",
|
| 147 |
+
"model.text.blocks.18.attn.proj.weight": "model-00003-of-00004.safetensors",
|
| 148 |
+
"model.text.blocks.18.attn.qkv.bias": "model-00003-of-00004.safetensors",
|
| 149 |
+
"model.text.blocks.18.attn.qkv.weight": "model-00003-of-00004.safetensors",
|
| 150 |
+
"model.text.blocks.18.attn.tau.alpha": "model-00003-of-00004.safetensors",
|
| 151 |
+
"model.text.blocks.18.attn.tau.wq": "model-00003-of-00004.safetensors",
|
| 152 |
+
"model.text.blocks.18.attn.tau.wv": "model-00003-of-00004.safetensors",
|
| 153 |
+
"model.text.blocks.18.ln.bias": "model-00003-of-00004.safetensors",
|
| 154 |
+
"model.text.blocks.18.ln.weight": "model-00003-of-00004.safetensors",
|
| 155 |
+
"model.text.blocks.18.mlp.fc1.weight": "model-00003-of-00004.safetensors",
|
| 156 |
+
"model.text.blocks.18.mlp.fc2.weight": "model-00003-of-00004.safetensors",
|
| 157 |
+
"model.text.blocks.18.mlp.router.bias": "model-00003-of-00004.safetensors",
|
| 158 |
+
"model.text.blocks.18.mlp.router.weight": "model-00003-of-00004.safetensors",
|
| 159 |
+
"model.text.blocks.19.attn.proj.bias": "model-00003-of-00004.safetensors",
|
| 160 |
+
"model.text.blocks.19.attn.proj.weight": "model-00003-of-00004.safetensors",
|
| 161 |
+
"model.text.blocks.19.attn.qkv.bias": "model-00003-of-00004.safetensors",
|
| 162 |
+
"model.text.blocks.19.attn.qkv.weight": "model-00003-of-00004.safetensors",
|
| 163 |
+
"model.text.blocks.19.attn.tau.alpha": "model-00003-of-00004.safetensors",
|
| 164 |
+
"model.text.blocks.19.attn.tau.wq": "model-00003-of-00004.safetensors",
|
| 165 |
+
"model.text.blocks.19.attn.tau.wv": "model-00003-of-00004.safetensors",
|
| 166 |
+
"model.text.blocks.19.ln.bias": "model-00003-of-00004.safetensors",
|
| 167 |
+
"model.text.blocks.19.ln.weight": "model-00003-of-00004.safetensors",
|
| 168 |
+
"model.text.blocks.19.mlp.fc1.weight": "model-00004-of-00004.safetensors",
|
| 169 |
+
"model.text.blocks.19.mlp.fc2.weight": "model-00004-of-00004.safetensors",
|
| 170 |
+
"model.text.blocks.19.mlp.router.bias": "model-00003-of-00004.safetensors",
|
| 171 |
+
"model.text.blocks.19.mlp.router.weight": "model-00003-of-00004.safetensors",
|
| 172 |
+
"model.text.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 173 |
+
"model.text.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 174 |
+
"model.text.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 175 |
+
"model.text.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 176 |
+
"model.text.blocks.2.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 177 |
+
"model.text.blocks.2.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 178 |
+
"model.text.blocks.2.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 179 |
+
"model.text.blocks.2.ln.bias": "model-00001-of-00004.safetensors",
|
| 180 |
+
"model.text.blocks.2.ln.weight": "model-00001-of-00004.safetensors",
|
| 181 |
+
"model.text.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 182 |
+
"model.text.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 183 |
+
"model.text.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 184 |
+
"model.text.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 185 |
+
"model.text.blocks.20.attn.proj.bias": "model-00004-of-00004.safetensors",
|
| 186 |
+
"model.text.blocks.20.attn.proj.weight": "model-00004-of-00004.safetensors",
|
| 187 |
+
"model.text.blocks.20.attn.qkv.bias": "model-00004-of-00004.safetensors",
|
| 188 |
+
"model.text.blocks.20.attn.qkv.weight": "model-00004-of-00004.safetensors",
|
| 189 |
+
"model.text.blocks.20.attn.tau.alpha": "model-00004-of-00004.safetensors",
|
| 190 |
+
"model.text.blocks.20.attn.tau.wq": "model-00004-of-00004.safetensors",
|
| 191 |
+
"model.text.blocks.20.attn.tau.wv": "model-00004-of-00004.safetensors",
|
| 192 |
+
"model.text.blocks.20.ln.bias": "model-00004-of-00004.safetensors",
|
| 193 |
+
"model.text.blocks.20.ln.weight": "model-00004-of-00004.safetensors",
|
| 194 |
+
"model.text.blocks.20.mlp.fc1.weight": "model-00004-of-00004.safetensors",
|
| 195 |
+
"model.text.blocks.20.mlp.fc2.weight": "model-00004-of-00004.safetensors",
|
| 196 |
+
"model.text.blocks.20.mlp.router.bias": "model-00004-of-00004.safetensors",
|
| 197 |
+
"model.text.blocks.20.mlp.router.weight": "model-00004-of-00004.safetensors",
|
| 198 |
+
"model.text.blocks.21.attn.proj.bias": "model-00004-of-00004.safetensors",
|
| 199 |
+
"model.text.blocks.21.attn.proj.weight": "model-00004-of-00004.safetensors",
|
| 200 |
+
"model.text.blocks.21.attn.qkv.bias": "model-00004-of-00004.safetensors",
|
| 201 |
+
"model.text.blocks.21.attn.qkv.weight": "model-00004-of-00004.safetensors",
|
| 202 |
+
"model.text.blocks.21.attn.tau.alpha": "model-00004-of-00004.safetensors",
|
| 203 |
+
"model.text.blocks.21.attn.tau.wq": "model-00004-of-00004.safetensors",
|
| 204 |
+
"model.text.blocks.21.attn.tau.wv": "model-00004-of-00004.safetensors",
|
| 205 |
+
"model.text.blocks.21.ln.bias": "model-00004-of-00004.safetensors",
|
| 206 |
+
"model.text.blocks.21.ln.weight": "model-00004-of-00004.safetensors",
|
| 207 |
+
"model.text.blocks.21.mlp.fc1.weight": "model-00004-of-00004.safetensors",
|
| 208 |
+
"model.text.blocks.21.mlp.fc2.weight": "model-00004-of-00004.safetensors",
|
| 209 |
+
"model.text.blocks.21.mlp.router.bias": "model-00004-of-00004.safetensors",
|
| 210 |
+
"model.text.blocks.21.mlp.router.weight": "model-00004-of-00004.safetensors",
|
| 211 |
+
"model.text.blocks.22.attn.proj.bias": "model-00004-of-00004.safetensors",
|
| 212 |
+
"model.text.blocks.22.attn.proj.weight": "model-00004-of-00004.safetensors",
|
| 213 |
+
"model.text.blocks.22.attn.qkv.bias": "model-00004-of-00004.safetensors",
|
| 214 |
+
"model.text.blocks.22.attn.qkv.weight": "model-00004-of-00004.safetensors",
|
| 215 |
+
"model.text.blocks.22.attn.tau.alpha": "model-00004-of-00004.safetensors",
|
| 216 |
+
"model.text.blocks.22.attn.tau.wq": "model-00004-of-00004.safetensors",
|
| 217 |
+
"model.text.blocks.22.attn.tau.wv": "model-00004-of-00004.safetensors",
|
| 218 |
+
"model.text.blocks.22.ln.bias": "model-00004-of-00004.safetensors",
|
| 219 |
+
"model.text.blocks.22.ln.weight": "model-00004-of-00004.safetensors",
|
| 220 |
+
"model.text.blocks.22.mlp.fc1.weight": "model-00004-of-00004.safetensors",
|
| 221 |
+
"model.text.blocks.22.mlp.fc2.weight": "model-00004-of-00004.safetensors",
|
| 222 |
+
"model.text.blocks.22.mlp.router.bias": "model-00004-of-00004.safetensors",
|
| 223 |
+
"model.text.blocks.22.mlp.router.weight": "model-00004-of-00004.safetensors",
|
| 224 |
+
"model.text.blocks.23.attn.proj.bias": "model-00004-of-00004.safetensors",
|
| 225 |
+
"model.text.blocks.23.attn.proj.weight": "model-00004-of-00004.safetensors",
|
| 226 |
+
"model.text.blocks.23.attn.qkv.bias": "model-00004-of-00004.safetensors",
|
| 227 |
+
"model.text.blocks.23.attn.qkv.weight": "model-00004-of-00004.safetensors",
|
| 228 |
+
"model.text.blocks.23.attn.tau.alpha": "model-00004-of-00004.safetensors",
|
| 229 |
+
"model.text.blocks.23.attn.tau.wq": "model-00004-of-00004.safetensors",
|
| 230 |
+
"model.text.blocks.23.attn.tau.wv": "model-00004-of-00004.safetensors",
|
| 231 |
+
"model.text.blocks.23.ln.bias": "model-00004-of-00004.safetensors",
|
| 232 |
+
"model.text.blocks.23.ln.weight": "model-00004-of-00004.safetensors",
|
| 233 |
+
"model.text.blocks.23.mlp.fc1.weight": "model-00004-of-00004.safetensors",
|
| 234 |
+
"model.text.blocks.23.mlp.fc2.weight": "model-00004-of-00004.safetensors",
|
| 235 |
+
"model.text.blocks.23.mlp.router.bias": "model-00004-of-00004.safetensors",
|
| 236 |
+
"model.text.blocks.23.mlp.router.weight": "model-00004-of-00004.safetensors",
|
| 237 |
+
"model.text.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 238 |
+
"model.text.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 239 |
+
"model.text.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 240 |
+
"model.text.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 241 |
+
"model.text.blocks.3.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 242 |
+
"model.text.blocks.3.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 243 |
+
"model.text.blocks.3.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 244 |
+
"model.text.blocks.3.ln.bias": "model-00001-of-00004.safetensors",
|
| 245 |
+
"model.text.blocks.3.ln.weight": "model-00001-of-00004.safetensors",
|
| 246 |
+
"model.text.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 247 |
+
"model.text.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 248 |
+
"model.text.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 249 |
+
"model.text.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 250 |
+
"model.text.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 251 |
+
"model.text.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 252 |
+
"model.text.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 253 |
+
"model.text.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 254 |
+
"model.text.blocks.4.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 255 |
+
"model.text.blocks.4.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 256 |
+
"model.text.blocks.4.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 257 |
+
"model.text.blocks.4.ln.bias": "model-00001-of-00004.safetensors",
|
| 258 |
+
"model.text.blocks.4.ln.weight": "model-00001-of-00004.safetensors",
|
| 259 |
+
"model.text.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 260 |
+
"model.text.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 261 |
+
"model.text.blocks.4.mlp.router.bias": "model-00001-of-00004.safetensors",
|
| 262 |
+
"model.text.blocks.4.mlp.router.weight": "model-00001-of-00004.safetensors",
|
| 263 |
+
"model.text.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 264 |
+
"model.text.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 265 |
+
"model.text.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 266 |
+
"model.text.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 267 |
+
"model.text.blocks.5.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 268 |
+
"model.text.blocks.5.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 269 |
+
"model.text.blocks.5.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 270 |
+
"model.text.blocks.5.ln.bias": "model-00001-of-00004.safetensors",
|
| 271 |
+
"model.text.blocks.5.ln.weight": "model-00001-of-00004.safetensors",
|
| 272 |
+
"model.text.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 273 |
+
"model.text.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 274 |
+
"model.text.blocks.5.mlp.router.bias": "model-00001-of-00004.safetensors",
|
| 275 |
+
"model.text.blocks.5.mlp.router.weight": "model-00001-of-00004.safetensors",
|
| 276 |
+
"model.text.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 277 |
+
"model.text.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 278 |
+
"model.text.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 279 |
+
"model.text.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 280 |
+
"model.text.blocks.6.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 281 |
+
"model.text.blocks.6.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 282 |
+
"model.text.blocks.6.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 283 |
+
"model.text.blocks.6.ln.bias": "model-00001-of-00004.safetensors",
|
| 284 |
+
"model.text.blocks.6.ln.weight": "model-00001-of-00004.safetensors",
|
| 285 |
+
"model.text.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 286 |
+
"model.text.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 287 |
+
"model.text.blocks.6.mlp.router.bias": "model-00001-of-00004.safetensors",
|
| 288 |
+
"model.text.blocks.6.mlp.router.weight": "model-00001-of-00004.safetensors",
|
| 289 |
+
"model.text.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 290 |
+
"model.text.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 291 |
+
"model.text.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 292 |
+
"model.text.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 293 |
+
"model.text.blocks.7.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 294 |
+
"model.text.blocks.7.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 295 |
+
"model.text.blocks.7.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 296 |
+
"model.text.blocks.7.ln.bias": "model-00001-of-00004.safetensors",
|
| 297 |
+
"model.text.blocks.7.ln.weight": "model-00001-of-00004.safetensors",
|
| 298 |
+
"model.text.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 299 |
+
"model.text.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 300 |
+
"model.text.blocks.7.mlp.router.bias": "model-00001-of-00004.safetensors",
|
| 301 |
+
"model.text.blocks.7.mlp.router.weight": "model-00001-of-00004.safetensors",
|
| 302 |
+
"model.text.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 303 |
+
"model.text.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 304 |
+
"model.text.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 305 |
+
"model.text.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 306 |
+
"model.text.blocks.8.attn.tau.alpha": "model-00001-of-00004.safetensors",
|
| 307 |
+
"model.text.blocks.8.attn.tau.wq": "model-00001-of-00004.safetensors",
|
| 308 |
+
"model.text.blocks.8.attn.tau.wv": "model-00001-of-00004.safetensors",
|
| 309 |
+
"model.text.blocks.8.ln.bias": "model-00001-of-00004.safetensors",
|
| 310 |
+
"model.text.blocks.8.ln.weight": "model-00001-of-00004.safetensors",
|
| 311 |
+
"model.text.blocks.8.mlp.fc1.weight": "model-00002-of-00004.safetensors",
|
| 312 |
+
"model.text.blocks.8.mlp.fc2.weight": "model-00002-of-00004.safetensors",
|
| 313 |
+
"model.text.blocks.8.mlp.router.bias": "model-00001-of-00004.safetensors",
|
| 314 |
+
"model.text.blocks.8.mlp.router.weight": "model-00001-of-00004.safetensors",
|
| 315 |
+
"model.text.blocks.9.attn.proj.bias": "model-00002-of-00004.safetensors",
|
| 316 |
+
"model.text.blocks.9.attn.proj.weight": "model-00002-of-00004.safetensors",
|
| 317 |
+
"model.text.blocks.9.attn.qkv.bias": "model-00002-of-00004.safetensors",
|
| 318 |
+
"model.text.blocks.9.attn.qkv.weight": "model-00002-of-00004.safetensors",
|
| 319 |
+
"model.text.blocks.9.attn.tau.alpha": "model-00002-of-00004.safetensors",
|
| 320 |
+
"model.text.blocks.9.attn.tau.wq": "model-00002-of-00004.safetensors",
|
| 321 |
+
"model.text.blocks.9.attn.tau.wv": "model-00002-of-00004.safetensors",
|
| 322 |
+
"model.text.blocks.9.ln.bias": "model-00002-of-00004.safetensors",
|
| 323 |
+
"model.text.blocks.9.ln.weight": "model-00002-of-00004.safetensors",
|
| 324 |
+
"model.text.blocks.9.mlp.fc1.weight": "model-00002-of-00004.safetensors",
|
| 325 |
+
"model.text.blocks.9.mlp.fc2.weight": "model-00002-of-00004.safetensors",
|
| 326 |
+
"model.text.blocks.9.mlp.router.bias": "model-00002-of-00004.safetensors",
|
| 327 |
+
"model.text.blocks.9.mlp.router.weight": "model-00002-of-00004.safetensors",
|
| 328 |
+
"model.text.lm_head.bias": "model-00004-of-00004.safetensors",
|
| 329 |
+
"model.text.lm_head.weight": "model-00004-of-00004.safetensors",
|
| 330 |
+
"model.text.post_ln.bias": "model-00004-of-00004.safetensors",
|
| 331 |
+
"model.text.post_ln.weight": "model-00004-of-00004.safetensors",
|
| 332 |
+
"model.text.wte": "model-00001-of-00004.safetensors",
|
| 333 |
+
"model.vision.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 334 |
+
"model.vision.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 335 |
+
"model.vision.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 336 |
+
"model.vision.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 337 |
+
"model.vision.blocks.0.ln1.bias": "model-00001-of-00004.safetensors",
|
| 338 |
+
"model.vision.blocks.0.ln1.weight": "model-00001-of-00004.safetensors",
|
| 339 |
+
"model.vision.blocks.0.ln2.bias": "model-00001-of-00004.safetensors",
|
| 340 |
+
"model.vision.blocks.0.ln2.weight": "model-00001-of-00004.safetensors",
|
| 341 |
+
"model.vision.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 342 |
+
"model.vision.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 343 |
+
"model.vision.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 344 |
+
"model.vision.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 345 |
+
"model.vision.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 346 |
+
"model.vision.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 347 |
+
"model.vision.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 348 |
+
"model.vision.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 349 |
+
"model.vision.blocks.1.ln1.bias": "model-00001-of-00004.safetensors",
|
| 350 |
+
"model.vision.blocks.1.ln1.weight": "model-00001-of-00004.safetensors",
|
| 351 |
+
"model.vision.blocks.1.ln2.bias": "model-00001-of-00004.safetensors",
|
| 352 |
+
"model.vision.blocks.1.ln2.weight": "model-00001-of-00004.safetensors",
|
| 353 |
+
"model.vision.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 354 |
+
"model.vision.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 355 |
+
"model.vision.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 356 |
+
"model.vision.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 357 |
+
"model.vision.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 358 |
+
"model.vision.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 359 |
+
"model.vision.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 360 |
+
"model.vision.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 361 |
+
"model.vision.blocks.10.ln1.bias": "model-00001-of-00004.safetensors",
|
| 362 |
+
"model.vision.blocks.10.ln1.weight": "model-00001-of-00004.safetensors",
|
| 363 |
+
"model.vision.blocks.10.ln2.bias": "model-00001-of-00004.safetensors",
|
| 364 |
+
"model.vision.blocks.10.ln2.weight": "model-00001-of-00004.safetensors",
|
| 365 |
+
"model.vision.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 366 |
+
"model.vision.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 367 |
+
"model.vision.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 368 |
+
"model.vision.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 369 |
+
"model.vision.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 370 |
+
"model.vision.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 371 |
+
"model.vision.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 372 |
+
"model.vision.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 373 |
+
"model.vision.blocks.11.ln1.bias": "model-00001-of-00004.safetensors",
|
| 374 |
+
"model.vision.blocks.11.ln1.weight": "model-00001-of-00004.safetensors",
|
| 375 |
+
"model.vision.blocks.11.ln2.bias": "model-00001-of-00004.safetensors",
|
| 376 |
+
"model.vision.blocks.11.ln2.weight": "model-00001-of-00004.safetensors",
|
| 377 |
+
"model.vision.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 378 |
+
"model.vision.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 379 |
+
"model.vision.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 380 |
+
"model.vision.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 381 |
+
"model.vision.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 382 |
+
"model.vision.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 383 |
+
"model.vision.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 384 |
+
"model.vision.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 385 |
+
"model.vision.blocks.12.ln1.bias": "model-00001-of-00004.safetensors",
|
| 386 |
+
"model.vision.blocks.12.ln1.weight": "model-00001-of-00004.safetensors",
|
| 387 |
+
"model.vision.blocks.12.ln2.bias": "model-00001-of-00004.safetensors",
|
| 388 |
+
"model.vision.blocks.12.ln2.weight": "model-00001-of-00004.safetensors",
|
| 389 |
+
"model.vision.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 390 |
+
"model.vision.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 391 |
+
"model.vision.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 392 |
+
"model.vision.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 393 |
+
"model.vision.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 394 |
+
"model.vision.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 395 |
+
"model.vision.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 396 |
+
"model.vision.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 397 |
+
"model.vision.blocks.13.ln1.bias": "model-00001-of-00004.safetensors",
|
| 398 |
+
"model.vision.blocks.13.ln1.weight": "model-00001-of-00004.safetensors",
|
| 399 |
+
"model.vision.blocks.13.ln2.bias": "model-00001-of-00004.safetensors",
|
| 400 |
+
"model.vision.blocks.13.ln2.weight": "model-00001-of-00004.safetensors",
|
| 401 |
+
"model.vision.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 402 |
+
"model.vision.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 403 |
+
"model.vision.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 404 |
+
"model.vision.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 405 |
+
"model.vision.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 406 |
+
"model.vision.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 407 |
+
"model.vision.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 408 |
+
"model.vision.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 409 |
+
"model.vision.blocks.14.ln1.bias": "model-00001-of-00004.safetensors",
|
| 410 |
+
"model.vision.blocks.14.ln1.weight": "model-00001-of-00004.safetensors",
|
| 411 |
+
"model.vision.blocks.14.ln2.bias": "model-00001-of-00004.safetensors",
|
| 412 |
+
"model.vision.blocks.14.ln2.weight": "model-00001-of-00004.safetensors",
|
| 413 |
+
"model.vision.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 414 |
+
"model.vision.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 415 |
+
"model.vision.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 416 |
+
"model.vision.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 417 |
+
"model.vision.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 418 |
+
"model.vision.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 419 |
+
"model.vision.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 420 |
+
"model.vision.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 421 |
+
"model.vision.blocks.15.ln1.bias": "model-00001-of-00004.safetensors",
|
| 422 |
+
"model.vision.blocks.15.ln1.weight": "model-00001-of-00004.safetensors",
|
| 423 |
+
"model.vision.blocks.15.ln2.bias": "model-00001-of-00004.safetensors",
|
| 424 |
+
"model.vision.blocks.15.ln2.weight": "model-00001-of-00004.safetensors",
|
| 425 |
+
"model.vision.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 426 |
+
"model.vision.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 427 |
+
"model.vision.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 428 |
+
"model.vision.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 429 |
+
"model.vision.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 430 |
+
"model.vision.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 431 |
+
"model.vision.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 432 |
+
"model.vision.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 433 |
+
"model.vision.blocks.16.ln1.bias": "model-00001-of-00004.safetensors",
|
| 434 |
+
"model.vision.blocks.16.ln1.weight": "model-00001-of-00004.safetensors",
|
| 435 |
+
"model.vision.blocks.16.ln2.bias": "model-00001-of-00004.safetensors",
|
| 436 |
+
"model.vision.blocks.16.ln2.weight": "model-00001-of-00004.safetensors",
|
| 437 |
+
"model.vision.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 438 |
+
"model.vision.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 439 |
+
"model.vision.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 440 |
+
"model.vision.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 441 |
+
"model.vision.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 442 |
+
"model.vision.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 443 |
+
"model.vision.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 444 |
+
"model.vision.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 445 |
+
"model.vision.blocks.17.ln1.bias": "model-00001-of-00004.safetensors",
|
| 446 |
+
"model.vision.blocks.17.ln1.weight": "model-00001-of-00004.safetensors",
|
| 447 |
+
"model.vision.blocks.17.ln2.bias": "model-00001-of-00004.safetensors",
|
| 448 |
+
"model.vision.blocks.17.ln2.weight": "model-00001-of-00004.safetensors",
|
| 449 |
+
"model.vision.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 450 |
+
"model.vision.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 451 |
+
"model.vision.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 452 |
+
"model.vision.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 453 |
+
"model.vision.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 454 |
+
"model.vision.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 455 |
+
"model.vision.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 456 |
+
"model.vision.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 457 |
+
"model.vision.blocks.18.ln1.bias": "model-00001-of-00004.safetensors",
|
| 458 |
+
"model.vision.blocks.18.ln1.weight": "model-00001-of-00004.safetensors",
|
| 459 |
+
"model.vision.blocks.18.ln2.bias": "model-00001-of-00004.safetensors",
|
| 460 |
+
"model.vision.blocks.18.ln2.weight": "model-00001-of-00004.safetensors",
|
| 461 |
+
"model.vision.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 462 |
+
"model.vision.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 463 |
+
"model.vision.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 464 |
+
"model.vision.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 465 |
+
"model.vision.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 466 |
+
"model.vision.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 467 |
+
"model.vision.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 468 |
+
"model.vision.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 469 |
+
"model.vision.blocks.19.ln1.bias": "model-00001-of-00004.safetensors",
|
| 470 |
+
"model.vision.blocks.19.ln1.weight": "model-00001-of-00004.safetensors",
|
| 471 |
+
"model.vision.blocks.19.ln2.bias": "model-00001-of-00004.safetensors",
|
| 472 |
+
"model.vision.blocks.19.ln2.weight": "model-00001-of-00004.safetensors",
|
| 473 |
+
"model.vision.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 474 |
+
"model.vision.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 475 |
+
"model.vision.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 476 |
+
"model.vision.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 477 |
+
"model.vision.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 478 |
+
"model.vision.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 479 |
+
"model.vision.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 480 |
+
"model.vision.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 481 |
+
"model.vision.blocks.2.ln1.bias": "model-00001-of-00004.safetensors",
|
| 482 |
+
"model.vision.blocks.2.ln1.weight": "model-00001-of-00004.safetensors",
|
| 483 |
+
"model.vision.blocks.2.ln2.bias": "model-00001-of-00004.safetensors",
|
| 484 |
+
"model.vision.blocks.2.ln2.weight": "model-00001-of-00004.safetensors",
|
| 485 |
+
"model.vision.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 486 |
+
"model.vision.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 487 |
+
"model.vision.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 488 |
+
"model.vision.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 489 |
+
"model.vision.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 490 |
+
"model.vision.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 491 |
+
"model.vision.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 492 |
+
"model.vision.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 493 |
+
"model.vision.blocks.20.ln1.bias": "model-00001-of-00004.safetensors",
|
| 494 |
+
"model.vision.blocks.20.ln1.weight": "model-00001-of-00004.safetensors",
|
| 495 |
+
"model.vision.blocks.20.ln2.bias": "model-00001-of-00004.safetensors",
|
| 496 |
+
"model.vision.blocks.20.ln2.weight": "model-00001-of-00004.safetensors",
|
| 497 |
+
"model.vision.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 498 |
+
"model.vision.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 499 |
+
"model.vision.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 500 |
+
"model.vision.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 501 |
+
"model.vision.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 502 |
+
"model.vision.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 503 |
+
"model.vision.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 504 |
+
"model.vision.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 505 |
+
"model.vision.blocks.21.ln1.bias": "model-00001-of-00004.safetensors",
|
| 506 |
+
"model.vision.blocks.21.ln1.weight": "model-00001-of-00004.safetensors",
|
| 507 |
+
"model.vision.blocks.21.ln2.bias": "model-00001-of-00004.safetensors",
|
| 508 |
+
"model.vision.blocks.21.ln2.weight": "model-00001-of-00004.safetensors",
|
| 509 |
+
"model.vision.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 510 |
+
"model.vision.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 511 |
+
"model.vision.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 512 |
+
"model.vision.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 513 |
+
"model.vision.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 514 |
+
"model.vision.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 515 |
+
"model.vision.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 516 |
+
"model.vision.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 517 |
+
"model.vision.blocks.22.ln1.bias": "model-00001-of-00004.safetensors",
|
| 518 |
+
"model.vision.blocks.22.ln1.weight": "model-00001-of-00004.safetensors",
|
| 519 |
+
"model.vision.blocks.22.ln2.bias": "model-00001-of-00004.safetensors",
|
| 520 |
+
"model.vision.blocks.22.ln2.weight": "model-00001-of-00004.safetensors",
|
| 521 |
+
"model.vision.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 522 |
+
"model.vision.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 523 |
+
"model.vision.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 524 |
+
"model.vision.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 525 |
+
"model.vision.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 526 |
+
"model.vision.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 527 |
+
"model.vision.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 528 |
+
"model.vision.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 529 |
+
"model.vision.blocks.23.ln1.bias": "model-00001-of-00004.safetensors",
|
| 530 |
+
"model.vision.blocks.23.ln1.weight": "model-00001-of-00004.safetensors",
|
| 531 |
+
"model.vision.blocks.23.ln2.bias": "model-00001-of-00004.safetensors",
|
| 532 |
+
"model.vision.blocks.23.ln2.weight": "model-00001-of-00004.safetensors",
|
| 533 |
+
"model.vision.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 534 |
+
"model.vision.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 535 |
+
"model.vision.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 536 |
+
"model.vision.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 537 |
+
"model.vision.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 538 |
+
"model.vision.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 539 |
+
"model.vision.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 540 |
+
"model.vision.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 541 |
+
"model.vision.blocks.24.ln1.bias": "model-00001-of-00004.safetensors",
|
| 542 |
+
"model.vision.blocks.24.ln1.weight": "model-00001-of-00004.safetensors",
|
| 543 |
+
"model.vision.blocks.24.ln2.bias": "model-00001-of-00004.safetensors",
|
| 544 |
+
"model.vision.blocks.24.ln2.weight": "model-00001-of-00004.safetensors",
|
| 545 |
+
"model.vision.blocks.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 546 |
+
"model.vision.blocks.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 547 |
+
"model.vision.blocks.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 548 |
+
"model.vision.blocks.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 549 |
+
"model.vision.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 550 |
+
"model.vision.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 551 |
+
"model.vision.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 552 |
+
"model.vision.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 553 |
+
"model.vision.blocks.25.ln1.bias": "model-00001-of-00004.safetensors",
|
| 554 |
+
"model.vision.blocks.25.ln1.weight": "model-00001-of-00004.safetensors",
|
| 555 |
+
"model.vision.blocks.25.ln2.bias": "model-00001-of-00004.safetensors",
|
| 556 |
+
"model.vision.blocks.25.ln2.weight": "model-00001-of-00004.safetensors",
|
| 557 |
+
"model.vision.blocks.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 558 |
+
"model.vision.blocks.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 559 |
+
"model.vision.blocks.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 560 |
+
"model.vision.blocks.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 561 |
+
"model.vision.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 562 |
+
"model.vision.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 563 |
+
"model.vision.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 564 |
+
"model.vision.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 565 |
+
"model.vision.blocks.26.ln1.bias": "model-00001-of-00004.safetensors",
|
| 566 |
+
"model.vision.blocks.26.ln1.weight": "model-00001-of-00004.safetensors",
|
| 567 |
+
"model.vision.blocks.26.ln2.bias": "model-00001-of-00004.safetensors",
|
| 568 |
+
"model.vision.blocks.26.ln2.weight": "model-00001-of-00004.safetensors",
|
| 569 |
+
"model.vision.blocks.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 570 |
+
"model.vision.blocks.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 571 |
+
"model.vision.blocks.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 572 |
+
"model.vision.blocks.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 573 |
+
"model.vision.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 574 |
+
"model.vision.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 575 |
+
"model.vision.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 576 |
+
"model.vision.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 577 |
+
"model.vision.blocks.3.ln1.bias": "model-00001-of-00004.safetensors",
|
| 578 |
+
"model.vision.blocks.3.ln1.weight": "model-00001-of-00004.safetensors",
|
| 579 |
+
"model.vision.blocks.3.ln2.bias": "model-00001-of-00004.safetensors",
|
| 580 |
+
"model.vision.blocks.3.ln2.weight": "model-00001-of-00004.safetensors",
|
| 581 |
+
"model.vision.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 582 |
+
"model.vision.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 583 |
+
"model.vision.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 584 |
+
"model.vision.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 585 |
+
"model.vision.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 586 |
+
"model.vision.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 587 |
+
"model.vision.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 588 |
+
"model.vision.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 589 |
+
"model.vision.blocks.4.ln1.bias": "model-00001-of-00004.safetensors",
|
| 590 |
+
"model.vision.blocks.4.ln1.weight": "model-00001-of-00004.safetensors",
|
| 591 |
+
"model.vision.blocks.4.ln2.bias": "model-00001-of-00004.safetensors",
|
| 592 |
+
"model.vision.blocks.4.ln2.weight": "model-00001-of-00004.safetensors",
|
| 593 |
+
"model.vision.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 594 |
+
"model.vision.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 595 |
+
"model.vision.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 596 |
+
"model.vision.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 597 |
+
"model.vision.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 598 |
+
"model.vision.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 599 |
+
"model.vision.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 600 |
+
"model.vision.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 601 |
+
"model.vision.blocks.5.ln1.bias": "model-00001-of-00004.safetensors",
|
| 602 |
+
"model.vision.blocks.5.ln1.weight": "model-00001-of-00004.safetensors",
|
| 603 |
+
"model.vision.blocks.5.ln2.bias": "model-00001-of-00004.safetensors",
|
| 604 |
+
"model.vision.blocks.5.ln2.weight": "model-00001-of-00004.safetensors",
|
| 605 |
+
"model.vision.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 606 |
+
"model.vision.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 607 |
+
"model.vision.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 608 |
+
"model.vision.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 609 |
+
"model.vision.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 610 |
+
"model.vision.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 611 |
+
"model.vision.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 612 |
+
"model.vision.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 613 |
+
"model.vision.blocks.6.ln1.bias": "model-00001-of-00004.safetensors",
|
| 614 |
+
"model.vision.blocks.6.ln1.weight": "model-00001-of-00004.safetensors",
|
| 615 |
+
"model.vision.blocks.6.ln2.bias": "model-00001-of-00004.safetensors",
|
| 616 |
+
"model.vision.blocks.6.ln2.weight": "model-00001-of-00004.safetensors",
|
| 617 |
+
"model.vision.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 618 |
+
"model.vision.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 619 |
+
"model.vision.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 620 |
+
"model.vision.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 621 |
+
"model.vision.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 622 |
+
"model.vision.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 623 |
+
"model.vision.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 624 |
+
"model.vision.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 625 |
+
"model.vision.blocks.7.ln1.bias": "model-00001-of-00004.safetensors",
|
| 626 |
+
"model.vision.blocks.7.ln1.weight": "model-00001-of-00004.safetensors",
|
| 627 |
+
"model.vision.blocks.7.ln2.bias": "model-00001-of-00004.safetensors",
|
| 628 |
+
"model.vision.blocks.7.ln2.weight": "model-00001-of-00004.safetensors",
|
| 629 |
+
"model.vision.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 630 |
+
"model.vision.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 631 |
+
"model.vision.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 632 |
+
"model.vision.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 633 |
+
"model.vision.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 634 |
+
"model.vision.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 635 |
+
"model.vision.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 636 |
+
"model.vision.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 637 |
+
"model.vision.blocks.8.ln1.bias": "model-00001-of-00004.safetensors",
|
| 638 |
+
"model.vision.blocks.8.ln1.weight": "model-00001-of-00004.safetensors",
|
| 639 |
+
"model.vision.blocks.8.ln2.bias": "model-00001-of-00004.safetensors",
|
| 640 |
+
"model.vision.blocks.8.ln2.weight": "model-00001-of-00004.safetensors",
|
| 641 |
+
"model.vision.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 642 |
+
"model.vision.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 643 |
+
"model.vision.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 644 |
+
"model.vision.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 645 |
+
"model.vision.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 646 |
+
"model.vision.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 647 |
+
"model.vision.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 648 |
+
"model.vision.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 649 |
+
"model.vision.blocks.9.ln1.bias": "model-00001-of-00004.safetensors",
|
| 650 |
+
"model.vision.blocks.9.ln1.weight": "model-00001-of-00004.safetensors",
|
| 651 |
+
"model.vision.blocks.9.ln2.bias": "model-00001-of-00004.safetensors",
|
| 652 |
+
"model.vision.blocks.9.ln2.weight": "model-00001-of-00004.safetensors",
|
| 653 |
+
"model.vision.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 654 |
+
"model.vision.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 655 |
+
"model.vision.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 656 |
+
"model.vision.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 657 |
+
"model.vision.patch_emb.bias": "model-00001-of-00004.safetensors",
|
| 658 |
+
"model.vision.patch_emb.weight": "model-00001-of-00004.safetensors",
|
| 659 |
+
"model.vision.pos_emb": "model-00001-of-00004.safetensors",
|
| 660 |
+
"model.vision.post_ln.bias": "model-00001-of-00004.safetensors",
|
| 661 |
+
"model.vision.post_ln.weight": "model-00001-of-00004.safetensors",
|
| 662 |
+
"model.vision.proj_mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 663 |
+
"model.vision.proj_mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 664 |
+
"model.vision.proj_mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 665 |
+
"model.vision.proj_mlp.fc2.weight": "model-00001-of-00004.safetensors"
|
| 666 |
+
}
|
| 667 |
+
}
|
moondream.py
ADDED
|
@@ -0,0 +1,1070 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from tokenizers import Tokenizer
|
| 9 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
| 10 |
+
|
| 11 |
+
from .config import MoondreamConfig
|
| 12 |
+
from .image_crops import reconstruct_from_crops
|
| 13 |
+
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
| 14 |
+
from .text import build_text_model, text_encoder, lm_head, text_decoder
|
| 15 |
+
from .region import (
|
| 16 |
+
decode_coordinate,
|
| 17 |
+
encode_coordinate,
|
| 18 |
+
decode_size,
|
| 19 |
+
encode_size,
|
| 20 |
+
encode_spatial_refs,
|
| 21 |
+
SpatialRefs,
|
| 22 |
+
)
|
| 23 |
+
from .layers import QuantizedLinear
|
| 24 |
+
from .lora import variant_state_dict
|
| 25 |
+
from .utils import remove_outlier_points
|
| 26 |
+
|
| 27 |
+
ImageEncodingSettings = TypedDict(
|
| 28 |
+
"ImageEncodingSettings",
|
| 29 |
+
{"variant": str},
|
| 30 |
+
total=False,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
TextSamplingSettings = TypedDict(
|
| 34 |
+
"TextSamplingSettings",
|
| 35 |
+
{
|
| 36 |
+
"max_tokens": int,
|
| 37 |
+
"temperature": float,
|
| 38 |
+
"top_p": float,
|
| 39 |
+
"variant": str,
|
| 40 |
+
},
|
| 41 |
+
total=False,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
ObjectSamplingSettings = TypedDict(
|
| 45 |
+
"ObjectSamplingSettings",
|
| 46 |
+
{"max_objects": int, "variant": str},
|
| 47 |
+
total=False,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
DEFAULT_MAX_TOKENS = 768
|
| 52 |
+
DEFAULT_TEMPERATURE = 0.5
|
| 53 |
+
DEFAULT_TOP_P = 0.9
|
| 54 |
+
DEFAULT_MAX_OBJECTS = 150
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass(frozen=True)
|
| 58 |
+
class EncodedImage:
|
| 59 |
+
pos: int
|
| 60 |
+
caches: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class KVCache(nn.Module):
|
| 64 |
+
|
| 65 |
+
def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
|
| 66 |
+
super().__init__()
|
| 67 |
+
cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
|
| 68 |
+
self.register_buffer(
|
| 69 |
+
"k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 70 |
+
)
|
| 71 |
+
self.register_buffer(
|
| 72 |
+
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def update(self, pos_ids, k, v):
|
| 76 |
+
kout, vout = self.k_cache, self.v_cache
|
| 77 |
+
kout[:, :, pos_ids, :] = k
|
| 78 |
+
vout[:, :, pos_ids, :] = v
|
| 79 |
+
return kout, vout
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
| 83 |
+
return q_idx >= kv_idx
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_mask_mod(mask_mod, offset):
|
| 87 |
+
def _mask_mod(b, h, q, kv):
|
| 88 |
+
return mask_mod(b, h, q + offset, kv)
|
| 89 |
+
|
| 90 |
+
return _mask_mod
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class MoondreamModel(nn.Module):
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.config = config
|
| 100 |
+
|
| 101 |
+
self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
|
| 102 |
+
self.vision = build_vision_model(config.vision, dtype)
|
| 103 |
+
self.text = build_text_model(config.text, dtype)
|
| 104 |
+
|
| 105 |
+
# Region Model
|
| 106 |
+
linear_cls = (
|
| 107 |
+
QuantizedLinear if config.region.group_size is not None else nn.Linear
|
| 108 |
+
)
|
| 109 |
+
self.region = nn.ModuleDict(
|
| 110 |
+
{
|
| 111 |
+
"coord_encoder": linear_cls(
|
| 112 |
+
config.region.coord_feat_dim, config.region.dim, dtype=dtype
|
| 113 |
+
),
|
| 114 |
+
"coord_decoder": linear_cls(
|
| 115 |
+
config.region.dim, config.region.coord_out_dim, dtype=dtype
|
| 116 |
+
),
|
| 117 |
+
"size_encoder": linear_cls(
|
| 118 |
+
config.region.size_feat_dim, config.region.dim, dtype=dtype
|
| 119 |
+
),
|
| 120 |
+
"size_decoder": linear_cls(
|
| 121 |
+
config.region.dim, config.region.size_out_dim, dtype=dtype
|
| 122 |
+
),
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
self.region.coord_features = nn.Parameter(
|
| 126 |
+
torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
|
| 127 |
+
)
|
| 128 |
+
self.region.size_features = nn.Parameter(
|
| 129 |
+
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
attn_mask = torch.tril(
|
| 133 |
+
torch.ones(
|
| 134 |
+
1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
patch_w = config.vision.crop_size // config.vision.enc_patch_size
|
| 138 |
+
prefix_attn_len = 1 + patch_w**2
|
| 139 |
+
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
|
| 140 |
+
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
| 141 |
+
|
| 142 |
+
self.use_flex_decoding = True
|
| 143 |
+
self._causal_block_mask = None
|
| 144 |
+
self._point_gen_indices = None
|
| 145 |
+
|
| 146 |
+
# Initialize KV caches.
|
| 147 |
+
if setup_caches:
|
| 148 |
+
self._setup_caches()
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def causal_block_mask(self):
|
| 152 |
+
# The things we do to deal with ZeroGPU...
|
| 153 |
+
if self._causal_block_mask is None:
|
| 154 |
+
self._causal_block_mask = create_block_mask(
|
| 155 |
+
causal_mask,
|
| 156 |
+
B=None,
|
| 157 |
+
H=None,
|
| 158 |
+
Q_LEN=self.config.text.max_context,
|
| 159 |
+
KV_LEN=self.config.text.max_context,
|
| 160 |
+
)
|
| 161 |
+
return self._causal_block_mask
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def point_gen_indices(self):
|
| 165 |
+
if self._point_gen_indices is None:
|
| 166 |
+
self._point_gen_indices = torch.tensor(
|
| 167 |
+
[self.config.tokenizer.coord_id, self.config.tokenizer.eos_id],
|
| 168 |
+
device=self.device,
|
| 169 |
+
)
|
| 170 |
+
return self._point_gen_indices
|
| 171 |
+
|
| 172 |
+
def _setup_caches(self):
|
| 173 |
+
c = self.config.text
|
| 174 |
+
for b in self.text.blocks:
|
| 175 |
+
b.kv_cache = KVCache(
|
| 176 |
+
c.n_heads,
|
| 177 |
+
c.n_kv_heads,
|
| 178 |
+
c.max_context,
|
| 179 |
+
c.dim,
|
| 180 |
+
device=self.device,
|
| 181 |
+
dtype=self.vision.pos_emb.dtype,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def device(self):
|
| 186 |
+
return self.vision.pos_emb.device
|
| 187 |
+
|
| 188 |
+
def _vis_enc(self, x: torch.Tensor):
|
| 189 |
+
return vision_encoder(x, self.vision, self.config.vision)
|
| 190 |
+
|
| 191 |
+
def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
|
| 192 |
+
return vision_projection(g, r, self.vision, self.config.vision)
|
| 193 |
+
|
| 194 |
+
def _prefill(
|
| 195 |
+
self,
|
| 196 |
+
x: torch.Tensor,
|
| 197 |
+
attn_mask: torch.Tensor,
|
| 198 |
+
pos_ids: torch.Tensor,
|
| 199 |
+
lora: Optional[torch.Tensor],
|
| 200 |
+
):
|
| 201 |
+
return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
|
| 202 |
+
|
| 203 |
+
def _decode_one_tok(
|
| 204 |
+
self,
|
| 205 |
+
x: torch.Tensor,
|
| 206 |
+
attn_mask: torch.Tensor,
|
| 207 |
+
pos_ids: torch.Tensor,
|
| 208 |
+
lora: Optional[torch.Tensor],
|
| 209 |
+
lm_head_indices: Optional[torch.Tensor] = None,
|
| 210 |
+
):
|
| 211 |
+
if self.use_flex_decoding:
|
| 212 |
+
torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape")
|
| 213 |
+
block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0]
|
| 214 |
+
mask = self.causal_block_mask[:, :, block_index]
|
| 215 |
+
mask.seq_lengths = (1, mask.seq_lengths[1])
|
| 216 |
+
mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0])
|
| 217 |
+
else:
|
| 218 |
+
mask = None
|
| 219 |
+
|
| 220 |
+
hidden = text_decoder(
|
| 221 |
+
x,
|
| 222 |
+
self.text,
|
| 223 |
+
attn_mask,
|
| 224 |
+
pos_ids,
|
| 225 |
+
self.config.text,
|
| 226 |
+
lora=lora,
|
| 227 |
+
flex_block_mask_slice=mask,
|
| 228 |
+
)
|
| 229 |
+
logits = lm_head(hidden, self.text, indices=lm_head_indices)
|
| 230 |
+
return logits, hidden
|
| 231 |
+
|
| 232 |
+
def compile(self):
|
| 233 |
+
for module in self.modules():
|
| 234 |
+
if isinstance(module, QuantizedLinear):
|
| 235 |
+
module.unpack()
|
| 236 |
+
|
| 237 |
+
# Initialize lazy properties to avoid first-call overhead
|
| 238 |
+
self.causal_block_mask
|
| 239 |
+
self.point_gen_indices
|
| 240 |
+
|
| 241 |
+
# TODO: vision_projection and _prefill is not being compiled
|
| 242 |
+
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
|
| 243 |
+
self._decode_one_tok = torch.compile(
|
| 244 |
+
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Warm up compiled methods with dummy forward passes
|
| 248 |
+
device = self.device
|
| 249 |
+
dtype = self.vision.pos_emb.dtype
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
# Warmup vision encoder
|
| 252 |
+
dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype)
|
| 253 |
+
self._vis_enc(dummy_crops)
|
| 254 |
+
|
| 255 |
+
# Warmup _decode_one_tok (both normal and point generation modes)
|
| 256 |
+
dummy_emb = torch.randn(
|
| 257 |
+
1, 1, self.config.text.dim, device=device, dtype=dtype
|
| 258 |
+
)
|
| 259 |
+
dummy_mask = torch.ones(
|
| 260 |
+
1, 1, self.config.text.max_context, device=device, dtype=torch.bool
|
| 261 |
+
)
|
| 262 |
+
dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long)
|
| 263 |
+
self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None)
|
| 264 |
+
self._decode_one_tok(
|
| 265 |
+
dummy_emb,
|
| 266 |
+
dummy_mask,
|
| 267 |
+
dummy_pos_ids,
|
| 268 |
+
None,
|
| 269 |
+
lm_head_indices=self.point_gen_indices,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
| 273 |
+
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
| 274 |
+
|
| 275 |
+
torch._dynamo.mark_dynamic(all_crops, 0)
|
| 276 |
+
|
| 277 |
+
outputs = self._vis_enc(all_crops)
|
| 278 |
+
|
| 279 |
+
global_features = outputs[0]
|
| 280 |
+
local_features = outputs[1:].view(
|
| 281 |
+
-1,
|
| 282 |
+
self.config.vision.enc_n_layers,
|
| 283 |
+
self.config.vision.enc_n_layers,
|
| 284 |
+
self.config.vision.enc_dim,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
reconstructed = reconstruct_from_crops(
|
| 288 |
+
local_features,
|
| 289 |
+
tiling,
|
| 290 |
+
patch_size=1,
|
| 291 |
+
overlap_margin=self.config.vision.overlap_margin,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return self._vis_proj(global_features, reconstructed)
|
| 295 |
+
|
| 296 |
+
def encode_image(
|
| 297 |
+
self,
|
| 298 |
+
image: Union[Image.Image, EncodedImage],
|
| 299 |
+
settings: Optional[ImageEncodingSettings] = None,
|
| 300 |
+
) -> EncodedImage:
|
| 301 |
+
if isinstance(image, EncodedImage):
|
| 302 |
+
return image
|
| 303 |
+
elif not isinstance(image, Image.Image):
|
| 304 |
+
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 305 |
+
|
| 306 |
+
lora = (
|
| 307 |
+
variant_state_dict(settings["variant"], device=self.device)
|
| 308 |
+
if settings is not None and "variant" in settings
|
| 309 |
+
else None
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Run through text model in addition to the vision encoder, to minimize
|
| 313 |
+
# re-computation if multiple queries are performed on this image.
|
| 314 |
+
with torch.inference_mode():
|
| 315 |
+
img_emb = self._run_vision_encoder(image)
|
| 316 |
+
bos_emb = text_encoder(
|
| 317 |
+
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
|
| 318 |
+
self.text,
|
| 319 |
+
)
|
| 320 |
+
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 321 |
+
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 322 |
+
pos_ids = torch.arange(
|
| 323 |
+
inputs_embeds.size(1), dtype=torch.long, device=self.device
|
| 324 |
+
)
|
| 325 |
+
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 326 |
+
|
| 327 |
+
return EncodedImage(
|
| 328 |
+
pos=inputs_embeds.size(1),
|
| 329 |
+
caches=[
|
| 330 |
+
(
|
| 331 |
+
b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 332 |
+
b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 333 |
+
)
|
| 334 |
+
for b in self.text.blocks
|
| 335 |
+
],
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
|
| 339 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 340 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 341 |
+
mask = probs_sum - probs_sort > top_p
|
| 342 |
+
probs_sort[mask] = 0.0
|
| 343 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 344 |
+
next_probs = torch.zeros_like(probs)
|
| 345 |
+
next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
|
| 346 |
+
return next_probs
|
| 347 |
+
|
| 348 |
+
def _prefill_prompt(
|
| 349 |
+
self,
|
| 350 |
+
prompt_tokens: torch.Tensor,
|
| 351 |
+
pos: int,
|
| 352 |
+
temperature: float,
|
| 353 |
+
top_p: float,
|
| 354 |
+
spatial_refs: Optional[SpatialRefs] = None,
|
| 355 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 356 |
+
lora: Optional[dict] = None,
|
| 357 |
+
):
|
| 358 |
+
with torch.inference_mode():
|
| 359 |
+
prompt_emb = text_encoder(prompt_tokens, self.text)
|
| 360 |
+
|
| 361 |
+
if spatial_refs:
|
| 362 |
+
encoded_refs = encode_spatial_refs(spatial_refs, self.region)
|
| 363 |
+
prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
|
| 364 |
+
encoded_refs["coords"]
|
| 365 |
+
)
|
| 366 |
+
if encoded_refs["sizes"] is not None:
|
| 367 |
+
prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
|
| 368 |
+
encoded_refs["sizes"]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 372 |
+
|
| 373 |
+
if attn_mask is None:
|
| 374 |
+
attn_mask = self.attn_mask
|
| 375 |
+
|
| 376 |
+
mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
|
| 377 |
+
pos_ids = torch.arange(
|
| 378 |
+
pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device
|
| 379 |
+
)
|
| 380 |
+
hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
|
| 381 |
+
logits_BV = lm_head(hidden_BC, self.text)
|
| 382 |
+
|
| 383 |
+
if temperature == 0:
|
| 384 |
+
next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
|
| 385 |
+
else:
|
| 386 |
+
probs = torch.softmax(logits_BV / temperature, dim=-1)
|
| 387 |
+
probs = self._apply_top_p(probs, top_p)
|
| 388 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 389 |
+
|
| 390 |
+
pos = pos + prompt_emb.size(1)
|
| 391 |
+
return logits_BV, hidden_BC, next_token, pos
|
| 392 |
+
|
| 393 |
+
def _generate_reasoning(
|
| 394 |
+
self,
|
| 395 |
+
prompt_tokens,
|
| 396 |
+
pos,
|
| 397 |
+
settings: Optional[TextSamplingSettings] = None,
|
| 398 |
+
spatial_refs: Optional[SpatialRefs] = None,
|
| 399 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 400 |
+
) -> Tuple[int, str, List[dict]]:
|
| 401 |
+
max_tokens = (
|
| 402 |
+
settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
| 403 |
+
if settings
|
| 404 |
+
else DEFAULT_MAX_TOKENS
|
| 405 |
+
)
|
| 406 |
+
temperature = (
|
| 407 |
+
settings.get("temperature", DEFAULT_TEMPERATURE)
|
| 408 |
+
if settings
|
| 409 |
+
else DEFAULT_TEMPERATURE
|
| 410 |
+
)
|
| 411 |
+
lora = (
|
| 412 |
+
variant_state_dict(settings["variant"], device=self.device)
|
| 413 |
+
if settings is not None and "variant" in settings
|
| 414 |
+
else None
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
|
| 418 |
+
eos_id = self.config.tokenizer.answer_id
|
| 419 |
+
|
| 420 |
+
_, last_hidden_BC, next_token, pos = self._prefill_prompt(
|
| 421 |
+
prompt_tokens,
|
| 422 |
+
pos,
|
| 423 |
+
temperature,
|
| 424 |
+
top_p,
|
| 425 |
+
spatial_refs,
|
| 426 |
+
attn_mask=attn_mask,
|
| 427 |
+
lora=lora,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
text_token_chunks = [[]]
|
| 431 |
+
grounding_chunks = [[]]
|
| 432 |
+
|
| 433 |
+
mask = torch.zeros(
|
| 434 |
+
1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
|
| 435 |
+
)
|
| 436 |
+
mask[:, :, :pos] = 1
|
| 437 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
| 438 |
+
generated_tokens = 0
|
| 439 |
+
|
| 440 |
+
while (
|
| 441 |
+
next_token_id := next_token.item()
|
| 442 |
+
) != eos_id and generated_tokens < max_tokens:
|
| 443 |
+
if (
|
| 444 |
+
next_token_id == self.config.tokenizer.start_ground_points_id
|
| 445 |
+
or next_token_id == self.config.tokenizer.end_ground_id
|
| 446 |
+
):
|
| 447 |
+
text_token_chunks.append([])
|
| 448 |
+
grounding_chunks.append([])
|
| 449 |
+
|
| 450 |
+
text_token_chunks[-1].append(next_token_id)
|
| 451 |
+
|
| 452 |
+
with torch.inference_mode():
|
| 453 |
+
if next_token_id == self.config.tokenizer.coord_id:
|
| 454 |
+
coord_logits = decode_coordinate(last_hidden_BC, self.region)
|
| 455 |
+
coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
|
| 456 |
+
grounding_chunks[-1].append(coord.item())
|
| 457 |
+
|
| 458 |
+
next_emb = encode_coordinate(
|
| 459 |
+
coord.to(dtype=coord_logits.dtype), self.region
|
| 460 |
+
).unsqueeze(0)
|
| 461 |
+
else:
|
| 462 |
+
next_emb = text_encoder(next_token, self.text)
|
| 463 |
+
|
| 464 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 465 |
+
|
| 466 |
+
logits_BV, last_hidden_BC = self._decode_one_tok(
|
| 467 |
+
next_emb, mask, pos_ids, lora
|
| 468 |
+
)
|
| 469 |
+
logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
|
| 470 |
+
logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
|
| 471 |
+
|
| 472 |
+
pos += 1
|
| 473 |
+
|
| 474 |
+
if temperature == 0:
|
| 475 |
+
next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
|
| 476 |
+
else:
|
| 477 |
+
probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
|
| 478 |
+
probs = self._apply_top_p(probs, top_p)
|
| 479 |
+
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
|
| 480 |
+
|
| 481 |
+
generated_tokens += 1
|
| 482 |
+
|
| 483 |
+
text_chunks = [
|
| 484 |
+
self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
|
| 485 |
+
]
|
| 486 |
+
text = "".join(text_chunks)
|
| 487 |
+
|
| 488 |
+
start_idx = 0
|
| 489 |
+
grounding = []
|
| 490 |
+
for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
|
| 491 |
+
if len(grounding_chunk) > 1:
|
| 492 |
+
points = []
|
| 493 |
+
for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
|
| 494 |
+
points.append((grounding_chunk[i], grounding_chunk[i + 1]))
|
| 495 |
+
grounding.append(
|
| 496 |
+
{
|
| 497 |
+
"start_idx": start_idx,
|
| 498 |
+
"end_idx": start_idx + len(text_chunk),
|
| 499 |
+
"points": points,
|
| 500 |
+
}
|
| 501 |
+
)
|
| 502 |
+
start_idx += len(text_chunk)
|
| 503 |
+
|
| 504 |
+
return pos, text, grounding
|
| 505 |
+
|
| 506 |
+
def _generate_answer(
|
| 507 |
+
self,
|
| 508 |
+
prompt_tokens: torch.Tensor,
|
| 509 |
+
pos: int,
|
| 510 |
+
settings: Optional[TextSamplingSettings] = None,
|
| 511 |
+
spatial_refs: Optional[SpatialRefs] = None,
|
| 512 |
+
eos_id: Optional[int] = None,
|
| 513 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 514 |
+
):
|
| 515 |
+
max_tokens = (
|
| 516 |
+
settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
| 517 |
+
if settings
|
| 518 |
+
else DEFAULT_MAX_TOKENS
|
| 519 |
+
)
|
| 520 |
+
temperature = (
|
| 521 |
+
settings.get("temperature", DEFAULT_TEMPERATURE)
|
| 522 |
+
if settings
|
| 523 |
+
else DEFAULT_TEMPERATURE
|
| 524 |
+
)
|
| 525 |
+
top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
|
| 526 |
+
eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
|
| 527 |
+
lora = (
|
| 528 |
+
variant_state_dict(settings["variant"], device=self.device)
|
| 529 |
+
if settings is not None and "variant" in settings
|
| 530 |
+
else None
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
_, _, next_token, pos = self._prefill_prompt(
|
| 534 |
+
prompt_tokens,
|
| 535 |
+
pos,
|
| 536 |
+
temperature,
|
| 537 |
+
top_p,
|
| 538 |
+
spatial_refs,
|
| 539 |
+
attn_mask=attn_mask,
|
| 540 |
+
lora=lora,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def generator(next_token, pos):
|
| 544 |
+
mask = torch.zeros(
|
| 545 |
+
1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
|
| 546 |
+
)
|
| 547 |
+
mask[:, :, :pos] = 1
|
| 548 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
| 549 |
+
generated_tokens = 0
|
| 550 |
+
|
| 551 |
+
# For properly handling token streaming with Unicode
|
| 552 |
+
token_cache = []
|
| 553 |
+
print_len = 0
|
| 554 |
+
|
| 555 |
+
while (
|
| 556 |
+
next_token_id := next_token.item()
|
| 557 |
+
) != eos_id and generated_tokens < max_tokens:
|
| 558 |
+
# Add token to our cache
|
| 559 |
+
token_cache.append(next_token_id)
|
| 560 |
+
|
| 561 |
+
# Decode all tokens collected so far
|
| 562 |
+
text = self.tokenizer.decode(token_cache)
|
| 563 |
+
|
| 564 |
+
# After a newline, we flush the cache completely
|
| 565 |
+
if text.endswith("\n"):
|
| 566 |
+
printable_text = text[print_len:]
|
| 567 |
+
token_cache = []
|
| 568 |
+
print_len = 0
|
| 569 |
+
if printable_text:
|
| 570 |
+
yield printable_text
|
| 571 |
+
# If the last token is a CJK character, we can safely print it
|
| 572 |
+
elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
|
| 573 |
+
printable_text = text[print_len:]
|
| 574 |
+
print_len += len(printable_text)
|
| 575 |
+
if printable_text:
|
| 576 |
+
yield printable_text
|
| 577 |
+
# Otherwise, only yield up to the last space to avoid cutting words
|
| 578 |
+
else:
|
| 579 |
+
last_space_idx = text.rfind(" ", print_len)
|
| 580 |
+
if last_space_idx >= print_len:
|
| 581 |
+
printable_text = text[print_len : last_space_idx + 1]
|
| 582 |
+
print_len += len(printable_text)
|
| 583 |
+
if printable_text:
|
| 584 |
+
yield printable_text
|
| 585 |
+
|
| 586 |
+
with torch.inference_mode():
|
| 587 |
+
next_emb = text_encoder(next_token, self.text)
|
| 588 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 589 |
+
|
| 590 |
+
logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
|
| 591 |
+
logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
|
| 592 |
+
|
| 593 |
+
pos += 1
|
| 594 |
+
|
| 595 |
+
if temperature == 0:
|
| 596 |
+
next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
|
| 597 |
+
1
|
| 598 |
+
) # (1, 1)
|
| 599 |
+
else:
|
| 600 |
+
probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
|
| 601 |
+
probs = self._apply_top_p(probs, top_p)
|
| 602 |
+
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
|
| 603 |
+
|
| 604 |
+
generated_tokens += 1
|
| 605 |
+
|
| 606 |
+
# Flush any remaining text in the cache
|
| 607 |
+
if token_cache:
|
| 608 |
+
text = self.tokenizer.decode(token_cache)
|
| 609 |
+
printable_text = text[print_len:]
|
| 610 |
+
if printable_text:
|
| 611 |
+
yield printable_text
|
| 612 |
+
|
| 613 |
+
return generator(next_token, pos)
|
| 614 |
+
|
| 615 |
+
def query(
|
| 616 |
+
self,
|
| 617 |
+
image: Optional[Union[Image.Image, EncodedImage]] = None,
|
| 618 |
+
question: str = None,
|
| 619 |
+
reasoning: bool = True,
|
| 620 |
+
spatial_refs: Optional[SpatialRefs] = None,
|
| 621 |
+
stream: bool = False,
|
| 622 |
+
settings: Optional[TextSamplingSettings] = None,
|
| 623 |
+
):
|
| 624 |
+
if self.config.tokenizer.templates["query"] is None:
|
| 625 |
+
raise NotImplementedError("Model does not support querying.")
|
| 626 |
+
|
| 627 |
+
if question is None:
|
| 628 |
+
raise ValueError("question must be provided.")
|
| 629 |
+
|
| 630 |
+
if spatial_refs and image is None:
|
| 631 |
+
raise ValueError("spatial_refs can only be used with an image.")
|
| 632 |
+
|
| 633 |
+
attn_mask = self.attn_mask
|
| 634 |
+
if image is not None:
|
| 635 |
+
image = self.encode_image(image, settings)
|
| 636 |
+
self.load_encoded_image(image)
|
| 637 |
+
pos = image.pos
|
| 638 |
+
prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
|
| 639 |
+
else:
|
| 640 |
+
self._setup_caches()
|
| 641 |
+
pos = 0
|
| 642 |
+
prompt_toks = [
|
| 643 |
+
self.config.tokenizer.bos_id
|
| 644 |
+
] + self.config.tokenizer.templates["query"]["prefix"]
|
| 645 |
+
max_context = self.config.text.max_context
|
| 646 |
+
attn_mask = torch.tril(
|
| 647 |
+
torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
|
| 648 |
+
).to(self.device)
|
| 649 |
+
|
| 650 |
+
spatial_toks = []
|
| 651 |
+
if spatial_refs:
|
| 652 |
+
for ref in spatial_refs:
|
| 653 |
+
coord_id = self.config.tokenizer.coord_id
|
| 654 |
+
size_id = self.config.tokenizer.size_id
|
| 655 |
+
if len(ref) == 2:
|
| 656 |
+
spatial_toks.extend([coord_id, coord_id])
|
| 657 |
+
else:
|
| 658 |
+
spatial_toks.extend([coord_id, coord_id, size_id])
|
| 659 |
+
|
| 660 |
+
prompt_tokens = [
|
| 661 |
+
prompt_toks + spatial_toks + self.tokenizer.encode(question).ids
|
| 662 |
+
]
|
| 663 |
+
|
| 664 |
+
if reasoning:
|
| 665 |
+
prompt_tokens[0] += [self.config.tokenizer.thinking_id]
|
| 666 |
+
prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
|
| 667 |
+
pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
|
| 668 |
+
prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
|
| 669 |
+
)
|
| 670 |
+
prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
|
| 671 |
+
reasoning_dict = {
|
| 672 |
+
"reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
|
| 673 |
+
}
|
| 674 |
+
else:
|
| 675 |
+
prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
|
| 676 |
+
reasoning_dict = {}
|
| 677 |
+
|
| 678 |
+
prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
|
| 679 |
+
|
| 680 |
+
def generator():
|
| 681 |
+
for token in self._generate_answer(
|
| 682 |
+
prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
|
| 683 |
+
):
|
| 684 |
+
yield token
|
| 685 |
+
|
| 686 |
+
if stream:
|
| 687 |
+
return {**reasoning_dict, "answer": generator()}
|
| 688 |
+
else:
|
| 689 |
+
return {**reasoning_dict, "answer": "".join(list(generator()))}
|
| 690 |
+
|
| 691 |
+
def load_encoded_image(self, encoded_image: EncodedImage):
|
| 692 |
+
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
| 693 |
+
b.kv_cache.k_cache[:, :, : k.size(2), :] = k
|
| 694 |
+
b.kv_cache.v_cache[:, :, : v.size(2), :] = v
|
| 695 |
+
|
| 696 |
+
def caption(
|
| 697 |
+
self,
|
| 698 |
+
image: Union[Image.Image, EncodedImage],
|
| 699 |
+
length: Literal["normal", "short", "long"] = "normal",
|
| 700 |
+
stream: bool = False,
|
| 701 |
+
settings: Optional[TextSamplingSettings] = None,
|
| 702 |
+
):
|
| 703 |
+
if self.config.tokenizer.templates["caption"] is None:
|
| 704 |
+
raise NotImplementedError("Model does not support captioning.")
|
| 705 |
+
if length not in self.config.tokenizer.templates["caption"]:
|
| 706 |
+
raise ValueError(f"Model does not support caption length '{length}'.")
|
| 707 |
+
|
| 708 |
+
image = self.encode_image(image, settings)
|
| 709 |
+
self.load_encoded_image(image)
|
| 710 |
+
|
| 711 |
+
prompt_tokens = torch.tensor(
|
| 712 |
+
[self.config.tokenizer.templates["caption"][length]], device=self.device
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
def generator():
|
| 716 |
+
for token in self._generate_answer(prompt_tokens, image.pos, settings):
|
| 717 |
+
yield token
|
| 718 |
+
|
| 719 |
+
if stream:
|
| 720 |
+
return {"caption": generator()}
|
| 721 |
+
else:
|
| 722 |
+
return {"caption": "".join(list(generator()))}
|
| 723 |
+
|
| 724 |
+
def _generate_points(
|
| 725 |
+
self,
|
| 726 |
+
hidden: torch.Tensor,
|
| 727 |
+
next_token: torch.Tensor,
|
| 728 |
+
pos: int,
|
| 729 |
+
include_size: bool = True,
|
| 730 |
+
max_objects: int = DEFAULT_MAX_OBJECTS,
|
| 731 |
+
lora: Optional[dict] = None,
|
| 732 |
+
):
|
| 733 |
+
out = []
|
| 734 |
+
mask = torch.zeros(
|
| 735 |
+
1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
|
| 736 |
+
)
|
| 737 |
+
mask[:, :, :pos] = 1
|
| 738 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
| 739 |
+
|
| 740 |
+
with torch.inference_mode():
|
| 741 |
+
while (
|
| 742 |
+
next_token.item() != self.config.tokenizer.eos_id
|
| 743 |
+
and len(out) < max_objects
|
| 744 |
+
):
|
| 745 |
+
x_logits = decode_coordinate(hidden, self.region)
|
| 746 |
+
x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
|
| 747 |
+
next_emb = encode_coordinate(
|
| 748 |
+
x_center.to(dtype=x_logits.dtype), self.region
|
| 749 |
+
).unsqueeze(0)
|
| 750 |
+
|
| 751 |
+
# Decode y-coordinate
|
| 752 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 753 |
+
_, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
|
| 754 |
+
pos += 1
|
| 755 |
+
y_logits = decode_coordinate(hidden, self.region)
|
| 756 |
+
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
| 757 |
+
next_emb = encode_coordinate(
|
| 758 |
+
y_center.to(dtype=y_logits.dtype), self.region
|
| 759 |
+
).unsqueeze(0)
|
| 760 |
+
|
| 761 |
+
# Decode size
|
| 762 |
+
if include_size:
|
| 763 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 764 |
+
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
|
| 765 |
+
pos += 1
|
| 766 |
+
size_logits = decode_size(hidden, self.region)
|
| 767 |
+
|
| 768 |
+
# Get bin indices from the logits
|
| 769 |
+
w_bin = torch.argmax(size_logits[0], dim=-1)
|
| 770 |
+
h_bin = torch.argmax(size_logits[1], dim=-1)
|
| 771 |
+
|
| 772 |
+
# Convert from bin indices to actual size values using the inverse of the log-scale mapping
|
| 773 |
+
# Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
| 774 |
+
w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
|
| 775 |
+
h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
|
| 776 |
+
|
| 777 |
+
next_emb = (
|
| 778 |
+
encode_size(
|
| 779 |
+
torch.tensor(
|
| 780 |
+
[w, h], device=self.device, dtype=size_logits.dtype
|
| 781 |
+
),
|
| 782 |
+
self.region,
|
| 783 |
+
)
|
| 784 |
+
.unsqueeze(0)
|
| 785 |
+
.unsqueeze(0)
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# Add object
|
| 789 |
+
out.append(
|
| 790 |
+
{
|
| 791 |
+
"x_min": x_center.item() - w.item() / 2,
|
| 792 |
+
"y_min": y_center.item() - h.item() / 2,
|
| 793 |
+
"x_max": x_center.item() + w.item() / 2,
|
| 794 |
+
"y_max": y_center.item() + h.item() / 2,
|
| 795 |
+
}
|
| 796 |
+
)
|
| 797 |
+
else:
|
| 798 |
+
out.append({"x": x_center.item(), "y": y_center.item()})
|
| 799 |
+
|
| 800 |
+
# Decode next token (x-coordinate, or eos)
|
| 801 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 802 |
+
logits, hidden = self._decode_one_tok(
|
| 803 |
+
next_emb,
|
| 804 |
+
mask,
|
| 805 |
+
pos_ids,
|
| 806 |
+
lora,
|
| 807 |
+
lm_head_indices=self.point_gen_indices,
|
| 808 |
+
)
|
| 809 |
+
pos += 1
|
| 810 |
+
# Map back: index 0 -> coord_id, index 1 -> eos_id
|
| 811 |
+
next_token_idx = torch.argmax(logits, dim=-1)
|
| 812 |
+
next_token = self.point_gen_indices[next_token_idx]
|
| 813 |
+
|
| 814 |
+
return out
|
| 815 |
+
|
| 816 |
+
def detect(
|
| 817 |
+
self,
|
| 818 |
+
image: Union[Image.Image, EncodedImage],
|
| 819 |
+
object: str,
|
| 820 |
+
settings: Optional[ObjectSamplingSettings] = None,
|
| 821 |
+
):
|
| 822 |
+
if self.config.tokenizer.templates["detect"] is None:
|
| 823 |
+
raise NotImplementedError("Model does not support object detection.")
|
| 824 |
+
|
| 825 |
+
image = self.encode_image(image, settings)
|
| 826 |
+
self.load_encoded_image(image)
|
| 827 |
+
|
| 828 |
+
prompt_tokens = torch.tensor(
|
| 829 |
+
[
|
| 830 |
+
self.config.tokenizer.templates["detect"]["prefix"]
|
| 831 |
+
+ self.tokenizer.encode(" " + object).ids
|
| 832 |
+
+ self.config.tokenizer.templates["detect"]["suffix"]
|
| 833 |
+
],
|
| 834 |
+
device=self.device,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
lora = (
|
| 838 |
+
variant_state_dict(settings["variant"], device=self.device)
|
| 839 |
+
if settings is not None and "variant" in settings
|
| 840 |
+
else None
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
_, hidden, next_token, pos = self._prefill_prompt(
|
| 844 |
+
prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
|
| 845 |
+
)
|
| 846 |
+
hidden = hidden[:, -1:, :]
|
| 847 |
+
|
| 848 |
+
max_objects = (
|
| 849 |
+
settings.get("max_objects", DEFAULT_MAX_OBJECTS)
|
| 850 |
+
if settings
|
| 851 |
+
else DEFAULT_MAX_OBJECTS
|
| 852 |
+
)
|
| 853 |
+
objects = self._generate_points(
|
| 854 |
+
hidden,
|
| 855 |
+
next_token,
|
| 856 |
+
pos,
|
| 857 |
+
include_size=True,
|
| 858 |
+
max_objects=max_objects,
|
| 859 |
+
lora=lora,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
return {"objects": objects}
|
| 863 |
+
|
| 864 |
+
def point(
|
| 865 |
+
self,
|
| 866 |
+
image: Union[Image.Image, EncodedImage],
|
| 867 |
+
object: str,
|
| 868 |
+
settings: Optional[ObjectSamplingSettings] = None,
|
| 869 |
+
):
|
| 870 |
+
if self.config.tokenizer.templates["point"] is None:
|
| 871 |
+
raise NotImplementedError("Model does not support pointing.")
|
| 872 |
+
|
| 873 |
+
image = self.encode_image(image, settings)
|
| 874 |
+
self.load_encoded_image(image)
|
| 875 |
+
|
| 876 |
+
prompt_tokens = torch.tensor(
|
| 877 |
+
[
|
| 878 |
+
self.config.tokenizer.templates["point"]["prefix"]
|
| 879 |
+
+ self.tokenizer.encode(" " + object).ids
|
| 880 |
+
+ self.config.tokenizer.templates["point"]["suffix"]
|
| 881 |
+
],
|
| 882 |
+
device=self.device,
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
lora = (
|
| 886 |
+
variant_state_dict(settings["variant"], device=self.device)
|
| 887 |
+
if settings is not None and "variant" in settings
|
| 888 |
+
else None
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
_, hidden, next_token, pos = self._prefill_prompt(
|
| 892 |
+
prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
|
| 893 |
+
)
|
| 894 |
+
hidden = hidden[:, -1:, :]
|
| 895 |
+
|
| 896 |
+
max_objects = (
|
| 897 |
+
settings.get("max_objects", DEFAULT_MAX_OBJECTS)
|
| 898 |
+
if settings
|
| 899 |
+
else DEFAULT_MAX_OBJECTS
|
| 900 |
+
)
|
| 901 |
+
objects = self._generate_points(
|
| 902 |
+
hidden,
|
| 903 |
+
next_token,
|
| 904 |
+
pos,
|
| 905 |
+
include_size=False,
|
| 906 |
+
max_objects=max_objects,
|
| 907 |
+
lora=lora,
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
return {"points": objects}
|
| 911 |
+
|
| 912 |
+
def _detect_gaze(
|
| 913 |
+
self,
|
| 914 |
+
image: EncodedImage,
|
| 915 |
+
source: Tuple[float, float],
|
| 916 |
+
force_detect: bool = False,
|
| 917 |
+
):
|
| 918 |
+
with torch.inference_mode():
|
| 919 |
+
before_emb = text_encoder(
|
| 920 |
+
torch.tensor(
|
| 921 |
+
[self.tokenizer.encode("\n\nPoint:").ids], device=self.device
|
| 922 |
+
),
|
| 923 |
+
self.text,
|
| 924 |
+
)
|
| 925 |
+
after_emb = text_encoder(
|
| 926 |
+
torch.tensor(
|
| 927 |
+
[self.tokenizer.encode(" gaze\n\n").ids], device=self.device
|
| 928 |
+
),
|
| 929 |
+
self.text,
|
| 930 |
+
)
|
| 931 |
+
x_emb = encode_coordinate(
|
| 932 |
+
torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
|
| 933 |
+
self.region,
|
| 934 |
+
)
|
| 935 |
+
y_emb = encode_coordinate(
|
| 936 |
+
torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
|
| 937 |
+
self.region,
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
|
| 941 |
+
|
| 942 |
+
self.load_encoded_image(image)
|
| 943 |
+
|
| 944 |
+
mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
|
| 945 |
+
pos_ids = torch.arange(
|
| 946 |
+
image.pos,
|
| 947 |
+
image.pos + prompt_emb.size(1),
|
| 948 |
+
dtype=torch.long,
|
| 949 |
+
device=self.device,
|
| 950 |
+
)
|
| 951 |
+
hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
|
| 952 |
+
logits = lm_head(hidden, self.text)
|
| 953 |
+
next_token = torch.argmax(logits, dim=-1)
|
| 954 |
+
pos = image.pos + prompt_emb.size(1)
|
| 955 |
+
hidden = hidden[:, -1:, :]
|
| 956 |
+
|
| 957 |
+
if force_detect:
|
| 958 |
+
next_token = torch.tensor([[0]], device=self.device)
|
| 959 |
+
|
| 960 |
+
if next_token.item() == self.config.tokenizer.eos_id:
|
| 961 |
+
return None
|
| 962 |
+
|
| 963 |
+
gaze = self._generate_points(
|
| 964 |
+
hidden, next_token, pos, include_size=False, max_objects=1
|
| 965 |
+
)
|
| 966 |
+
return gaze[0]
|
| 967 |
+
|
| 968 |
+
def detect_gaze(
|
| 969 |
+
self,
|
| 970 |
+
image: Union[Image.Image, EncodedImage],
|
| 971 |
+
eye: Optional[Tuple[float, float]] = None,
|
| 972 |
+
face: Optional[Dict[str, float]] = None,
|
| 973 |
+
unstable_settings: Dict[str, Any] = {},
|
| 974 |
+
):
|
| 975 |
+
if "force_detect" in unstable_settings:
|
| 976 |
+
force_detect = unstable_settings["force_detect"]
|
| 977 |
+
else:
|
| 978 |
+
force_detect = False
|
| 979 |
+
|
| 980 |
+
if "prioritize_accuracy" in unstable_settings:
|
| 981 |
+
prioritize_accuracy = unstable_settings["prioritize_accuracy"]
|
| 982 |
+
else:
|
| 983 |
+
prioritize_accuracy = False
|
| 984 |
+
|
| 985 |
+
if not prioritize_accuracy:
|
| 986 |
+
if eye is None:
|
| 987 |
+
raise ValueError("eye must be provided when prioritize_accuracy=False")
|
| 988 |
+
image = self.encode_image(image)
|
| 989 |
+
return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
|
| 990 |
+
else:
|
| 991 |
+
if (
|
| 992 |
+
not isinstance(image, Image.Image)
|
| 993 |
+
and "flip_enc_img" not in unstable_settings
|
| 994 |
+
):
|
| 995 |
+
raise ValueError(
|
| 996 |
+
"image must be a PIL Image when prioritize_accuracy=True, "
|
| 997 |
+
"or flip_enc_img must be provided"
|
| 998 |
+
)
|
| 999 |
+
if face is None:
|
| 1000 |
+
raise ValueError("face must be provided when prioritize_accuracy=True")
|
| 1001 |
+
|
| 1002 |
+
encoded_image = self.encode_image(image)
|
| 1003 |
+
if (
|
| 1004 |
+
isinstance(image, Image.Image)
|
| 1005 |
+
and "flip_enc_img" not in unstable_settings
|
| 1006 |
+
):
|
| 1007 |
+
flipped_pil = image.copy()
|
| 1008 |
+
flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
|
| 1009 |
+
encoded_flipped_image = self.encode_image(flipped_pil)
|
| 1010 |
+
else:
|
| 1011 |
+
encoded_flipped_image = unstable_settings["flip_enc_img"]
|
| 1012 |
+
|
| 1013 |
+
N = 10
|
| 1014 |
+
|
| 1015 |
+
detections = [
|
| 1016 |
+
self._detect_gaze(
|
| 1017 |
+
encoded_image,
|
| 1018 |
+
(
|
| 1019 |
+
random.uniform(face["x_min"], face["x_max"]),
|
| 1020 |
+
random.uniform(face["y_min"], face["y_max"]),
|
| 1021 |
+
),
|
| 1022 |
+
force_detect=force_detect,
|
| 1023 |
+
)
|
| 1024 |
+
for _ in range(N)
|
| 1025 |
+
]
|
| 1026 |
+
detections = [
|
| 1027 |
+
(gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
|
| 1028 |
+
]
|
| 1029 |
+
flipped_detections = [
|
| 1030 |
+
self._detect_gaze(
|
| 1031 |
+
encoded_flipped_image,
|
| 1032 |
+
(
|
| 1033 |
+
1 - random.uniform(face["x_min"], face["x_max"]),
|
| 1034 |
+
random.uniform(face["y_min"], face["y_max"]),
|
| 1035 |
+
),
|
| 1036 |
+
force_detect=force_detect,
|
| 1037 |
+
)
|
| 1038 |
+
for _ in range(N)
|
| 1039 |
+
]
|
| 1040 |
+
detections.extend(
|
| 1041 |
+
[
|
| 1042 |
+
(1 - gaze["x"], gaze["y"])
|
| 1043 |
+
for gaze in flipped_detections
|
| 1044 |
+
if gaze is not None
|
| 1045 |
+
]
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
if len(detections) < N:
|
| 1049 |
+
return {"gaze": None}
|
| 1050 |
+
|
| 1051 |
+
detections = remove_outlier_points(detections)
|
| 1052 |
+
mean_gaze = (
|
| 1053 |
+
sum(gaze[0] for gaze in detections) / len(detections),
|
| 1054 |
+
sum(gaze[1] for gaze in detections) / len(detections),
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
def _is_cjk_char(cp):
|
| 1061 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 1062 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 1063 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 1064 |
+
if (
|
| 1065 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 1066 |
+
or (cp >= 0x3400 and cp <= 0x4DBF)
|
| 1067 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
| 1068 |
+
):
|
| 1069 |
+
return True
|
| 1070 |
+
return False
|
open_vocab_detect.png
ADDED
|
Git LFS Details
|
point_count.png
ADDED
|
Git LFS Details
|
region.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
from typing import List, Tuple, Union
|
| 6 |
+
|
| 7 |
+
SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
| 11 |
+
"""
|
| 12 |
+
Applies Fourier feature mapping to input tensor x using frequency matrix w. This
|
| 13 |
+
projects inputs through sinusoidal functions to create higher dimensional features
|
| 14 |
+
that help mitigate spectral bias - the tendency of neural networks to learn
|
| 15 |
+
low-frequency functions more easily than high-frequency ones. By explicitly
|
| 16 |
+
mapping inputs to higher frequencies through sin/cos transformations, we enable
|
| 17 |
+
better learning of fine details and higher frequency patterns.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
x: Input tensor to transform
|
| 21 |
+
w: Matrix of frequencies for the Fourier features transformation
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Concatenated cosine and sine transformed features as a tensor
|
| 25 |
+
"""
|
| 26 |
+
f = 2 * math.pi * x @ w
|
| 27 |
+
return torch.cat([f.cos(), f.sin()], dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Takes as input a tensor containing a single float coordinate value (x or y)
|
| 33 |
+
and encodes it into hidden states for input to the text model.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
coord: Tensor with single float coordinate value
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Encoded hidden states tensor for input to text model
|
| 40 |
+
"""
|
| 41 |
+
return w.coord_encoder(fourier_features(coord, w.coord_features))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Takes as input the last hidden state from the text model and outputs a single logit
|
| 47 |
+
representing either an x or y coordinate prediction.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
hidden_state: The final hidden state tensor from the text model.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
A single logit representing the predicted coordinate value (x or y)
|
| 54 |
+
"""
|
| 55 |
+
return w.coord_decoder(hidden_state)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Takes a tensor containing width and height values and encodes them into
|
| 61 |
+
hidden states for input to the text model.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
size: Tensor with two floats for width and height
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Encoded hidden states tensor for input to text model
|
| 68 |
+
"""
|
| 69 |
+
return w.size_encoder(fourier_features(size, w.size_features))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Takes as input the last hidden state from the text model and outputs logits
|
| 75 |
+
for 1024 bins representing width and height in log-scale.
|
| 76 |
+
|
| 77 |
+
The bins are distributed according to the formula:
|
| 78 |
+
bin = (log2(size) + 10.0) / 10.0 * 1023.0
|
| 79 |
+
where size values are clamped to be at least 1/1024.
|
| 80 |
+
|
| 81 |
+
To convert from bin back to size:
|
| 82 |
+
size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
hidden_state: The final hidden state tensor from the text model.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
A tensor containing logits for 1024 bins for width and height.
|
| 89 |
+
Shape is (2, 1024) where the first dimension corresponds to width and height.
|
| 90 |
+
"""
|
| 91 |
+
return w.size_decoder(hidden_state).view(2, -1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
|
| 95 |
+
"""
|
| 96 |
+
Takes a list of spatial references (points or regions) and encodes them into
|
| 97 |
+
hidden states for input to the text model.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
spatial_refs: List of spatial references (points or boxes)
|
| 101 |
+
- Points are represented as normalized (x, y) tuples
|
| 102 |
+
- Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
{"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
|
| 106 |
+
"""
|
| 107 |
+
coords, sizes = [], []
|
| 108 |
+
for ref in spatial_refs:
|
| 109 |
+
if len(ref) == 2:
|
| 110 |
+
coords.append(ref[0])
|
| 111 |
+
coords.append(ref[1])
|
| 112 |
+
else:
|
| 113 |
+
x_c = (ref[0] + ref[2]) / 2
|
| 114 |
+
y_c = (ref[1] + ref[3]) / 2
|
| 115 |
+
width = ref[2] - ref[0]
|
| 116 |
+
height = ref[3] - ref[1]
|
| 117 |
+
coords.append(x_c)
|
| 118 |
+
coords.append(y_c)
|
| 119 |
+
sizes.append([width, height])
|
| 120 |
+
|
| 121 |
+
coords = torch.tensor(
|
| 122 |
+
coords, device=w.coord_features.device, dtype=w.coord_features.dtype
|
| 123 |
+
).view(-1, 1)
|
| 124 |
+
coords = encode_coordinate(coords, w)
|
| 125 |
+
|
| 126 |
+
if sizes:
|
| 127 |
+
sizes = torch.tensor(
|
| 128 |
+
sizes, device=w.size_features.device, dtype=w.size_features.dtype
|
| 129 |
+
)
|
| 130 |
+
sizes = encode_size(sizes, w)
|
| 131 |
+
else:
|
| 132 |
+
sizes = None
|
| 133 |
+
|
| 134 |
+
return {"coords": coords, "sizes": sizes}
|
rope.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ethically sourced from https://github.com/xjdr-alt/entropix
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def precompute_freqs_cis(
|
| 7 |
+
dim: int,
|
| 8 |
+
end: int,
|
| 9 |
+
theta: float = 1500000.0,
|
| 10 |
+
dtype: torch.dtype = torch.float32,
|
| 11 |
+
) -> torch.Tensor:
|
| 12 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
|
| 13 |
+
t = torch.arange(end, dtype=dtype).unsqueeze(1)
|
| 14 |
+
freqs = t * freqs.unsqueeze(0)
|
| 15 |
+
freqs = torch.exp(1j * freqs)
|
| 16 |
+
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def apply_rotary_emb(
|
| 20 |
+
x: torch.Tensor,
|
| 21 |
+
freqs_cis: torch.Tensor,
|
| 22 |
+
position_ids: torch.Tensor,
|
| 23 |
+
num_heads: int,
|
| 24 |
+
rot_dim: int = 32,
|
| 25 |
+
interleave: bool = False,
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
assert rot_dim == freqs_cis.shape[-2] * 2
|
| 28 |
+
assert num_heads == x.shape[1]
|
| 29 |
+
|
| 30 |
+
x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
| 31 |
+
|
| 32 |
+
if interleave:
|
| 33 |
+
xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
|
| 34 |
+
xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
|
| 35 |
+
else:
|
| 36 |
+
d_q = x_rot.shape[-1] // 2
|
| 37 |
+
xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
|
| 38 |
+
|
| 39 |
+
freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
|
| 40 |
+
freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
|
| 41 |
+
|
| 42 |
+
# Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
| 43 |
+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
| 44 |
+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
| 45 |
+
xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
|
| 46 |
+
|
| 47 |
+
return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
|
structured_outputs.png
ADDED
|
Git LFS Details
|
text.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
|
| 9 |
+
from .rope import apply_rotary_emb, precompute_freqs_cis
|
| 10 |
+
from .config import TextConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def text_encoder(input_ids: torch.Tensor, w: nn.Module):
|
| 14 |
+
return F.embedding(input_ids, w.wte)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def attn(
|
| 18 |
+
x: torch.Tensor,
|
| 19 |
+
w: nn.Module,
|
| 20 |
+
freqs_cis: torch.Tensor,
|
| 21 |
+
kv_cache: nn.Module,
|
| 22 |
+
attn_mask: torch.Tensor,
|
| 23 |
+
n_heads: int,
|
| 24 |
+
n_kv_heads: int,
|
| 25 |
+
position_ids: torch.Tensor,
|
| 26 |
+
lora: Optional[dict] = None,
|
| 27 |
+
flex_block_mask_slice=None,
|
| 28 |
+
):
|
| 29 |
+
bsz, q_len, d_model = x.shape
|
| 30 |
+
head_dim = d_model // n_heads
|
| 31 |
+
|
| 32 |
+
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
| 33 |
+
if lora is not None:
|
| 34 |
+
qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
|
| 35 |
+
q_dim = n_heads * head_dim
|
| 36 |
+
kv_dim = n_kv_heads * head_dim
|
| 37 |
+
q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
|
| 38 |
+
|
| 39 |
+
q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
| 40 |
+
k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
| 41 |
+
v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
| 42 |
+
|
| 43 |
+
if hasattr(w, "tau") and w.tau is not None:
|
| 44 |
+
tok_feat = F.gelu(qkv_out)
|
| 45 |
+
tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1)
|
| 46 |
+
tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1)
|
| 47 |
+
pos = position_ids.to(q.dtype) + 1
|
| 48 |
+
tau_pos = 1 + (
|
| 49 |
+
torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5
|
| 50 |
+
) # (H,S)
|
| 51 |
+
tau_q = (tok_q + tau_pos[None]).unsqueeze(-1) # (B,H,S,1)
|
| 52 |
+
tau_v = (tok_v + tau_pos[None]).unsqueeze(-1)
|
| 53 |
+
q = q * tau_q
|
| 54 |
+
v = v * tau_v
|
| 55 |
+
|
| 56 |
+
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
| 57 |
+
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
| 58 |
+
|
| 59 |
+
if kv_cache is not None:
|
| 60 |
+
k, v = kv_cache.update(position_ids, k, v)
|
| 61 |
+
|
| 62 |
+
if flex_block_mask_slice is not None:
|
| 63 |
+
torch._assert(n_heads == n_kv_heads, "gqa not supported yet")
|
| 64 |
+
out = flex_attention(q, k, v, block_mask=flex_block_mask_slice)
|
| 65 |
+
else:
|
| 66 |
+
out = F.scaled_dot_product_attention(
|
| 67 |
+
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
| 71 |
+
|
| 72 |
+
out0 = w.proj(out)
|
| 73 |
+
if lora is not None:
|
| 74 |
+
out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
|
| 75 |
+
out = out0 + out1
|
| 76 |
+
else:
|
| 77 |
+
out = out0
|
| 78 |
+
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def text_decoder(
|
| 83 |
+
x: torch.Tensor,
|
| 84 |
+
w: nn.Module,
|
| 85 |
+
attn_mask: torch.Tensor,
|
| 86 |
+
position_ids: torch.Tensor,
|
| 87 |
+
config: TextConfig,
|
| 88 |
+
lora: Optional[dict] = None,
|
| 89 |
+
flex_block_mask_slice=None,
|
| 90 |
+
):
|
| 91 |
+
for i, block in enumerate(w.blocks):
|
| 92 |
+
if lora is not None:
|
| 93 |
+
layer_lora = lora["text"]["blocks"][str(i)]
|
| 94 |
+
mlp_lora = layer_lora["mlp"]
|
| 95 |
+
attn_lora = layer_lora["attn"]
|
| 96 |
+
else:
|
| 97 |
+
mlp_lora = None
|
| 98 |
+
attn_lora = None
|
| 99 |
+
|
| 100 |
+
l_in = layer_norm(x, block.ln)
|
| 101 |
+
l_attn = attn(
|
| 102 |
+
l_in,
|
| 103 |
+
block.attn,
|
| 104 |
+
freqs_cis=w.freqs_cis,
|
| 105 |
+
kv_cache=block.kv_cache,
|
| 106 |
+
attn_mask=attn_mask,
|
| 107 |
+
n_heads=config.n_heads,
|
| 108 |
+
n_kv_heads=config.n_kv_heads,
|
| 109 |
+
position_ids=position_ids,
|
| 110 |
+
lora=attn_lora,
|
| 111 |
+
flex_block_mask_slice=flex_block_mask_slice,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if config.moe is not None and i >= config.moe.start_layer:
|
| 115 |
+
l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)
|
| 116 |
+
else:
|
| 117 |
+
l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
|
| 118 |
+
|
| 119 |
+
x = x + l_attn + l_mlp
|
| 120 |
+
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def lm_head(
|
| 125 |
+
hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None
|
| 126 |
+
):
|
| 127 |
+
hidden_BC = hidden_BTC[:, -1, :]
|
| 128 |
+
hidden_BC = layer_norm(hidden_BC, w.post_ln)
|
| 129 |
+
if indices is not None:
|
| 130 |
+
# Only compute logits for specified token indices
|
| 131 |
+
logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices]
|
| 132 |
+
else:
|
| 133 |
+
logits = w.lm_head(hidden_BC)
|
| 134 |
+
return logits
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
|
| 138 |
+
return nn.ModuleDict(
|
| 139 |
+
{
|
| 140 |
+
"fc1": linear_cls(d_model, d_ffn, dtype=dtype),
|
| 141 |
+
"fc2": linear_cls(d_ffn, d_model, dtype=dtype),
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
|
| 147 |
+
# For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
|
| 148 |
+
return nn.ModuleDict(
|
| 149 |
+
{
|
| 150 |
+
"router": nn.Linear(d_model, n_experts, dtype=dtype),
|
| 151 |
+
"fc1": nn.ParameterDict(
|
| 152 |
+
{
|
| 153 |
+
"weight": nn.Parameter(
|
| 154 |
+
torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype)
|
| 155 |
+
)
|
| 156 |
+
}
|
| 157 |
+
),
|
| 158 |
+
"fc2": nn.ParameterDict(
|
| 159 |
+
{
|
| 160 |
+
"weight": nn.Parameter(
|
| 161 |
+
torch.empty(n_experts, d_model, d_ffn, dtype=dtype)
|
| 162 |
+
)
|
| 163 |
+
}
|
| 164 |
+
),
|
| 165 |
+
}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
| 170 |
+
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
| 171 |
+
linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
|
| 172 |
+
|
| 173 |
+
text = nn.ModuleDict(
|
| 174 |
+
{
|
| 175 |
+
"blocks": nn.ModuleList(
|
| 176 |
+
[
|
| 177 |
+
nn.ModuleDict(
|
| 178 |
+
{
|
| 179 |
+
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
| 180 |
+
"attn": nn.ModuleDict(
|
| 181 |
+
{
|
| 182 |
+
"qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
|
| 183 |
+
"proj": linear_cls(
|
| 184 |
+
config.dim, config.dim, dtype=dtype
|
| 185 |
+
),
|
| 186 |
+
"tau": nn.ParameterDict(
|
| 187 |
+
{
|
| 188 |
+
"wq": nn.Parameter(
|
| 189 |
+
torch.empty(
|
| 190 |
+
config.n_heads, qkv_dim, dtype=dtype
|
| 191 |
+
)
|
| 192 |
+
),
|
| 193 |
+
"wv": nn.Parameter(
|
| 194 |
+
torch.empty(
|
| 195 |
+
config.n_heads, qkv_dim, dtype=dtype
|
| 196 |
+
)
|
| 197 |
+
),
|
| 198 |
+
"alpha": nn.Parameter(
|
| 199 |
+
torch.empty(config.n_heads, dtype=dtype)
|
| 200 |
+
),
|
| 201 |
+
}
|
| 202 |
+
),
|
| 203 |
+
}
|
| 204 |
+
),
|
| 205 |
+
"mlp": (
|
| 206 |
+
build_moe_mlp(
|
| 207 |
+
config.dim,
|
| 208 |
+
config.moe.expert_inner_dim,
|
| 209 |
+
config.moe.num_experts,
|
| 210 |
+
dtype,
|
| 211 |
+
)
|
| 212 |
+
if config.moe is not None
|
| 213 |
+
and layer_idx >= config.moe.start_layer
|
| 214 |
+
else build_dense_mlp(
|
| 215 |
+
config.dim, config.ff_dim, dtype, linear_cls
|
| 216 |
+
)
|
| 217 |
+
),
|
| 218 |
+
}
|
| 219 |
+
)
|
| 220 |
+
for layer_idx in range(config.n_layers)
|
| 221 |
+
]
|
| 222 |
+
),
|
| 223 |
+
"post_ln": nn.LayerNorm(config.dim, dtype=dtype),
|
| 224 |
+
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
|
| 225 |
+
}
|
| 226 |
+
)
|
| 227 |
+
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
|
| 228 |
+
text.register_buffer(
|
| 229 |
+
"freqs_cis",
|
| 230 |
+
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
|
| 231 |
+
persistent=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return text
|
utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
|
| 5 |
+
"""
|
| 6 |
+
Robust outlier detection for list of (x,y) tuples.
|
| 7 |
+
Only requires numpy.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
points_tuples: list of (x,y) tuples
|
| 11 |
+
k_nearest: number of neighbors to consider
|
| 12 |
+
threshold: multiplier for median distance
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
list: filtered list of (x,y) tuples with outliers removed
|
| 16 |
+
list: list of booleans indicating which points were kept (True = kept)
|
| 17 |
+
"""
|
| 18 |
+
points = np.array(points_tuples)
|
| 19 |
+
n_points = len(points)
|
| 20 |
+
|
| 21 |
+
# Calculate pairwise distances manually
|
| 22 |
+
dist_matrix = np.zeros((n_points, n_points))
|
| 23 |
+
for i in range(n_points):
|
| 24 |
+
for j in range(i + 1, n_points):
|
| 25 |
+
# Euclidean distance between points i and j
|
| 26 |
+
dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
|
| 27 |
+
dist_matrix[i, j] = dist
|
| 28 |
+
dist_matrix[j, i] = dist
|
| 29 |
+
|
| 30 |
+
# Get k nearest neighbors' distances
|
| 31 |
+
k = min(k_nearest, n_points - 1)
|
| 32 |
+
neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
|
| 33 |
+
avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
|
| 34 |
+
|
| 35 |
+
# Calculate mask using median distance
|
| 36 |
+
median_dist = np.median(avg_neighbor_dist)
|
| 37 |
+
mask = avg_neighbor_dist <= threshold * median_dist
|
| 38 |
+
|
| 39 |
+
# Return filtered tuples and mask
|
| 40 |
+
filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
|
| 41 |
+
return filtered_tuples
|
vision.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from typing import Union, Tuple
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from .layers import attn, layer_norm, mlp
|
| 10 |
+
from .image_crops import overlap_crop_image
|
| 11 |
+
from .config import VisionConfig
|
| 12 |
+
|
| 13 |
+
if torch.backends.mps.is_available():
|
| 14 |
+
# Non-divisible input sizes are not implemented on MPS device yet.
|
| 15 |
+
# https://github.com/pytorch/pytorch/issues/96056
|
| 16 |
+
def adaptive_avg_pool2d(input, output_size):
|
| 17 |
+
return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
|
| 18 |
+
|
| 19 |
+
else:
|
| 20 |
+
adaptive_avg_pool2d = F.adaptive_avg_pool2d
|
| 21 |
+
|
| 22 |
+
DeviceLike = Union[str, torch.device, int]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def prepare_crops(
|
| 26 |
+
image: Image.Image, config: VisionConfig, device: DeviceLike
|
| 27 |
+
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 28 |
+
np_image = np.array(image.convert("RGB"))
|
| 29 |
+
overlap_crops = overlap_crop_image(
|
| 30 |
+
np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
|
| 31 |
+
)
|
| 32 |
+
all_crops = overlap_crops["crops"]
|
| 33 |
+
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
|
| 34 |
+
all_crops = (
|
| 35 |
+
torch.from_numpy(all_crops)
|
| 36 |
+
.to(device=device, dtype=torch.bfloat16)
|
| 37 |
+
.div_(255.0)
|
| 38 |
+
.sub_(0.5)
|
| 39 |
+
.div_(0.5)
|
| 40 |
+
)
|
| 41 |
+
return all_crops, overlap_crops["tiling"]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def create_patches(x, patch_size):
|
| 45 |
+
# Original shape: [B, C, H, W]
|
| 46 |
+
B, C, H, W = x.shape
|
| 47 |
+
P1 = P2 = patch_size
|
| 48 |
+
|
| 49 |
+
# Step 1: Split H and W dimensions into patches
|
| 50 |
+
# [B, C, H/P1, P1, W/P2, P2]
|
| 51 |
+
x = x.reshape(B, C, H // P1, P1, W // P2, P2)
|
| 52 |
+
|
| 53 |
+
# Step 2: Rearrange dimensions to match target shape
|
| 54 |
+
# [B, H/P1, W/P2, C, P1, P2]
|
| 55 |
+
x = x.permute(0, 2, 4, 1, 3, 5)
|
| 56 |
+
|
| 57 |
+
# Step 3: Combine dimensions to get final shape
|
| 58 |
+
# [B, (H/P1)*(W/P2), C*P1*P2]
|
| 59 |
+
x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
|
| 60 |
+
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
| 65 |
+
x = create_patches(input_BCHW, config.enc_patch_size)
|
| 66 |
+
|
| 67 |
+
x = w.patch_emb(x)
|
| 68 |
+
x = x + w.pos_emb
|
| 69 |
+
for block in w.blocks:
|
| 70 |
+
x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
|
| 71 |
+
x = x + mlp(layer_norm(x, block.ln2), block.mlp)
|
| 72 |
+
x = layer_norm(x, w.post_ln)
|
| 73 |
+
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def vision_projection(
|
| 78 |
+
global_features: torch.Tensor,
|
| 79 |
+
reconstructed: torch.Tensor,
|
| 80 |
+
w: nn.Module,
|
| 81 |
+
config: VisionConfig,
|
| 82 |
+
):
|
| 83 |
+
reconstructed = reconstructed.permute(2, 0, 1)
|
| 84 |
+
reconstructed = adaptive_avg_pool2d(
|
| 85 |
+
reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
|
| 86 |
+
)
|
| 87 |
+
reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
|
| 88 |
+
final_features = torch.cat([global_features, reconstructed], dim=-1)
|
| 89 |
+
return mlp(final_features, w.proj_mlp)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def build_vision_model(config: VisionConfig, dtype: torch.dtype):
|
| 93 |
+
patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
|
| 94 |
+
grid_size = config.crop_size // config.enc_patch_size
|
| 95 |
+
num_patches = grid_size * grid_size
|
| 96 |
+
|
| 97 |
+
vision = nn.ModuleDict(
|
| 98 |
+
{
|
| 99 |
+
"patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
|
| 100 |
+
"blocks": nn.ModuleList(
|
| 101 |
+
[
|
| 102 |
+
nn.ModuleDict(
|
| 103 |
+
{
|
| 104 |
+
"ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
|
| 105 |
+
"attn": nn.ModuleDict(
|
| 106 |
+
{
|
| 107 |
+
"qkv": nn.Linear(
|
| 108 |
+
config.enc_dim, 3 * config.enc_dim, dtype=dtype
|
| 109 |
+
),
|
| 110 |
+
"proj": nn.Linear(
|
| 111 |
+
config.enc_dim, config.enc_dim, dtype=dtype
|
| 112 |
+
),
|
| 113 |
+
}
|
| 114 |
+
),
|
| 115 |
+
"ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
|
| 116 |
+
"mlp": nn.ModuleDict(
|
| 117 |
+
{
|
| 118 |
+
"fc1": nn.Linear(
|
| 119 |
+
config.enc_dim, config.enc_ff_dim, dtype=dtype
|
| 120 |
+
),
|
| 121 |
+
"fc2": nn.Linear(
|
| 122 |
+
config.enc_ff_dim, config.enc_dim, dtype=dtype
|
| 123 |
+
),
|
| 124 |
+
}
|
| 125 |
+
),
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
for _ in range(config.enc_n_layers)
|
| 129 |
+
]
|
| 130 |
+
),
|
| 131 |
+
"post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
|
| 132 |
+
"proj_mlp": nn.ModuleDict(
|
| 133 |
+
{
|
| 134 |
+
"fc1": nn.Linear(
|
| 135 |
+
config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
|
| 136 |
+
),
|
| 137 |
+
"fc2": nn.Linear(
|
| 138 |
+
config.proj_inner_dim, config.proj_out_dim, dtype=dtype
|
| 139 |
+
),
|
| 140 |
+
}
|
| 141 |
+
),
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
vision.pos_emb = nn.Parameter(
|
| 145 |
+
torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
|
| 146 |
+
)
|
| 147 |
+
return vision
|
visual_reasoning.png
ADDED
|
Git LFS Details
|