Update moondream.py
Browse filesfeat: make prefill granular.
- moondream.py +28 -22
moondream.py
CHANGED
|
@@ -943,13 +943,19 @@ class MoondreamModel(nn.Module):
|
|
| 943 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 944 |
|
| 945 |
|
| 946 |
-
def _prefill_prompt_batched(
|
| 947 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
tpl = self.config.tokenizer.templates["detect"]
|
| 949 |
if tpl is None:
|
| 950 |
raise NotImplementedError("Model does not support object detection.")
|
| 951 |
|
| 952 |
-
# Build
|
| 953 |
rows_ids, lens = [], []
|
| 954 |
for lab in labels:
|
| 955 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
|
@@ -957,44 +963,44 @@ class MoondreamModel(nn.Module):
|
|
| 957 |
rows_ids.append(t)
|
| 958 |
lens.append(t.numel())
|
| 959 |
|
| 960 |
-
B
|
|
|
|
| 961 |
|
| 962 |
-
# Embed
|
| 963 |
-
|
| 964 |
-
embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list of (Li, C)
|
| 965 |
padded = []
|
| 966 |
for e, L in zip(embs, lens):
|
| 967 |
pad = T - L
|
| 968 |
if pad > 0:
|
| 969 |
-
e = torch.cat([e[:1].repeat(pad, 1), e], dim=0)
|
| 970 |
padded.append(e)
|
| 971 |
-
prompt_emb = torch.stack(padded, dim=0)
|
| 972 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 973 |
|
| 974 |
-
#
|
| 975 |
-
base = self.attn_mask[:, :, pos:pos+T, :]
|
| 976 |
-
mask = base.expand(B, -1, -1, -1).contiguous()
|
| 977 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
|
|
|
|
|
|
| 978 |
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
|
| 984 |
-
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 985 |
-
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 986 |
|
| 987 |
if temperature == 0.0:
|
| 988 |
-
next_token = last_logits.argmax(dim=-1, keepdim=True)
|
| 989 |
else:
|
| 990 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 991 |
probs = self._apply_top_p(probs, top_p)
|
| 992 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 993 |
|
| 994 |
-
pos_end = int(pos + T)
|
| 995 |
return last_hidden, next_token, pos_end
|
| 996 |
|
| 997 |
|
|
|
|
| 998 |
def _generate_points_batched(
|
| 999 |
self,
|
| 1000 |
hidden, # (B,1,C) - last token hidden state per row
|
|
|
|
| 943 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 944 |
|
| 945 |
|
| 946 |
+
def _prefill_prompt_batched(
|
| 947 |
+
self,
|
| 948 |
+
labels: List[str],
|
| 949 |
+
pos: int,
|
| 950 |
+
lora=None,
|
| 951 |
+
temperature: float = 0.0,
|
| 952 |
+
top_p: float = 0.0,
|
| 953 |
+
):
|
| 954 |
tpl = self.config.tokenizer.templates["detect"]
|
| 955 |
if tpl is None:
|
| 956 |
raise NotImplementedError("Model does not support object detection.")
|
| 957 |
|
| 958 |
+
# 1) Build token ids for each label (variable length)
|
| 959 |
rows_ids, lens = [], []
|
| 960 |
for lab in labels:
|
| 961 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
|
|
|
| 963 |
rows_ids.append(t)
|
| 964 |
lens.append(t.numel())
|
| 965 |
|
| 966 |
+
B = len(rows_ids)
|
| 967 |
+
T = max(lens)
|
| 968 |
|
| 969 |
+
# 2) Embed then LEFT-pad each row to length T using the row’s first token embedding
|
| 970 |
+
embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
|
|
|
|
| 971 |
padded = []
|
| 972 |
for e, L in zip(embs, lens):
|
| 973 |
pad = T - L
|
| 974 |
if pad > 0:
|
| 975 |
+
e = torch.cat([e[:1].repeat(pad, 1), e], dim=0) # (T, C)
|
| 976 |
padded.append(e)
|
| 977 |
+
prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
|
| 978 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 979 |
|
| 980 |
+
# 3) Prefill over the shared image prefix [pos : pos + T)
|
| 981 |
+
base = self.attn_mask[:, :, pos:pos + T, :] # (1, 1, T, K)
|
| 982 |
+
mask = base.expand(B, -1, -1, -1).contiguous() # (B, 1, T, K)
|
| 983 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
| 984 |
+
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
|
| 985 |
+
logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
|
| 986 |
|
| 987 |
+
# **FIX**: After left-padding, the last real token sits at T-1 for every row.
|
| 988 |
+
last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long) # (B,)
|
| 989 |
+
last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B, 1, C)
|
| 990 |
+
last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B, V)
|
|
|
|
|
|
|
|
|
|
| 991 |
|
| 992 |
if temperature == 0.0:
|
| 993 |
+
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B, 1)
|
| 994 |
else:
|
| 995 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 996 |
probs = self._apply_top_p(probs, top_p)
|
| 997 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 998 |
|
| 999 |
+
pos_end = int(pos + T)
|
| 1000 |
return last_hidden, next_token, pos_end
|
| 1001 |
|
| 1002 |
|
| 1003 |
+
|
| 1004 |
def _generate_points_batched(
|
| 1005 |
self,
|
| 1006 |
hidden, # (B,1,C) - last token hidden state per row
|