johnmalek312 commited on
Commit
083e56e
·
1 Parent(s): 1e6a1e7

torch rope broken

Browse files
.gitignore CHANGED
@@ -13,4 +13,6 @@ __pycache__/
13
  # venv
14
  venv/
15
 
16
- *.safetensors
 
 
 
13
  # venv
14
  venv/
15
 
16
+ *.safetensors
17
+
18
+ log/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
moondream2/config.py CHANGED
@@ -8,7 +8,7 @@ class TextConfig:
8
  ff_dim: int = 8192
9
  n_layers: int = 24
10
  vocab_size: int = 51200
11
- max_context: int = 2048
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
 
8
  ff_dim: int = 8192
9
  n_layers: int = 24
10
  vocab_size: int = 51200
11
+ max_context: int = 1000
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
moondream2/moondream.py CHANGED
@@ -14,7 +14,7 @@ from .text import build_text_model, text_encoder, lm_head, text_decoder
14
  from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
  from .utils import remove_outlier_points
16
  import os
17
-
18
  TextSamplingSettings = TypedDict(
19
  "TextSamplingSettings",
20
  {
@@ -71,6 +71,11 @@ class MoondreamModel(nn.Module):
71
  self.vision = build_vision_model(config.vision, dtype)
72
  self.text = build_text_model(config.text, dtype)
73
 
 
 
 
 
 
74
  # Region Model
75
  self.region = nn.ModuleDict(
76
  {
@@ -149,12 +154,12 @@ class MoondreamModel(nn.Module):
149
  return vision_projection(g, r, self.vision, self.config.vision)
150
 
151
  def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
152
- return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
153
 
154
  def _decode_one_tok(
155
  self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
156
  ):
157
- hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
158
  logits = lm_head(hidden, self.text)
159
  return logits, hidden
160
 
 
14
  from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
  from .utils import remove_outlier_points
16
  import os
17
+ from torchtune.modules import RotaryPositionalEmbeddings
18
  TextSamplingSettings = TypedDict(
19
  "TextSamplingSettings",
20
  {
 
71
  self.vision = build_vision_model(config.vision, dtype)
72
  self.text = build_text_model(config.text, dtype)
73
 
74
+ self.rotary_emb = RotaryPositionalEmbeddings(
75
+ config.text.dim // config.text.n_heads,
76
+ config.text.max_context
77
+ )
78
+
79
  # Region Model
80
  self.region = nn.ModuleDict(
81
  {
 
154
  return vision_projection(g, r, self.vision, self.config.vision)
155
 
156
  def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
157
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, self.rotary_emb)
158
 
159
  def _decode_one_tok(
160
  self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
161
  ):
162
+ hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, self.rotary_emb)
163
  logits = lm_head(hidden, self.text)
164
  return logits, hidden
165
 
moondream2/rope.py CHANGED
@@ -62,7 +62,6 @@ def func11(x):
62
 
63
  def apply_rotary_emb(
64
  x: torch.Tensor,
65
- freqs_cis: torch.Tensor,
66
  position_ids: torch.Tensor,
67
  num_heads: int,
68
  rot_dim: int = 32,
 
62
 
63
  def apply_rotary_emb(
64
  x: torch.Tensor,
 
65
  position_ids: torch.Tensor,
66
  num_heads: int,
67
  rot_dim: int = 32,
moondream2/text.py CHANGED
@@ -15,11 +15,11 @@ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
15
  def attn(
16
  x: torch.Tensor,
17
  w: nn.Module,
18
- freqs_cis: torch.Tensor,
19
  kv_cache: nn.Module,
20
  attn_mask: torch.Tensor,
21
  n_heads: int,
22
  position_ids: torch.Tensor,
 
23
  do_apply_rotary_emb: bool = True,
24
  ):
25
  bsz, q_len, d_model = x.shape
@@ -37,8 +37,8 @@ def attn(
37
  # 3. Unpack/Split along the first dimension (which now separates Q, K, V)
38
  q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
39
 
40
- q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
41
- k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads)
42
 
43
  if kv_cache is not None:
44
  k, v = kv_cache.update(position_ids, k, v)
@@ -57,16 +57,18 @@ def text_decoder(
57
  attn_mask: torch.Tensor,
58
  position_ids: torch.Tensor,
59
  config: TextConfig,
 
 
60
  ):
61
  for i, block in enumerate(w.blocks):
62
  l_in = layer_norm(x, block.ln)
63
  l_attn = attn(
64
  l_in,
65
  block.attn,
66
- freqs_cis=w.freqs_cis,
67
  kv_cache=block.kv_cache,
68
  attn_mask=attn_mask,
69
  n_heads=config.n_heads,
 
70
  position_ids=position_ids,
71
  )
72
  l_mlp = mlp(l_in, block.mlp)
@@ -120,10 +122,6 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
120
  }
121
  )
122
  text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
123
- text.register_buffer(
124
- "freqs_cis",
125
- precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
126
- persistent=False,
127
- )
128
 
129
  return text
 
15
  def attn(
16
  x: torch.Tensor,
17
  w: nn.Module,
 
18
  kv_cache: nn.Module,
19
  attn_mask: torch.Tensor,
20
  n_heads: int,
21
  position_ids: torch.Tensor,
22
+ rotary_emb: nn.Module,
23
  do_apply_rotary_emb: bool = True,
24
  ):
25
  bsz, q_len, d_model = x.shape
 
37
  # 3. Unpack/Split along the first dimension (which now separates Q, K, V)
38
  q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
39
 
40
+ q = rotary_emb(q.permute(0, 2, 1, 3))
41
+ k = rotary_emb(k.permute(0, 2, 1, 3))
42
 
43
  if kv_cache is not None:
44
  k, v = kv_cache.update(position_ids, k, v)
 
57
  attn_mask: torch.Tensor,
58
  position_ids: torch.Tensor,
59
  config: TextConfig,
60
+ rotary_emb: nn.Module
61
+
62
  ):
63
  for i, block in enumerate(w.blocks):
64
  l_in = layer_norm(x, block.ln)
65
  l_attn = attn(
66
  l_in,
67
  block.attn,
 
68
  kv_cache=block.kv_cache,
69
  attn_mask=attn_mask,
70
  n_heads=config.n_heads,
71
+ rotary_emb=rotary_emb,
72
  position_ids=position_ids,
73
  )
74
  l_mlp = mlp(l_in, block.mlp)
 
122
  }
123
  )
124
  text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
125
+
 
 
 
 
126
 
127
  return text
notes.ipynb CHANGED
@@ -54,7 +54,18 @@
54
  "cell_type": "code",
55
  "execution_count": 1,
56
  "metadata": {},
57
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
58
  "source": [
59
  "import torch"
60
  ]
@@ -398,7 +409,7 @@
398
  ],
399
  "metadata": {
400
  "kernelspec": {
401
- "display_name": "venv",
402
  "language": "python",
403
  "name": "python3"
404
  },
@@ -412,7 +423,7 @@
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
- "version": "3.13.3"
416
  }
417
  },
418
  "nbformat": 4,
 
54
  "cell_type": "code",
55
  "execution_count": 1,
56
  "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "ename": "",
60
+ "evalue": "",
61
+ "output_type": "error",
62
+ "traceback": [
63
+ "\u001b[1;31mRunning cells with 'venv12 (Python 3.12.10)' requires the ipykernel package.\n",
64
+ "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
65
+ "\u001b[1;31mCommand: '/home/pixel/Desktop/moondream/venv12/bin/python3.12 -m pip install ipykernel -U --force-reinstall'"
66
+ ]
67
+ }
68
+ ],
69
  "source": [
70
  "import torch"
71
  ]
 
409
  ],
410
  "metadata": {
411
  "kernelspec": {
412
+ "display_name": "venv12",
413
  "language": "python",
414
  "name": "python3"
415
  },
 
423
  "name": "python",
424
  "nbconvert_exporter": "python",
425
  "pygments_lexer": "ipython3",
426
+ "version": "3.12.10"
427
  }
428
  },
429
  "nbformat": 4,
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch==2.7.0+cu128 torchvision==0.22.0+cu128 torchaudio==2.7.0+cu128 --index-url https://download.pytorch.org/whl/cu128