Update moondream.py
Browse filesfix: caled_dot_product_attention expands to a 3-D mask into (B, n_heads, q_len, kv_len) and bombs out.
- moondream.py +38 -27
moondream.py
CHANGED
|
@@ -26,6 +26,8 @@ from .region import decode_coordinate, encode_coordinate, decode_size, encode_si
|
|
| 26 |
from .text import text_encoder, lm_head
|
| 27 |
from typing import Optional, List, Union
|
| 28 |
from .lora import variant_state_dict
|
|
|
|
|
|
|
| 29 |
|
| 30 |
ImageEncodingSettings = TypedDict(
|
| 31 |
"ImageEncodingSettings",
|
|
@@ -911,24 +913,27 @@ class MoondreamModel(nn.Module):
|
|
| 911 |
|
| 912 |
def _generate_points_batched(
|
| 913 |
self,
|
| 914 |
-
hidden, # (B,1,C)
|
| 915 |
-
next_token, # (B,1)
|
| 916 |
-
pos: int, # shared scalar next position
|
| 917 |
include_size: bool = True,
|
| 918 |
max_objects: int = 50,
|
| 919 |
lora=None,
|
| 920 |
):
|
| 921 |
"""
|
| 922 |
Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
|
| 923 |
-
for all rows in the batch simultaneously. Returns list-of-lists of dicts
|
|
|
|
| 924 |
"""
|
|
|
|
|
|
|
| 925 |
B = hidden.size(0)
|
| 926 |
device = self.device
|
| 927 |
out = [[] for _ in range(B)]
|
| 928 |
eos_id = self.config.tokenizer.eos_id
|
| 929 |
max_ctx = self.config.text.max_context
|
| 930 |
|
| 931 |
-
# 4-D mask: (B, 1, q_len=1, kv_len)
|
| 932 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 933 |
if pos > 0:
|
| 934 |
mask[:, :, :, :pos] = True
|
|
@@ -939,29 +944,29 @@ class MoondreamModel(nn.Module):
|
|
| 939 |
|
| 940 |
with torch.inference_mode():
|
| 941 |
while alive.any() and (counts < max_objects).any():
|
| 942 |
-
# --- x coordinate ---
|
| 943 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 944 |
if x_logits.dim() == 3:
|
| 945 |
-
x_logits = x_logits.squeeze(1)
|
| 946 |
-
x_bin
|
| 947 |
-
x_center = x_bin / float(x_logits.size(-1))
|
| 948 |
-
x_in
|
| 949 |
-
x_emb
|
| 950 |
|
| 951 |
-
# advance
|
| 952 |
mask[:, :, :, pos] = True
|
| 953 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
|
| 954 |
pos += 1
|
| 955 |
pos_id[0] = pos
|
| 956 |
|
| 957 |
-
# --- y coordinate ---
|
| 958 |
-
y_logits = decode_coordinate(hidden, self.region)
|
| 959 |
if y_logits.dim() == 3:
|
| 960 |
y_logits = y_logits.squeeze(1)
|
| 961 |
-
y_bin
|
| 962 |
-
y_center = y_bin / float(y_logits.size(-1))
|
| 963 |
-
y_in
|
| 964 |
-
y_emb
|
| 965 |
|
| 966 |
mask[:, :, :, pos] = True
|
| 967 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
|
@@ -969,18 +974,23 @@ class MoondreamModel(nn.Module):
|
|
| 969 |
pos_id[0] = pos
|
| 970 |
|
| 971 |
if include_size:
|
| 972 |
-
#
|
| 973 |
-
|
| 974 |
-
|
|
|
|
|
|
|
|
|
|
| 975 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 976 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 977 |
-
|
|
|
|
| 978 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 979 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 980 |
-
|
|
|
|
| 981 |
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 982 |
|
| 983 |
-
#
|
| 984 |
for i in range(B):
|
| 985 |
if alive[i]:
|
| 986 |
out[i].append({
|
|
@@ -990,21 +1000,22 @@ class MoondreamModel(nn.Module):
|
|
| 990 |
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 991 |
})
|
| 992 |
|
|
|
|
| 993 |
mask[:, :, :, pos] = True
|
| 994 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
|
| 995 |
pos += 1
|
| 996 |
pos_id[0] = pos
|
| 997 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 998 |
else:
|
|
|
|
| 999 |
for i in range(B):
|
| 1000 |
if alive[i]:
|
| 1001 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1002 |
-
|
| 1003 |
mask[:, :, :, pos] = True
|
| 1004 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
| 1005 |
pos += 1
|
| 1006 |
pos_id[0] = pos
|
| 1007 |
-
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1008 |
|
| 1009 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1010 |
counts = counts + (~finished_now & alive).to(counts.dtype)
|
|
|
|
| 26 |
from .text import text_encoder, lm_head
|
| 27 |
from typing import Optional, List, Union
|
| 28 |
from .lora import variant_state_dict
|
| 29 |
+
from .layers import mlp
|
| 30 |
+
|
| 31 |
|
| 32 |
ImageEncodingSettings = TypedDict(
|
| 33 |
"ImageEncodingSettings",
|
|
|
|
| 913 |
|
| 914 |
def _generate_points_batched(
|
| 915 |
self,
|
| 916 |
+
hidden, # (B,1,C) last hidden after prefill (per label row)
|
| 917 |
+
next_token, # (B,1) (kept for parity; not used when temperature=0)
|
| 918 |
+
pos: int, # shared scalar next position for all rows
|
| 919 |
include_size: bool = True,
|
| 920 |
max_objects: int = 50,
|
| 921 |
lora=None,
|
| 922 |
):
|
| 923 |
"""
|
| 924 |
Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
|
| 925 |
+
for all rows in the batch simultaneously. Returns list-of-lists of dicts (len B).
|
| 926 |
+
Batch-safe: uses 4-D masks and avoids region.decode_size() (which flattens batch).
|
| 927 |
"""
|
| 928 |
+
import torch
|
| 929 |
+
|
| 930 |
B = hidden.size(0)
|
| 931 |
device = self.device
|
| 932 |
out = [[] for _ in range(B)]
|
| 933 |
eos_id = self.config.tokenizer.eos_id
|
| 934 |
max_ctx = self.config.text.max_context
|
| 935 |
|
| 936 |
+
# 4-D mask: (B, 1, q_len=1, kv_len), True means "visible" to match model's convention
|
| 937 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 938 |
if pos > 0:
|
| 939 |
mask[:, :, :, :pos] = True
|
|
|
|
| 944 |
|
| 945 |
with torch.inference_mode():
|
| 946 |
while alive.any() and (counts < max_objects).any():
|
| 947 |
+
# --- x coordinate (batched) ---
|
| 948 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 949 |
if x_logits.dim() == 3:
|
| 950 |
+
x_logits = x_logits.squeeze(1) # (B,1024)
|
| 951 |
+
x_bin = x_logits.argmax(dim=-1).to(torch.float32)
|
| 952 |
+
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 953 |
+
x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
|
| 954 |
+
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 955 |
|
| 956 |
+
# advance one token
|
| 957 |
mask[:, :, :, pos] = True
|
| 958 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
|
| 959 |
pos += 1
|
| 960 |
pos_id[0] = pos
|
| 961 |
|
| 962 |
+
# --- y coordinate (batched) ---
|
| 963 |
+
y_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 964 |
if y_logits.dim() == 3:
|
| 965 |
y_logits = y_logits.squeeze(1)
|
| 966 |
+
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 967 |
+
y_center = y_bin / float(y_logits.size(-1))
|
| 968 |
+
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
|
| 969 |
+
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 970 |
|
| 971 |
mask[:, :, :, pos] = True
|
| 972 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
|
|
|
| 974 |
pos_id[0] = pos
|
| 975 |
|
| 976 |
if include_size:
|
| 977 |
+
# ---- size (batched, *without* region.decode_size which flattens batch) ----
|
| 978 |
+
# size_out_dim is 2*1024 (W then H). mlp() preserves (B,1,路).
|
| 979 |
+
size_logits = mlp(hidden, self.region["size_decoder"]).squeeze(1) # (B, 2048)
|
| 980 |
+
half = size_logits.size(-1) // 2
|
| 981 |
+
w_logits, h_logits = size_logits[:, :half], size_logits[:, half:] # (B,1024),(B,1024)
|
| 982 |
+
|
| 983 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 984 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 985 |
+
|
| 986 |
+
# inverse log-scale mapping used by the repo
|
| 987 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 988 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 989 |
+
|
| 990 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
|
| 991 |
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 992 |
|
| 993 |
+
# commit boxes
|
| 994 |
for i in range(B):
|
| 995 |
if alive[i]:
|
| 996 |
out[i].append({
|
|
|
|
| 1000 |
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 1001 |
})
|
| 1002 |
|
| 1003 |
+
# decide continuation
|
| 1004 |
mask[:, :, :, pos] = True
|
| 1005 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
|
| 1006 |
pos += 1
|
| 1007 |
pos_id[0] = pos
|
| 1008 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1009 |
else:
|
| 1010 |
+
# points mode
|
| 1011 |
for i in range(B):
|
| 1012 |
if alive[i]:
|
| 1013 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
|
|
|
| 1014 |
mask[:, :, :, pos] = True
|
| 1015 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
|
| 1016 |
pos += 1
|
| 1017 |
pos_id[0] = pos
|
| 1018 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1019 |
|
| 1020 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1021 |
counts = counts + (~finished_now & alive).to(counts.dtype)
|