Update moondream.py
Browse files- moondream.py +90 -70
moondream.py
CHANGED
|
@@ -895,81 +895,96 @@ class MoondreamModel(nn.Module):
|
|
| 895 |
probs = self._apply_top_p(probs, top_p)
|
| 896 |
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 897 |
|
| 898 |
-
pos_vec = torch.
|
| 899 |
|
| 900 |
return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
|
| 901 |
|
| 902 |
-
def _generate_points_batched(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
"""
|
| 904 |
-
|
| 905 |
-
|
|
|
|
|
|
|
| 906 |
Returns: list-of-lists of dicts, length B.
|
| 907 |
"""
|
| 908 |
-
|
| 909 |
B = hidden.size(0)
|
| 910 |
device = self.device
|
| 911 |
out = [[] for _ in range(B)]
|
| 912 |
eos_id = self.config.tokenizer.eos_id
|
| 913 |
-
|
| 914 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
max_ctx = self.config.text.max_context
|
| 916 |
mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 922 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 923 |
-
|
| 924 |
with torch.inference_mode():
|
| 925 |
while alive.any() and (counts < max_objects).any():
|
| 926 |
-
# --- x coordinate
|
| 927 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 928 |
if x_logits.dim() == 3:
|
| 929 |
-
x_logits = x_logits.squeeze(1)
|
| 930 |
-
x_bin
|
| 931 |
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 932 |
-
x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,
|
| 933 |
x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
|
|
|
|
|
|
| 943 |
# --- y coordinate ---
|
| 944 |
y_logits = decode_coordinate(hidden, self.region)
|
| 945 |
if y_logits.dim() == 3:
|
| 946 |
-
y_logits = y_logits.squeeze(1)
|
| 947 |
-
y_bin
|
| 948 |
-
y_center = y_bin / float(y_logits.size(-1))
|
| 949 |
-
y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
|
| 950 |
-
y_emb = encode_coordinate(y_input, self.region).unsqueeze(1)
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
|
|
|
|
|
|
| 960 |
if include_size:
|
| 961 |
-
# --- size
|
| 962 |
-
size_logits = decode_size(hidden, self.region)
|
| 963 |
-
w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1) # (B,1024)
|
| 964 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 965 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 966 |
-
# Convert
|
| 967 |
-
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 968 |
-
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 969 |
-
size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,
|
| 970 |
-
size_emb = encode_size(size_input, self.region).unsqueeze(1)
|
| 971 |
-
|
| 972 |
-
#
|
| 973 |
for i in range(B):
|
| 974 |
if not alive[i]:
|
| 975 |
continue
|
|
@@ -979,35 +994,40 @@ class MoondreamModel(nn.Module):
|
|
| 979 |
"x_max": (x_center[i] + w[i] / 2).item(),
|
| 980 |
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 981 |
})
|
| 982 |
-
|
| 983 |
# step: decode "next token" to decide continuation
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
|
|
|
|
|
|
|
|
|
| 989 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 990 |
else:
|
| 991 |
# Points mode (no size)
|
| 992 |
-
for i in range(B):
|
| 993 |
-
if not alive[i]:
|
| 994 |
-
continue
|
| 995 |
-
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 996 |
-
# step: decode next token from y_emb
|
| 997 |
for i in range(B):
|
| 998 |
if alive[i]:
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1005 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1006 |
counts = counts + (~finished_now & alive).to(counts.dtype)
|
| 1007 |
alive &= ~finished_now
|
| 1008 |
-
|
| 1009 |
return out
|
| 1010 |
|
|
|
|
| 1011 |
def detect_multi(self, image, objects, settings=None):
|
| 1012 |
"""
|
| 1013 |
Parallel multi-label detection.
|
|
@@ -1053,7 +1073,7 @@ class MoondreamModel(nn.Module):
|
|
| 1053 |
d["label"] = lab
|
| 1054 |
res[lab] = lst
|
| 1055 |
return {"objects": res}
|
| 1056 |
-
|
| 1057 |
def _detect_gaze(
|
| 1058 |
self,
|
| 1059 |
image: EncodedImage,
|
|
|
|
| 895 |
probs = self._apply_top_p(probs, top_p)
|
| 896 |
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 897 |
|
| 898 |
+
pos_vec = torch.full((B,), pos + T, device=self.device, dtype=torch.long)
|
| 899 |
|
| 900 |
return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
|
| 901 |
|
| 902 |
+
def _generate_points_batched(
|
| 903 |
+
self,
|
| 904 |
+
hidden: torch.Tensor, # (B, 1, C) last hidden per row from prefill
|
| 905 |
+
next_token: torch.Tensor, # (B, 1) not used directly (kept for parity)
|
| 906 |
+
pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
|
| 907 |
+
include_size: bool = True,
|
| 908 |
+
max_objects: int = 50,
|
| 909 |
+
lora=None,
|
| 910 |
+
):
|
| 911 |
"""
|
| 912 |
+
Batched decode loop for multi-label detection.
|
| 913 |
+
- Uses a *shared* scalar position id per step (q_len = 1), as expected by RoPE.
|
| 914 |
+
- Maintains a per-row attention mask and 'alive' flags.
|
| 915 |
+
- Feeds coord encoders with (B,1) tensors; size encoder with (B,2).
|
| 916 |
Returns: list-of-lists of dicts, length B.
|
| 917 |
"""
|
|
|
|
| 918 |
B = hidden.size(0)
|
| 919 |
device = self.device
|
| 920 |
out = [[] for _ in range(B)]
|
| 921 |
eos_id = self.config.tokenizer.eos_id
|
| 922 |
+
|
| 923 |
+
# --- Shared write position (scalar) consistent with RoPE q_len=1 ---
|
| 924 |
+
# We align rows by padding; using the maximum ensures all KV rows can decode in lockstep.
|
| 925 |
+
pos = int(pos_vec.max().item())
|
| 926 |
+
|
| 927 |
+
# Per-row attention mask (1 = visible). Mark everything up to 'pos' as visible.
|
| 928 |
max_ctx = self.config.text.max_context
|
| 929 |
mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
|
| 930 |
+
mask[:, :, :pos] = 1
|
| 931 |
+
|
| 932 |
+
alive = torch.ones(B, dtype=torch.bool, device=device)
|
|
|
|
|
|
|
| 933 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 934 |
+
|
| 935 |
with torch.inference_mode():
|
| 936 |
while alive.any() and (counts < max_objects).any():
|
| 937 |
+
# --- x coordinate ---
|
| 938 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
|
| 939 |
if x_logits.dim() == 3:
|
| 940 |
+
x_logits = x_logits.squeeze(1) # (B,1024)
|
| 941 |
+
x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
|
| 942 |
x_center = x_bin / float(x_logits.size(-1)) # (B,)
|
| 943 |
+
x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
|
| 944 |
x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
|
| 945 |
+
|
| 946 |
+
# step: decode hidden for y (advance shared pos)
|
| 947 |
+
mask[:, :, pos] = 1
|
| 948 |
+
logits, hidden = self._decode_one_tok(
|
| 949 |
+
x_emb,
|
| 950 |
+
mask,
|
| 951 |
+
torch.tensor([pos], device=device, dtype=torch.long), # length-1 (q_len=1)
|
| 952 |
+
lora,
|
| 953 |
+
)
|
| 954 |
+
pos += 1
|
| 955 |
+
|
| 956 |
# --- y coordinate ---
|
| 957 |
y_logits = decode_coordinate(hidden, self.region)
|
| 958 |
if y_logits.dim() == 3:
|
| 959 |
+
y_logits = y_logits.squeeze(1)
|
| 960 |
+
y_bin = y_logits.argmax(dim=-1).to(torch.float32)
|
| 961 |
+
y_center = y_bin / float(y_logits.size(-1)) # (B,)
|
| 962 |
+
y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
|
| 963 |
+
y_emb = encode_coordinate(y_input, self.region).unsqueeze(1) # (B,1,C)
|
| 964 |
+
|
| 965 |
+
# step: decode hidden for size / eos (advance shared pos)
|
| 966 |
+
mask[:, :, pos] = 1
|
| 967 |
+
logits, hidden = self._decode_one_tok(
|
| 968 |
+
y_emb,
|
| 969 |
+
mask,
|
| 970 |
+
torch.tensor([pos], device=device, dtype=torch.long),
|
| 971 |
+
lora,
|
| 972 |
+
)
|
| 973 |
+
pos += 1
|
| 974 |
+
|
| 975 |
if include_size:
|
| 976 |
+
# --- size (batched) ---
|
| 977 |
+
size_logits = decode_size(hidden, self.region) # ([B,1,1024],[B,1,1024])
|
| 978 |
+
w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1) # (B,1024)
|
| 979 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 980 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 981 |
+
# Convert log-scale bins -> sizes in [0,1]
|
| 982 |
+
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0) # (B,)
|
| 983 |
+
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
|
| 984 |
+
size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
|
| 985 |
+
size_emb = encode_size(size_input, self.region).unsqueeze(1) # (B,1,C)
|
| 986 |
+
|
| 987 |
+
# Record boxes for alive rows
|
| 988 |
for i in range(B):
|
| 989 |
if not alive[i]:
|
| 990 |
continue
|
|
|
|
| 994 |
"x_max": (x_center[i] + w[i] / 2).item(),
|
| 995 |
"y_max": (y_center[i] + h[i] / 2).item(),
|
| 996 |
})
|
| 997 |
+
|
| 998 |
# step: decode "next token" to decide continuation
|
| 999 |
+
mask[:, :, pos] = 1
|
| 1000 |
+
logits, hidden = self._decode_one_tok(
|
| 1001 |
+
size_emb,
|
| 1002 |
+
mask,
|
| 1003 |
+
torch.tensor([pos], device=device, dtype=torch.long),
|
| 1004 |
+
lora,
|
| 1005 |
+
)
|
| 1006 |
+
pos += 1
|
| 1007 |
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1008 |
else:
|
| 1009 |
# Points mode (no size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1010 |
for i in range(B):
|
| 1011 |
if alive[i]:
|
| 1012 |
+
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1013 |
+
mask[:, :, pos] = 1
|
| 1014 |
+
logits, hidden = self._decode_one_tok(
|
| 1015 |
+
y_emb,
|
| 1016 |
+
mask,
|
| 1017 |
+
torch.tensor([pos], device=device, dtype=torch.long),
|
| 1018 |
+
lora,
|
| 1019 |
+
)
|
| 1020 |
+
pos += 1
|
| 1021 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1022 |
+
|
| 1023 |
+
# Update finished/alive bookkeeping
|
| 1024 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1025 |
counts = counts + (~finished_now & alive).to(counts.dtype)
|
| 1026 |
alive &= ~finished_now
|
| 1027 |
+
|
| 1028 |
return out
|
| 1029 |
|
| 1030 |
+
|
| 1031 |
def detect_multi(self, image, objects, settings=None):
|
| 1032 |
"""
|
| 1033 |
Parallel multi-label detection.
|
|
|
|
| 1073 |
d["label"] = lab
|
| 1074 |
res[lab] = lst
|
| 1075 |
return {"objects": res}
|
| 1076 |
+
|
| 1077 |
def _detect_gaze(
|
| 1078 |
self,
|
| 1079 |
image: EncodedImage,
|