Update moondream.py
Browse files- moondream.py +44 -61
moondream.py
CHANGED
|
@@ -77,18 +77,16 @@ class KVCache(nn.Module):
|
|
| 77 |
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 78 |
)
|
| 79 |
|
| 80 |
-
# In class KVCache, REPLACE the whole update() with this:
|
| 81 |
def update(self, pos_ids, k, v):
|
| 82 |
"""
|
| 83 |
Supports:
|
| 84 |
• Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
|
| 85 |
• 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
|
| 86 |
-
•
|
| 87 |
Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
|
| 88 |
"""
|
| 89 |
kout, vout = self.k_cache, self.v_cache
|
| 90 |
|
| 91 |
-
# Normalize pos_ids
|
| 92 |
if not torch.is_tensor(pos_ids):
|
| 93 |
pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
|
| 94 |
else:
|
|
@@ -98,7 +96,7 @@ class KVCache(nn.Module):
|
|
| 98 |
raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
|
| 99 |
B, Hkv, q_len, D = k.shape
|
| 100 |
|
| 101 |
-
#
|
| 102 |
if kout.size(0) != B:
|
| 103 |
if kout.size(0) == 1:
|
| 104 |
self.k_cache = kout.expand(B, -1, -1, -1).clone()
|
|
@@ -107,14 +105,14 @@ class KVCache(nn.Module):
|
|
| 107 |
else:
|
| 108 |
raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
|
| 109 |
|
| 110 |
-
#
|
| 111 |
if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
|
| 112 |
for i in range(B):
|
| 113 |
-
kout[i, :, pos_ids, :] = k[i]
|
| 114 |
vout[i, :, pos_ids, :] = v[i]
|
| 115 |
return kout, vout
|
| 116 |
|
| 117 |
-
#
|
| 118 |
if q_len == 1 and pos_ids.numel() == B:
|
| 119 |
pos_ids = pos_ids.view(B)
|
| 120 |
for i in range(B):
|
|
@@ -123,16 +121,15 @@ class KVCache(nn.Module):
|
|
| 123 |
vout[i, :, pi, :] = v[i, :, 0, :]
|
| 124 |
return kout, vout
|
| 125 |
|
| 126 |
-
#
|
| 127 |
if pos_ids.dim() == 0 and q_len == 1:
|
| 128 |
pi = int(pos_ids.item())
|
| 129 |
kout[:, :, pi, :] = k[:, :, 0, :]
|
| 130 |
vout[:, :, pi, :] = v[:, :, 0, :]
|
| 131 |
return kout, vout
|
| 132 |
|
| 133 |
-
raise RuntimeError(
|
| 134 |
-
|
| 135 |
-
)
|
| 136 |
|
| 137 |
|
| 138 |
|
|
@@ -213,10 +210,6 @@ class MoondreamModel(nn.Module):
|
|
| 213 |
|
| 214 |
|
| 215 |
def _reset_kv_caches(self, batch_size: int = 1):
|
| 216 |
-
"""
|
| 217 |
-
Recreate KV caches with the requested batch size so subsequent calls
|
| 218 |
-
(e.g., encode_image) start from a consistent shape.
|
| 219 |
-
"""
|
| 220 |
c = self.config.text
|
| 221 |
head_dim = c.dim // c.n_heads
|
| 222 |
for blk in self.text.blocks:
|
|
@@ -225,6 +218,7 @@ class MoondreamModel(nn.Module):
|
|
| 225 |
shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
|
| 226 |
blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 227 |
blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
def _setup_caches(self):
|
|
@@ -589,12 +583,13 @@ class MoondreamModel(nn.Module):
|
|
| 589 |
elif not isinstance(image, Image.Image):
|
| 590 |
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 591 |
|
| 592 |
-
#
|
| 593 |
for blk in self.text.blocks:
|
| 594 |
if blk.kv_cache.k_cache.size(0) != 1:
|
| 595 |
blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
|
| 596 |
blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
|
| 597 |
|
|
|
|
| 598 |
|
| 599 |
lora = (
|
| 600 |
variant_state_dict(settings["variant"], device=self.device)
|
|
@@ -933,6 +928,7 @@ class MoondreamModel(nn.Module):
|
|
| 933 |
b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
|
| 934 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 935 |
|
|
|
|
| 936 |
def _prefill_prompt_batched(self, labels, pos: int, lora=None,
|
| 937 |
temperature: float = 0.0, top_p: float = 0.0):
|
| 938 |
tpl = self.config.tokenizer.templates["detect"]
|
|
@@ -949,12 +945,11 @@ class MoondreamModel(nn.Module):
|
|
| 949 |
|
| 950 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 951 |
for i, ids in enumerate(rows):
|
| 952 |
-
prompt_ids[i, :
|
| 953 |
|
| 954 |
prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
|
| 955 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 956 |
|
| 957 |
-
# 4-D mask: (B,1,T,kv_len) for SDPA
|
| 958 |
base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,kv_len)
|
| 959 |
mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,kv_len)
|
| 960 |
|
|
@@ -967,29 +962,26 @@ class MoondreamModel(nn.Module):
|
|
| 967 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 968 |
|
| 969 |
if temperature == 0.0:
|
| 970 |
-
next_token = last_logits.argmax(dim=-1, keepdim=True)
|
| 971 |
else:
|
| 972 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 973 |
probs = self._apply_top_p(probs, top_p)
|
| 974 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 975 |
|
| 976 |
-
#
|
| 977 |
-
pos_vec = torch.tensor(lens, device=self.device, dtype=torch.long) + pos # (B,)
|
| 978 |
-
|
| 979 |
-
# At the end of _prefill_prompt_batched(), return a Python int:
|
| 980 |
-
pos_end = int((pos + T))
|
| 981 |
return last_hidden, next_token, pos_end
|
| 982 |
|
| 983 |
|
|
|
|
| 984 |
def _generate_points_batched(
|
| 985 |
self,
|
| 986 |
hidden, # (B,1,C)
|
| 987 |
-
next_token, # (B,1)
|
| 988 |
pos, # int or Tensor; normalized below
|
| 989 |
include_size: bool = True,
|
| 990 |
max_objects: int = 50,
|
| 991 |
lora=None,
|
| 992 |
-
use_soft_argmax: bool = True,
|
| 993 |
):
|
| 994 |
B = hidden.size(0)
|
| 995 |
device = self.device
|
|
@@ -997,55 +989,48 @@ class MoondreamModel(nn.Module):
|
|
| 997 |
eos_id = self.config.tokenizer.eos_id
|
| 998 |
max_ctx = self.config.text.max_context
|
| 999 |
|
| 1000 |
-
# Normalize pos to a scalar int (supports int, (1,), (B,), (B,1))
|
| 1001 |
if torch.is_tensor(pos):
|
| 1002 |
-
pos = int(pos.max().item())
|
| 1003 |
|
| 1004 |
-
#
|
| 1005 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 1006 |
if pos > 0:
|
| 1007 |
mask[:, :, :, :pos] = True
|
| 1008 |
-
|
| 1009 |
-
# position_ids must be (B,1)
|
| 1010 |
pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
|
| 1011 |
|
| 1012 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1013 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 1014 |
|
| 1015 |
-
# helpers ---------------------------------------------------------
|
| 1016 |
def _argmax01(logits):
|
| 1017 |
-
# logits: (B, bins)
|
| 1018 |
if use_soft_argmax:
|
| 1019 |
probs = torch.softmax(logits, dim=-1)
|
| 1020 |
bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
|
| 1021 |
idx = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
|
| 1022 |
-
return idx
|
| 1023 |
else:
|
| 1024 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1025 |
return idx / float(logits.size(-1) - 1)
|
| 1026 |
|
| 1027 |
with torch.inference_mode():
|
| 1028 |
while alive.any() and (counts < max_objects).any():
|
| 1029 |
-
#
|
| 1030 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 1031 |
-
if x_logits.dim() == 3:
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1036 |
|
| 1037 |
-
# advance attention one step FOR ALIVE ROWS ONLY
|
| 1038 |
mask[alive, :, :, pos] = True
|
| 1039 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1040 |
pos_ids[alive, 0] += 1
|
| 1041 |
-
pos += 1
|
| 1042 |
|
| 1043 |
-
#
|
| 1044 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1045 |
-
if y_logits.dim() == 3:
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
|
| 1049 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1050 |
|
| 1051 |
mask[alive, :, :, pos] = True
|
|
@@ -1054,32 +1039,31 @@ class MoondreamModel(nn.Module):
|
|
| 1054 |
pos += 1
|
| 1055 |
|
| 1056 |
if include_size:
|
| 1057 |
-
#
|
| 1058 |
size_logits = decode_size(hidden, self.region)
|
| 1059 |
w_logits = size_logits[0].squeeze(1)
|
| 1060 |
h_logits = size_logits[1].squeeze(1)
|
| 1061 |
if use_soft_argmax:
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
h_bin = (torch.softmax(h_logits, dim=-1) *
|
|
|
|
| 1065 |
else:
|
| 1066 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1067 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
|
|
|
| 1068 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 1069 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 1070 |
|
| 1071 |
-
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype)
|
| 1072 |
-
size_emb = encode_size(size_in, self.region).unsqueeze(1)
|
| 1073 |
|
| 1074 |
-
# record boxes only for alive rows
|
| 1075 |
for i in range(B):
|
| 1076 |
-
if not alive[i]:
|
| 1077 |
-
continue
|
| 1078 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1079 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1080 |
yt = (y_center[i] - h[i] / 2).item()
|
| 1081 |
yb = (y_center[i] + h[i] / 2).item()
|
| 1082 |
-
# clamp for safety
|
| 1083 |
out[i].append({
|
| 1084 |
"x_min": max(0.0, min(1.0, xl)),
|
| 1085 |
"y_min": max(0.0, min(1.0, yt)),
|
|
@@ -1091,7 +1075,7 @@ class MoondreamModel(nn.Module):
|
|
| 1091 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1092 |
pos_ids[alive, 0] += 1
|
| 1093 |
pos += 1
|
| 1094 |
-
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1095 |
else:
|
| 1096 |
for i in range(B):
|
| 1097 |
if alive[i]:
|
|
@@ -1102,7 +1086,6 @@ class MoondreamModel(nn.Module):
|
|
| 1102 |
pos += 1
|
| 1103 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1104 |
|
| 1105 |
-
# stop only rows that hit eos (or reached max objects)
|
| 1106 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1107 |
counts = counts + ((~finished_now) & alive).to(counts.dtype)
|
| 1108 |
alive &= ~finished_now
|
|
|
|
| 77 |
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 78 |
)
|
| 79 |
|
|
|
|
| 80 |
def update(self, pos_ids, k, v):
|
| 81 |
"""
|
| 82 |
Supports:
|
| 83 |
• Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
|
| 84 |
• 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
|
| 85 |
+
• Scalar: k,v = (B, n_kv_heads, 1, d), pos_ids = ()
|
| 86 |
Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
|
| 87 |
"""
|
| 88 |
kout, vout = self.k_cache, self.v_cache
|
| 89 |
|
|
|
|
| 90 |
if not torch.is_tensor(pos_ids):
|
| 91 |
pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
|
| 92 |
else:
|
|
|
|
| 96 |
raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
|
| 97 |
B, Hkv, q_len, D = k.shape
|
| 98 |
|
| 99 |
+
# Expand cache to batch B if needed (expand-from-1 allowed)
|
| 100 |
if kout.size(0) != B:
|
| 101 |
if kout.size(0) == 1:
|
| 102 |
self.k_cache = kout.expand(B, -1, -1, -1).clone()
|
|
|
|
| 105 |
else:
|
| 106 |
raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
|
| 107 |
|
| 108 |
+
# A) Prefill: pos_ids = (q_len,)
|
| 109 |
if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
|
| 110 |
for i in range(B):
|
| 111 |
+
kout[i, :, pos_ids, :] = k[i]
|
| 112 |
vout[i, :, pos_ids, :] = v[i]
|
| 113 |
return kout, vout
|
| 114 |
|
| 115 |
+
# B) One-step: q_len == 1 and pos_ids per row: (B,) or (B,1)
|
| 116 |
if q_len == 1 and pos_ids.numel() == B:
|
| 117 |
pos_ids = pos_ids.view(B)
|
| 118 |
for i in range(B):
|
|
|
|
| 121 |
vout[i, :, pi, :] = v[i, :, 0, :]
|
| 122 |
return kout, vout
|
| 123 |
|
| 124 |
+
# C) Scalar position for everyone & q_len == 1
|
| 125 |
if pos_ids.dim() == 0 and q_len == 1:
|
| 126 |
pi = int(pos_ids.item())
|
| 127 |
kout[:, :, pi, :] = k[:, :, 0, :]
|
| 128 |
vout[:, :, pi, :] = v[:, :, 0, :]
|
| 129 |
return kout, vout
|
| 130 |
|
| 131 |
+
raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
|
| 132 |
+
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
|
|
|
|
| 210 |
|
| 211 |
|
| 212 |
def _reset_kv_caches(self, batch_size: int = 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
c = self.config.text
|
| 214 |
head_dim = c.dim // c.n_heads
|
| 215 |
for blk in self.text.blocks:
|
|
|
|
| 218 |
shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
|
| 219 |
blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 220 |
blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 221 |
+
|
| 222 |
|
| 223 |
|
| 224 |
def _setup_caches(self):
|
|
|
|
| 583 |
elif not isinstance(image, Image.Image):
|
| 584 |
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 585 |
|
| 586 |
+
# Always start from single-row caches to avoid leftovers from batched runs
|
| 587 |
for blk in self.text.blocks:
|
| 588 |
if blk.kv_cache.k_cache.size(0) != 1:
|
| 589 |
blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
|
| 590 |
blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
|
| 591 |
|
| 592 |
+
|
| 593 |
|
| 594 |
lora = (
|
| 595 |
variant_state_dict(settings["variant"], device=self.device)
|
|
|
|
| 928 |
b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
|
| 929 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 930 |
|
| 931 |
+
|
| 932 |
def _prefill_prompt_batched(self, labels, pos: int, lora=None,
|
| 933 |
temperature: float = 0.0, top_p: float = 0.0):
|
| 934 |
tpl = self.config.tokenizer.templates["detect"]
|
|
|
|
| 945 |
|
| 946 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 947 |
for i, ids in enumerate(rows):
|
| 948 |
+
prompt_ids[i, :ids.numel()] = ids
|
| 949 |
|
| 950 |
prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
|
| 951 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 952 |
|
|
|
|
| 953 |
base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,kv_len)
|
| 954 |
mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,kv_len)
|
| 955 |
|
|
|
|
| 962 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 963 |
|
| 964 |
if temperature == 0.0:
|
| 965 |
+
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
|
| 966 |
else:
|
| 967 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 968 |
probs = self._apply_top_p(probs, top_p)
|
| 969 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 970 |
|
| 971 |
+
pos_end = int(pos + T) # shared scalar end position
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
return last_hidden, next_token, pos_end
|
| 973 |
|
| 974 |
|
| 975 |
+
|
| 976 |
def _generate_points_batched(
|
| 977 |
self,
|
| 978 |
hidden, # (B,1,C)
|
| 979 |
+
next_token, # (B,1)
|
| 980 |
pos, # int or Tensor; normalized below
|
| 981 |
include_size: bool = True,
|
| 982 |
max_objects: int = 50,
|
| 983 |
lora=None,
|
| 984 |
+
use_soft_argmax: bool = True, # reduces bbox jitter
|
| 985 |
):
|
| 986 |
B = hidden.size(0)
|
| 987 |
device = self.device
|
|
|
|
| 989 |
eos_id = self.config.tokenizer.eos_id
|
| 990 |
max_ctx = self.config.text.max_context
|
| 991 |
|
|
|
|
| 992 |
if torch.is_tensor(pos):
|
| 993 |
+
pos = int(pos.max().item())
|
| 994 |
|
| 995 |
+
# SDPA mask and position ids
|
| 996 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 997 |
if pos > 0:
|
| 998 |
mask[:, :, :, :pos] = True
|
|
|
|
|
|
|
| 999 |
pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
|
| 1000 |
|
| 1001 |
alive = torch.ones(B, dtype=torch.bool, device=device)
|
| 1002 |
counts = torch.zeros(B, dtype=torch.int32, device=device)
|
| 1003 |
|
|
|
|
| 1004 |
def _argmax01(logits):
|
| 1005 |
+
# logits: (B, bins) -> normalized index in [0,1]
|
| 1006 |
if use_soft_argmax:
|
| 1007 |
probs = torch.softmax(logits, dim=-1)
|
| 1008 |
bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
|
| 1009 |
idx = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
|
| 1010 |
+
return idx
|
| 1011 |
else:
|
| 1012 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1013 |
return idx / float(logits.size(-1) - 1)
|
| 1014 |
|
| 1015 |
with torch.inference_mode():
|
| 1016 |
while alive.any() and (counts < max_objects).any():
|
| 1017 |
+
# x
|
| 1018 |
+
x_logits = decode_coordinate(hidden, self.region)
|
| 1019 |
+
if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
|
| 1020 |
+
x_center = _argmax01(x_logits)
|
| 1021 |
+
x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1)
|
| 1022 |
+
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1)
|
|
|
|
| 1023 |
|
|
|
|
| 1024 |
mask[alive, :, :, pos] = True
|
| 1025 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
| 1026 |
pos_ids[alive, 0] += 1
|
| 1027 |
+
pos += 1
|
| 1028 |
|
| 1029 |
+
# y
|
| 1030 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1031 |
+
if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
|
| 1032 |
+
y_center = _argmax01(y_logits)
|
| 1033 |
+
y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
|
|
|
|
| 1034 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1035 |
|
| 1036 |
mask[alive, :, :, pos] = True
|
|
|
|
| 1039 |
pos += 1
|
| 1040 |
|
| 1041 |
if include_size:
|
| 1042 |
+
# size
|
| 1043 |
size_logits = decode_size(hidden, self.region)
|
| 1044 |
w_logits = size_logits[0].squeeze(1)
|
| 1045 |
h_logits = size_logits[1].squeeze(1)
|
| 1046 |
if use_soft_argmax:
|
| 1047 |
+
w_bin = (torch.softmax(w_logits, dim=-1) *
|
| 1048 |
+
torch.arange(w_logits.size(-1), device=device)).sum(dim=-1)
|
| 1049 |
+
h_bin = (torch.softmax(h_logits, dim=-1) *
|
| 1050 |
+
torch.arange(h_logits.size(-1), device=device)).sum(dim=-1)
|
| 1051 |
else:
|
| 1052 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1053 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
| 1054 |
+
|
| 1055 |
w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
|
| 1056 |
h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
|
| 1057 |
|
| 1058 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype)
|
| 1059 |
+
size_emb = encode_size(size_in, self.region).unsqueeze(1)
|
| 1060 |
|
|
|
|
| 1061 |
for i in range(B):
|
| 1062 |
+
if not alive[i]: continue
|
|
|
|
| 1063 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1064 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1065 |
yt = (y_center[i] - h[i] / 2).item()
|
| 1066 |
yb = (y_center[i] + h[i] / 2).item()
|
|
|
|
| 1067 |
out[i].append({
|
| 1068 |
"x_min": max(0.0, min(1.0, xl)),
|
| 1069 |
"y_min": max(0.0, min(1.0, yt)),
|
|
|
|
| 1075 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1076 |
pos_ids[alive, 0] += 1
|
| 1077 |
pos += 1
|
| 1078 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1079 |
else:
|
| 1080 |
for i in range(B):
|
| 1081 |
if alive[i]:
|
|
|
|
| 1086 |
pos += 1
|
| 1087 |
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1088 |
|
|
|
|
| 1089 |
finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
|
| 1090 |
counts = counts + ((~finished_now) & alive).to(counts.dtype)
|
| 1091 |
alive &= ~finished_now
|