johnmalek312
commited on
Commit
·
083e56e
1
Parent(s):
1e6a1e7
torch rope broken
Browse files- .gitignore +3 -1
- .python-version +1 -0
- moondream2/config.py +1 -1
- moondream2/moondream.py +8 -3
- moondream2/rope.py +0 -1
- moondream2/text.py +7 -9
- notes.ipynb +14 -3
- requirements.txt +1 -0
.gitignore
CHANGED
|
@@ -13,4 +13,6 @@ __pycache__/
|
|
| 13 |
# venv
|
| 14 |
venv/
|
| 15 |
|
| 16 |
-
*.safetensors
|
|
|
|
|
|
|
|
|
| 13 |
# venv
|
| 14 |
venv/
|
| 15 |
|
| 16 |
+
*.safetensors
|
| 17 |
+
|
| 18 |
+
log/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
moondream2/config.py
CHANGED
|
@@ -8,7 +8,7 @@ class TextConfig:
|
|
| 8 |
ff_dim: int = 8192
|
| 9 |
n_layers: int = 24
|
| 10 |
vocab_size: int = 51200
|
| 11 |
-
max_context: int =
|
| 12 |
n_heads: int = 32
|
| 13 |
n_kv_heads: int = 32
|
| 14 |
prefix_attn: int = 730
|
|
|
|
| 8 |
ff_dim: int = 8192
|
| 9 |
n_layers: int = 24
|
| 10 |
vocab_size: int = 51200
|
| 11 |
+
max_context: int = 1000
|
| 12 |
n_heads: int = 32
|
| 13 |
n_kv_heads: int = 32
|
| 14 |
prefix_attn: int = 730
|
moondream2/moondream.py
CHANGED
|
@@ -14,7 +14,7 @@ from .text import build_text_model, text_encoder, lm_head, text_decoder
|
|
| 14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
| 15 |
from .utils import remove_outlier_points
|
| 16 |
import os
|
| 17 |
-
|
| 18 |
TextSamplingSettings = TypedDict(
|
| 19 |
"TextSamplingSettings",
|
| 20 |
{
|
|
@@ -71,6 +71,11 @@ class MoondreamModel(nn.Module):
|
|
| 71 |
self.vision = build_vision_model(config.vision, dtype)
|
| 72 |
self.text = build_text_model(config.text, dtype)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# Region Model
|
| 75 |
self.region = nn.ModuleDict(
|
| 76 |
{
|
|
@@ -149,12 +154,12 @@ class MoondreamModel(nn.Module):
|
|
| 149 |
return vision_projection(g, r, self.vision, self.config.vision)
|
| 150 |
|
| 151 |
def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
|
| 152 |
-
return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
|
| 153 |
|
| 154 |
def _decode_one_tok(
|
| 155 |
self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
|
| 156 |
):
|
| 157 |
-
hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
|
| 158 |
logits = lm_head(hidden, self.text)
|
| 159 |
return logits, hidden
|
| 160 |
|
|
|
|
| 14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
| 15 |
from .utils import remove_outlier_points
|
| 16 |
import os
|
| 17 |
+
from torchtune.modules import RotaryPositionalEmbeddings
|
| 18 |
TextSamplingSettings = TypedDict(
|
| 19 |
"TextSamplingSettings",
|
| 20 |
{
|
|
|
|
| 71 |
self.vision = build_vision_model(config.vision, dtype)
|
| 72 |
self.text = build_text_model(config.text, dtype)
|
| 73 |
|
| 74 |
+
self.rotary_emb = RotaryPositionalEmbeddings(
|
| 75 |
+
config.text.dim // config.text.n_heads,
|
| 76 |
+
config.text.max_context
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
# Region Model
|
| 80 |
self.region = nn.ModuleDict(
|
| 81 |
{
|
|
|
|
| 154 |
return vision_projection(g, r, self.vision, self.config.vision)
|
| 155 |
|
| 156 |
def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
|
| 157 |
+
return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, self.rotary_emb)
|
| 158 |
|
| 159 |
def _decode_one_tok(
|
| 160 |
self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
|
| 161 |
):
|
| 162 |
+
hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, self.rotary_emb)
|
| 163 |
logits = lm_head(hidden, self.text)
|
| 164 |
return logits, hidden
|
| 165 |
|
moondream2/rope.py
CHANGED
|
@@ -62,7 +62,6 @@ def func11(x):
|
|
| 62 |
|
| 63 |
def apply_rotary_emb(
|
| 64 |
x: torch.Tensor,
|
| 65 |
-
freqs_cis: torch.Tensor,
|
| 66 |
position_ids: torch.Tensor,
|
| 67 |
num_heads: int,
|
| 68 |
rot_dim: int = 32,
|
|
|
|
| 62 |
|
| 63 |
def apply_rotary_emb(
|
| 64 |
x: torch.Tensor,
|
|
|
|
| 65 |
position_ids: torch.Tensor,
|
| 66 |
num_heads: int,
|
| 67 |
rot_dim: int = 32,
|
moondream2/text.py
CHANGED
|
@@ -15,11 +15,11 @@ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
|
|
| 15 |
def attn(
|
| 16 |
x: torch.Tensor,
|
| 17 |
w: nn.Module,
|
| 18 |
-
freqs_cis: torch.Tensor,
|
| 19 |
kv_cache: nn.Module,
|
| 20 |
attn_mask: torch.Tensor,
|
| 21 |
n_heads: int,
|
| 22 |
position_ids: torch.Tensor,
|
|
|
|
| 23 |
do_apply_rotary_emb: bool = True,
|
| 24 |
):
|
| 25 |
bsz, q_len, d_model = x.shape
|
|
@@ -37,8 +37,8 @@ def attn(
|
|
| 37 |
# 3. Unpack/Split along the first dimension (which now separates Q, K, V)
|
| 38 |
q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
|
| 39 |
|
| 40 |
-
q =
|
| 41 |
-
k =
|
| 42 |
|
| 43 |
if kv_cache is not None:
|
| 44 |
k, v = kv_cache.update(position_ids, k, v)
|
|
@@ -57,16 +57,18 @@ def text_decoder(
|
|
| 57 |
attn_mask: torch.Tensor,
|
| 58 |
position_ids: torch.Tensor,
|
| 59 |
config: TextConfig,
|
|
|
|
|
|
|
| 60 |
):
|
| 61 |
for i, block in enumerate(w.blocks):
|
| 62 |
l_in = layer_norm(x, block.ln)
|
| 63 |
l_attn = attn(
|
| 64 |
l_in,
|
| 65 |
block.attn,
|
| 66 |
-
freqs_cis=w.freqs_cis,
|
| 67 |
kv_cache=block.kv_cache,
|
| 68 |
attn_mask=attn_mask,
|
| 69 |
n_heads=config.n_heads,
|
|
|
|
| 70 |
position_ids=position_ids,
|
| 71 |
)
|
| 72 |
l_mlp = mlp(l_in, block.mlp)
|
|
@@ -120,10 +122,6 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
| 120 |
}
|
| 121 |
)
|
| 122 |
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
|
| 123 |
-
|
| 124 |
-
"freqs_cis",
|
| 125 |
-
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
|
| 126 |
-
persistent=False,
|
| 127 |
-
)
|
| 128 |
|
| 129 |
return text
|
|
|
|
| 15 |
def attn(
|
| 16 |
x: torch.Tensor,
|
| 17 |
w: nn.Module,
|
|
|
|
| 18 |
kv_cache: nn.Module,
|
| 19 |
attn_mask: torch.Tensor,
|
| 20 |
n_heads: int,
|
| 21 |
position_ids: torch.Tensor,
|
| 22 |
+
rotary_emb: nn.Module,
|
| 23 |
do_apply_rotary_emb: bool = True,
|
| 24 |
):
|
| 25 |
bsz, q_len, d_model = x.shape
|
|
|
|
| 37 |
# 3. Unpack/Split along the first dimension (which now separates Q, K, V)
|
| 38 |
q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
|
| 39 |
|
| 40 |
+
q = rotary_emb(q.permute(0, 2, 1, 3))
|
| 41 |
+
k = rotary_emb(k.permute(0, 2, 1, 3))
|
| 42 |
|
| 43 |
if kv_cache is not None:
|
| 44 |
k, v = kv_cache.update(position_ids, k, v)
|
|
|
|
| 57 |
attn_mask: torch.Tensor,
|
| 58 |
position_ids: torch.Tensor,
|
| 59 |
config: TextConfig,
|
| 60 |
+
rotary_emb: nn.Module
|
| 61 |
+
|
| 62 |
):
|
| 63 |
for i, block in enumerate(w.blocks):
|
| 64 |
l_in = layer_norm(x, block.ln)
|
| 65 |
l_attn = attn(
|
| 66 |
l_in,
|
| 67 |
block.attn,
|
|
|
|
| 68 |
kv_cache=block.kv_cache,
|
| 69 |
attn_mask=attn_mask,
|
| 70 |
n_heads=config.n_heads,
|
| 71 |
+
rotary_emb=rotary_emb,
|
| 72 |
position_ids=position_ids,
|
| 73 |
)
|
| 74 |
l_mlp = mlp(l_in, block.mlp)
|
|
|
|
| 122 |
}
|
| 123 |
)
|
| 124 |
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
|
| 125 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
return text
|
notes.ipynb
CHANGED
|
@@ -54,7 +54,18 @@
|
|
| 54 |
"cell_type": "code",
|
| 55 |
"execution_count": 1,
|
| 56 |
"metadata": {},
|
| 57 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
"source": [
|
| 59 |
"import torch"
|
| 60 |
]
|
|
@@ -398,7 +409,7 @@
|
|
| 398 |
],
|
| 399 |
"metadata": {
|
| 400 |
"kernelspec": {
|
| 401 |
-
"display_name": "
|
| 402 |
"language": "python",
|
| 403 |
"name": "python3"
|
| 404 |
},
|
|
@@ -412,7 +423,7 @@
|
|
| 412 |
"name": "python",
|
| 413 |
"nbconvert_exporter": "python",
|
| 414 |
"pygments_lexer": "ipython3",
|
| 415 |
-
"version": "3.
|
| 416 |
}
|
| 417 |
},
|
| 418 |
"nbformat": 4,
|
|
|
|
| 54 |
"cell_type": "code",
|
| 55 |
"execution_count": 1,
|
| 56 |
"metadata": {},
|
| 57 |
+
"outputs": [
|
| 58 |
+
{
|
| 59 |
+
"ename": "",
|
| 60 |
+
"evalue": "",
|
| 61 |
+
"output_type": "error",
|
| 62 |
+
"traceback": [
|
| 63 |
+
"\u001b[1;31mRunning cells with 'venv12 (Python 3.12.10)' requires the ipykernel package.\n",
|
| 64 |
+
"\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
|
| 65 |
+
"\u001b[1;31mCommand: '/home/pixel/Desktop/moondream/venv12/bin/python3.12 -m pip install ipykernel -U --force-reinstall'"
|
| 66 |
+
]
|
| 67 |
+
}
|
| 68 |
+
],
|
| 69 |
"source": [
|
| 70 |
"import torch"
|
| 71 |
]
|
|
|
|
| 409 |
],
|
| 410 |
"metadata": {
|
| 411 |
"kernelspec": {
|
| 412 |
+
"display_name": "venv12",
|
| 413 |
"language": "python",
|
| 414 |
"name": "python3"
|
| 415 |
},
|
|
|
|
| 423 |
"name": "python",
|
| 424 |
"nbconvert_exporter": "python",
|
| 425 |
"pygments_lexer": "ipython3",
|
| 426 |
+
"version": "3.12.10"
|
| 427 |
}
|
| 428 |
},
|
| 429 |
"nbformat": 4,
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
torch==2.7.0+cu128 torchvision==0.22.0+cu128 torchaudio==2.7.0+cu128 --index-url https://download.pytorch.org/whl/cu128
|