Update moondream.py
Browse filesfix: tensor mismatch between kv cache and batched generation
- moondream.py +79 -61
moondream.py
CHANGED
|
@@ -77,50 +77,54 @@ class KVCache(nn.Module):
|
|
| 77 |
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 78 |
)
|
| 79 |
|
| 80 |
-
# --- replace the whole method in KVCache ---
|
| 81 |
def update(self, pos_ids, k, v):
|
| 82 |
"""
|
| 83 |
-
Supports
|
| 84 |
• Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
|
| 85 |
-
• 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,)
|
| 86 |
-
|
|
|
|
| 87 |
"""
|
| 88 |
kout, vout = self.k_cache, self.v_cache
|
| 89 |
|
| 90 |
if not torch.is_tensor(pos_ids):
|
| 91 |
-
# Scalar
|
| 92 |
kout[:, :, pos_ids, :] = k
|
| 93 |
vout[:, :, pos_ids, :] = v
|
| 94 |
return kout, vout
|
| 95 |
|
| 96 |
-
# Normalize dtype
|
| 97 |
pos_ids = pos_ids.to(dtype=torch.long, device=k.device)
|
| 98 |
|
| 99 |
-
# Shapes
|
| 100 |
if k.dim() != 4 or v.dim() != 4:
|
| 101 |
raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
|
| 102 |
B, Hkv, q_len, D = k.shape
|
| 103 |
|
| 104 |
-
#
|
| 105 |
if kout.size(0) != B:
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
# Case A: PREFILL —
|
| 109 |
if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
|
| 110 |
for i in range(B):
|
| 111 |
-
kout[i, :, pos_ids, :] = k[i]
|
| 112 |
vout[i, :, pos_ids, :] = v[i]
|
| 113 |
return kout, vout
|
| 114 |
|
| 115 |
-
# Case B: STEP
|
| 116 |
-
if
|
|
|
|
| 117 |
for i in range(B):
|
| 118 |
-
pi = int(
|
| 119 |
kout[i, :, pi, :] = k[i, :, 0, :]
|
| 120 |
vout[i, :, pi, :] = v[i, :, 0, :]
|
| 121 |
return kout, vout
|
| 122 |
|
| 123 |
-
#
|
| 124 |
if pos_ids.dim() == 0 and q_len == 1:
|
| 125 |
pi = int(pos_ids.item())
|
| 126 |
kout[:, :, pi, :] = k[:, :, 0, :]
|
|
@@ -135,6 +139,7 @@ class KVCache(nn.Module):
|
|
| 135 |
|
| 136 |
|
| 137 |
|
|
|
|
| 138 |
class MoondreamModel(nn.Module):
|
| 139 |
|
| 140 |
def __init__(
|
|
@@ -968,93 +973,105 @@ class MoondreamModel(nn.Module):
|
|
| 968 |
|
| 969 |
return last_hidden, next_token, pos_vec
|
| 970 |
|
| 971 |
-
|
| 972 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
B = hidden.size(0)
|
| 974 |
device = self.device
|
| 975 |
out = [[] for _ in range(B)]
|
| 976 |
eos_id = self.config.tokenizer.eos_id
|
| 977 |
max_ctx = self.config.text.max_context
|
| 978 |
|
| 979 |
-
# 4-D mask: (B,1,1,kv_len)
|
| 980 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
|
|
|
| 984 |
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 988 |
|
| 989 |
with torch.inference_mode():
|
| 990 |
while alive.any() and (counts < max_objects).any():
|
| 991 |
-
# --- x ---
|
| 992 |
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 993 |
-
if x_logits.dim() == 3:
|
| 994 |
-
|
| 995 |
-
x_center =
|
| 996 |
x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
|
| 997 |
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 998 |
|
| 999 |
-
# advance one
|
| 1000 |
-
|
| 1001 |
-
if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
|
| 1002 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1003 |
-
|
|
|
|
| 1004 |
|
| 1005 |
-
# --- y ---
|
| 1006 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1007 |
-
if y_logits.dim() == 3:
|
| 1008 |
-
|
| 1009 |
-
y_center =
|
| 1010 |
-
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
|
| 1011 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1012 |
|
| 1013 |
-
|
| 1014 |
-
if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
|
| 1015 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1016 |
-
|
|
|
|
| 1017 |
|
| 1018 |
if include_size:
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1025 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1026 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 1027 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 1028 |
-
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype)
|
| 1029 |
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1030 |
|
|
|
|
| 1031 |
for i in range(B):
|
| 1032 |
-
if
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
|
| 1040 |
-
|
| 1041 |
-
if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
|
| 1042 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1043 |
-
|
|
|
|
| 1044 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1045 |
else:
|
| 1046 |
for i in range(B):
|
| 1047 |
if alive[i]:
|
| 1048 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1049 |
-
|
| 1050 |
-
for i in range(B):
|
| 1051 |
-
if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
|
| 1052 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1053 |
-
|
|
|
|
| 1054 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1055 |
|
| 1056 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1057 |
-
counts = counts + (
|
| 1058 |
alive &= ~finished_now
|
| 1059 |
|
| 1060 |
return out
|
|
@@ -1062,6 +1079,7 @@ class MoondreamModel(nn.Module):
|
|
| 1062 |
|
| 1063 |
|
| 1064 |
|
|
|
|
| 1065 |
def detect_multi(self, image, objects, settings=None):
|
| 1066 |
"""
|
| 1067 |
Parallel multi-label detection.
|
|
|
|
| 77 |
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 78 |
)
|
| 79 |
|
|
|
|
| 80 |
def update(self, pos_ids, k, v):
|
| 81 |
"""
|
| 82 |
+
Supports:
|
| 83 |
• Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
|
| 84 |
+
• 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,) or (B,1)
|
| 85 |
+
• Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar int
|
| 86 |
+
Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
|
| 87 |
"""
|
| 88 |
kout, vout = self.k_cache, self.v_cache
|
| 89 |
|
| 90 |
if not torch.is_tensor(pos_ids):
|
| 91 |
+
# Scalar legacy path
|
| 92 |
kout[:, :, pos_ids, :] = k
|
| 93 |
vout[:, :, pos_ids, :] = v
|
| 94 |
return kout, vout
|
| 95 |
|
|
|
|
| 96 |
pos_ids = pos_ids.to(dtype=torch.long, device=k.device)
|
| 97 |
|
|
|
|
| 98 |
if k.dim() != 4 or v.dim() != 4:
|
| 99 |
raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
|
| 100 |
B, Hkv, q_len, D = k.shape
|
| 101 |
|
| 102 |
+
# Make sure cache batch matches B (expand-from-1 is ok, otherwise error)
|
| 103 |
if kout.size(0) != B:
|
| 104 |
+
if kout.size(0) == 1:
|
| 105 |
+
self.k_cache = kout.expand(B, -1, -1, -1).clone()
|
| 106 |
+
self.v_cache = vout.expand(B, -1, -1, -1).clone()
|
| 107 |
+
kout, vout = self.k_cache, self.v_cache
|
| 108 |
+
else:
|
| 109 |
+
raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
|
| 110 |
|
| 111 |
+
# Case A: PREFILL — pos_ids indexes a contiguous range per row
|
| 112 |
if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
|
| 113 |
for i in range(B):
|
| 114 |
+
kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
|
| 115 |
vout[i, :, pos_ids, :] = v[i]
|
| 116 |
return kout, vout
|
| 117 |
|
| 118 |
+
# Case B: STEP — q_len == 1 and one position per row
|
| 119 |
+
if q_len == 1 and pos_ids.numel() == B:
|
| 120 |
+
pos_ids_flat = pos_ids.view(-1) # handle (B,1) or (B,)
|
| 121 |
for i in range(B):
|
| 122 |
+
pi = int(pos_ids_flat[i].item())
|
| 123 |
kout[i, :, pi, :] = k[i, :, 0, :]
|
| 124 |
vout[i, :, pi, :] = v[i, :, 0, :]
|
| 125 |
return kout, vout
|
| 126 |
|
| 127 |
+
# Case C: scalar for everyone
|
| 128 |
if pos_ids.dim() == 0 and q_len == 1:
|
| 129 |
pi = int(pos_ids.item())
|
| 130 |
kout[:, :, pi, :] = k[:, :, 0, :]
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
|
| 142 |
+
|
| 143 |
class MoondreamModel(nn.Module):
|
| 144 |
|
| 145 |
def __init__(
|
|
|
|
| 973 |
|
| 974 |
return last_hidden, next_token, pos_vec
|
| 975 |
|
| 976 |
+
# In class MoondreamModel, replace the whole method:
|
| 977 |
+
def _generate_points_batched(
|
| 978 |
+
self,
|
| 979 |
+
hidden, # (B,1,C)
|
| 980 |
+
next_token, # (B,1) (not used when temperature=0, but ok)
|
| 981 |
+
pos: int, # shared scalar next position
|
| 982 |
+
include_size: bool = True,
|
| 983 |
+
max_objects: int = 50,
|
| 984 |
+
lora=None,
|
| 985 |
+
):
|
| 986 |
+
"""
|
| 987 |
+
Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
|
| 988 |
+
for all rows in the batch simultaneously. Returns list-of-lists of dicts, len B.
|
| 989 |
+
"""
|
| 990 |
B = hidden.size(0)
|
| 991 |
device = self.device
|
| 992 |
out = [[] for _ in range(B)]
|
| 993 |
eos_id = self.config.tokenizer.eos_id
|
| 994 |
max_ctx = self.config.text.max_context
|
| 995 |
|
| 996 |
+
# 4-D mask: (B, 1, q_len=1, kv_len)
|
| 997 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 998 |
+
if pos > 0:
|
| 999 |
+
mask[:, :, :, :pos] = True
|
| 1000 |
+
# IMPORTANT: position_ids must be (B, 1) for rotary; KVCache.update accepts (B,1) too
|
| 1001 |
+
pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
|
| 1002 |
|
| 1003 |
+
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1004 |
+
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
|
|
|
| 1005 |
|
| 1006 |
with torch.inference_mode():
|
| 1007 |
while alive.any() and (counts < max_objects).any():
|
| 1008 |
+
# --- x coordinate ---
|
| 1009 |
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 1010 |
+
if x_logits.dim() == 3:
|
| 1011 |
+
x_logits = x_logits.squeeze(1)
|
| 1012 |
+
x_center = x_logits.argmax(dim=-1).to(torch.float32) / float(x_logits.size(-1)) # (B,)
|
| 1013 |
x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
|
| 1014 |
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1015 |
|
| 1016 |
+
# advance attention one step
|
| 1017 |
+
mask[:, :, :, pos] = True
|
|
|
|
| 1018 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1019 |
+
pos += 1
|
| 1020 |
+
pos_ids[:, 0] = pos
|
| 1021 |
|
| 1022 |
+
# --- y coordinate ---
|
| 1023 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1024 |
+
if y_logits.dim() == 3:
|
| 1025 |
+
y_logits = y_logits.squeeze(1)
|
| 1026 |
+
y_center = y_logits.argmax(dim=-1).to(torch.float32) / float(y_logits.size(-1))
|
| 1027 |
+
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
|
| 1028 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1029 |
|
| 1030 |
+
mask[:, :, :, pos] = True
|
|
|
|
| 1031 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1032 |
+
pos += 1
|
| 1033 |
+
pos_ids[:, 0] = pos
|
| 1034 |
|
| 1035 |
if include_size:
|
| 1036 |
+
# --- size ---
|
| 1037 |
+
size_logits = decode_size(hidden, self.region) # tuple/list [w_logits, h_logits]
|
| 1038 |
+
# Support both (B,1,1024) and (B,1024)
|
| 1039 |
+
w_logits = size_logits[0].squeeze(1)
|
| 1040 |
+
h_logits = size_logits[1].squeeze(1)
|
| 1041 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1042 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1043 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 1044 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 1045 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
|
| 1046 |
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1047 |
|
| 1048 |
+
# record boxes
|
| 1049 |
for i in range(B):
|
| 1050 |
+
if alive[i]:
|
| 1051 |
+
out[i].append({
|
| 1052 |
+
"x_min": (x_center[i] - w[i] / 2).item(),
|
| 1053 |
+
"y_min": (y_center[i] - h[i] / 2).item(),
|
| 1054 |
+
"x_max": (x_center[i] + w[i] / 2).item(),
|
| 1055 |
+
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 1056 |
+
})
|
| 1057 |
|
| 1058 |
+
mask[:, :, :, pos] = True
|
|
|
|
| 1059 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1060 |
+
pos += 1
|
| 1061 |
+
pos_ids[:, 0] = pos
|
| 1062 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1063 |
else:
|
| 1064 |
for i in range(B):
|
| 1065 |
if alive[i]:
|
| 1066 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1067 |
+
mask[:, :, :, pos] = True
|
|
|
|
|
|
|
| 1068 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1069 |
+
pos += 1
|
| 1070 |
+
pos_ids[:, 0] = pos
|
| 1071 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1072 |
|
| 1073 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1074 |
+
counts = counts + (~finished_now & alive).to(counts.dtype)
|
| 1075 |
alive &= ~finished_now
|
| 1076 |
|
| 1077 |
return out
|
|
|
|
| 1079 |
|
| 1080 |
|
| 1081 |
|
| 1082 |
+
|
| 1083 |
def detect_multi(self, image, objects, settings=None):
|
| 1084 |
"""
|
| 1085 |
Parallel multi-label detection.
|