Update moondream.py
Browse filesfix: batched generation
- moondream.py +55 -32
moondream.py
CHANGED
|
@@ -976,7 +976,7 @@ class MoondreamModel(nn.Module):
|
|
| 976 |
def _generate_points_batched(
|
| 977 |
self,
|
| 978 |
hidden, # (B,1,C)
|
| 979 |
-
next_token, # (B,1)
|
| 980 |
pos, # int or Tensor; normalized below
|
| 981 |
include_size: bool = True,
|
| 982 |
max_objects: int = 50,
|
|
@@ -989,10 +989,11 @@ class MoondreamModel(nn.Module):
|
|
| 989 |
eos_id = self.config.tokenizer.eos_id
|
| 990 |
max_ctx = self.config.text.max_context
|
| 991 |
|
|
|
|
| 992 |
if torch.is_tensor(pos):
|
| 993 |
pos = int(pos.max().item())
|
| 994 |
|
| 995 |
-
#
|
| 996 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 997 |
if pos > 0:
|
| 998 |
mask[:, :, :, :pos] = True
|
|
@@ -1001,36 +1002,44 @@ class MoondreamModel(nn.Module):
|
|
| 1001 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1002 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 1003 |
|
| 1004 |
-
def _argmax01(logits):
|
| 1005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
if use_soft_argmax:
|
| 1007 |
probs = torch.softmax(logits, dim=-1)
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
|
|
|
| 1011 |
else:
|
| 1012 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1013 |
return idx / float(logits.size(-1) - 1)
|
| 1014 |
|
| 1015 |
with torch.inference_mode():
|
| 1016 |
while alive.any() and (counts < max_objects).any():
|
| 1017 |
-
# x
|
| 1018 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 1019 |
-
|
| 1020 |
-
x_center
|
| 1021 |
-
|
| 1022 |
-
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1)
|
| 1023 |
|
| 1024 |
mask[alive, :, :, pos] = True
|
| 1025 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1026 |
pos_ids[alive, 0] += 1
|
| 1027 |
pos += 1
|
| 1028 |
|
| 1029 |
-
# y
|
| 1030 |
-
y_logits = decode_coordinate(hidden, self.region)
|
| 1031 |
-
|
| 1032 |
-
y_center
|
| 1033 |
-
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
|
| 1034 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1035 |
|
| 1036 |
mask[alive, :, :, pos] = True
|
|
@@ -1039,27 +1048,40 @@ class MoondreamModel(nn.Module):
|
|
| 1039 |
pos += 1
|
| 1040 |
|
| 1041 |
if include_size:
|
| 1042 |
-
# size
|
| 1043 |
-
size_logits = decode_size(hidden, self.region)
|
| 1044 |
-
w_logits = size_logits
|
| 1045 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
if use_soft_argmax:
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
|
|
|
|
|
|
| 1051 |
else:
|
| 1052 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1053 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1054 |
|
| 1055 |
-
|
| 1056 |
-
|
|
|
|
|
|
|
|
|
|
| 1057 |
|
| 1058 |
-
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype)
|
| 1059 |
-
size_emb = encode_size(size_in, self.region).unsqueeze(1)
|
| 1060 |
|
|
|
|
| 1061 |
for i in range(B):
|
| 1062 |
-
if not alive[i]:
|
|
|
|
| 1063 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1064 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1065 |
yt = (y_center[i] - h[i] / 2).item()
|
|
@@ -1075,7 +1097,7 @@ class MoondreamModel(nn.Module):
|
|
| 1075 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1076 |
pos_ids[alive, 0] += 1
|
| 1077 |
pos += 1
|
| 1078 |
-
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1079 |
else:
|
| 1080 |
for i in range(B):
|
| 1081 |
if alive[i]:
|
|
@@ -1086,6 +1108,7 @@ class MoondreamModel(nn.Module):
|
|
| 1086 |
pos += 1
|
| 1087 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1088 |
|
|
|
|
| 1089 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1090 |
counts = counts + ((~finished_now) & alive).to(counts.dtype)
|
| 1091 |
alive &= ~finished_now
|
|
|
|
| 976 |
def _generate_points_batched(
|
| 977 |
self,
|
| 978 |
hidden, # (B,1,C)
|
| 979 |
+
next_token, # (B,1) (kept for API compatibility)
|
| 980 |
pos, # int or Tensor; normalized below
|
| 981 |
include_size: bool = True,
|
| 982 |
max_objects: int = 50,
|
|
|
|
| 989 |
eos_id = self.config.tokenizer.eos_id
|
| 990 |
max_ctx = self.config.text.max_context
|
| 991 |
|
| 992 |
+
# Normalize pos to a scalar int (supports int, (1,), (B,), (B,1))
|
| 993 |
if torch.is_tensor(pos):
|
| 994 |
pos = int(pos.max().item())
|
| 995 |
|
| 996 |
+
# 4-D mask: (B, 1, q_len=1, kv_len) + per-row position ids (B,1)
|
| 997 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 998 |
if pos > 0:
|
| 999 |
mask[:, :, :, :pos] = True
|
|
|
|
| 1002 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1003 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 1004 |
|
| 1005 |
+
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 1006 |
+
"""
|
| 1007 |
+
logits: (..., bins) -> normalized index in [0,1] per row
|
| 1008 |
+
Accepts (B,1,bins), (B,bins), or (bins,).
|
| 1009 |
+
"""
|
| 1010 |
+
# Canonicalize to (B, bins)
|
| 1011 |
+
if logits.dim() == 3: # (B,1,bins)
|
| 1012 |
+
logits = logits.squeeze(1)
|
| 1013 |
+
elif logits.dim() == 1: # (bins,) -> (1,bins)
|
| 1014 |
+
logits = logits.unsqueeze(0)
|
| 1015 |
+
|
| 1016 |
if use_soft_argmax:
|
| 1017 |
probs = torch.softmax(logits, dim=-1)
|
| 1018 |
+
bins_idx = torch.arange(probs.size(-1), device=probs.device, dtype=torch.float32)
|
| 1019 |
+
# expected-bin (0..bins-1) -> normalize by (bins-1) to [0,1]
|
| 1020 |
+
expbin = (probs * bins_idx).sum(dim=-1)
|
| 1021 |
+
return expbin / float(probs.size(-1) - 1)
|
| 1022 |
else:
|
| 1023 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1024 |
return idx / float(logits.size(-1) - 1)
|
| 1025 |
|
| 1026 |
with torch.inference_mode():
|
| 1027 |
while alive.any() and (counts < max_objects).any():
|
| 1028 |
+
# ---- x ------------------------------------------------------
|
| 1029 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b)
|
| 1030 |
+
x_center = _argmax01(x_logits) # (B,)
|
| 1031 |
+
x_in = x_center.to(dtype=x_logits.dtype if torch.is_tensor(x_logits) else hidden.dtype).unsqueeze(-1)
|
| 1032 |
+
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
|
|
|
| 1033 |
|
| 1034 |
mask[alive, :, :, pos] = True
|
| 1035 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1036 |
pos_ids[alive, 0] += 1
|
| 1037 |
pos += 1
|
| 1038 |
|
| 1039 |
+
# ---- y ------------------------------------------------------
|
| 1040 |
+
y_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b)
|
| 1041 |
+
y_center = _argmax01(y_logits) # (B,)
|
| 1042 |
+
y_in = y_center.to(dtype=y_logits.dtype if torch.is_tensor(y_logits) else hidden.dtype).unsqueeze(-1)
|
|
|
|
| 1043 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1044 |
|
| 1045 |
mask[alive, :, :, pos] = True
|
|
|
|
| 1048 |
pos += 1
|
| 1049 |
|
| 1050 |
if include_size:
|
| 1051 |
+
# ---- size ----------------------------------------------
|
| 1052 |
+
size_logits = decode_size(hidden, self.region) # tuple of (w_logits, h_logits)
|
| 1053 |
+
w_logits, h_logits = size_logits
|
| 1054 |
+
|
| 1055 |
+
# Canonicalize to (B, bins) for both
|
| 1056 |
+
if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
|
| 1057 |
+
if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
|
| 1058 |
+
if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
|
| 1059 |
+
if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
|
| 1060 |
+
|
| 1061 |
if use_soft_argmax:
|
| 1062 |
+
w_probs = torch.softmax(w_logits, dim=-1)
|
| 1063 |
+
h_probs = torch.softmax(h_logits, dim=-1)
|
| 1064 |
+
w_bins_idx = torch.arange(w_probs.size(-1), device=device, dtype=torch.float32)
|
| 1065 |
+
h_bins_idx = torch.arange(h_probs.size(-1), device=device, dtype=torch.float32)
|
| 1066 |
+
w_bin = (w_probs * w_bins_idx).sum(dim=-1) # (B,)
|
| 1067 |
+
h_bin = (h_probs * h_bins_idx).sum(dim=-1) # (B,)
|
| 1068 |
else:
|
| 1069 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1070 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1071 |
|
| 1072 |
+
# bins -> size (inverse log scale), robust to bins != 1024
|
| 1073 |
+
w_den = float(w_logits.size(-1) - 1)
|
| 1074 |
+
h_den = float(h_logits.size(-1) - 1)
|
| 1075 |
+
w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
|
| 1076 |
+
h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
|
| 1077 |
|
| 1078 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
|
| 1079 |
+
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1080 |
|
| 1081 |
+
# record boxes only for alive rows
|
| 1082 |
for i in range(B):
|
| 1083 |
+
if not alive[i]:
|
| 1084 |
+
continue
|
| 1085 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1086 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1087 |
yt = (y_center[i] - h[i] / 2).item()
|
|
|
|
| 1097 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1098 |
pos_ids[alive, 0] += 1
|
| 1099 |
pos += 1
|
| 1100 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1101 |
else:
|
| 1102 |
for i in range(B):
|
| 1103 |
if alive[i]:
|
|
|
|
| 1108 |
pos += 1
|
| 1109 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1110 |
|
| 1111 |
+
# stop only rows that hit eos (or reached max objects)
|
| 1112 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1113 |
counts = counts + ((~finished_now) & alive).to(counts.dtype)
|
| 1114 |
alive &= ~finished_now
|