Update moondream.py
Browse files- moondream.py +99 -106
moondream.py
CHANGED
|
@@ -64,39 +64,35 @@ class EncodedImage:
|
|
| 64 |
pos: int
|
| 65 |
caches: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 66 |
|
| 67 |
-
|
| 68 |
class KVCache(nn.Module):
|
| 69 |
-
|
| 70 |
def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
|
| 71 |
super().__init__()
|
| 72 |
cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
|
| 73 |
-
self.register_buffer(
|
| 74 |
-
|
| 75 |
-
)
|
| 76 |
-
self.register_buffer(
|
| 77 |
-
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 78 |
-
)
|
| 79 |
|
| 80 |
def update(self, pos_ids, k, v):
|
| 81 |
"""
|
| 82 |
Supports:
|
| 83 |
• Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
|
| 84 |
• 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
|
| 85 |
-
•
|
| 86 |
Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
|
| 87 |
"""
|
| 88 |
kout, vout = self.k_cache, self.v_cache
|
| 89 |
-
|
|
|
|
| 90 |
if not torch.is_tensor(pos_ids):
|
| 91 |
pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
|
| 92 |
else:
|
| 93 |
pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
|
| 94 |
-
|
| 95 |
if k.dim() != 4 or v.dim() != 4:
|
| 96 |
-
raise RuntimeError(f"KV update expects k,v
|
|
|
|
| 97 |
B, Hkv, q_len, D = k.shape
|
| 98 |
-
|
| 99 |
-
#
|
| 100 |
if kout.size(0) != B:
|
| 101 |
if kout.size(0) == 1:
|
| 102 |
self.k_cache = kout.expand(B, -1, -1, -1).clone()
|
|
@@ -104,34 +100,31 @@ class KVCache(nn.Module):
|
|
| 104 |
kout, vout = self.k_cache, self.v_cache
|
| 105 |
else:
|
| 106 |
raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
|
| 107 |
-
|
| 108 |
-
# A
|
| 109 |
if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
|
| 110 |
for i in range(B):
|
| 111 |
kout[i, :, pos_ids, :] = k[i]
|
| 112 |
vout[i, :, pos_ids, :] = v[i]
|
| 113 |
return kout, vout
|
| 114 |
-
|
| 115 |
-
# B
|
| 116 |
-
if q_len == 1 and pos_ids.numel() == B:
|
| 117 |
pos_ids = pos_ids.view(B)
|
| 118 |
for i in range(B):
|
| 119 |
pi = int(pos_ids[i].item())
|
| 120 |
kout[i, :, pi, :] = k[i, :, 0, :]
|
| 121 |
vout[i, :, pi, :] = v[i, :, 0, :]
|
| 122 |
return kout, vout
|
| 123 |
-
|
| 124 |
-
# C
|
| 125 |
if pos_ids.dim() == 0 and q_len == 1:
|
| 126 |
pi = int(pos_ids.item())
|
| 127 |
kout[:, :, pi, :] = k[:, :, 0, :]
|
| 128 |
vout[:, :, pi, :] = v[:, :, 0, :]
|
| 129 |
return kout, vout
|
| 130 |
-
|
| 131 |
-
raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
|
|
@@ -214,11 +207,12 @@ class MoondreamModel(nn.Module):
|
|
| 214 |
head_dim = c.dim // c.n_heads
|
| 215 |
for blk in self.text.blocks:
|
| 216 |
device = blk.kv_cache.k_cache.device
|
| 217 |
-
dtype
|
| 218 |
-
shape
|
| 219 |
blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 220 |
blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 221 |
|
|
|
|
| 222 |
|
| 223 |
|
| 224 |
def _setup_caches(self):
|
|
@@ -575,52 +569,41 @@ class MoondreamModel(nn.Module):
|
|
| 575 |
image: Union[Image.Image, EncodedImage],
|
| 576 |
settings: Optional[ImageEncodingSettings] = None,
|
| 577 |
) -> EncodedImage:
|
| 578 |
-
# Always start from single-row caches; avoids leftovers from batched runs
|
| 579 |
self._setup_caches()
|
| 580 |
-
|
| 581 |
-
if isinstance(image, EncodedImage):
|
| 582 |
-
return image
|
| 583 |
-
elif not isinstance(image, Image.Image):
|
| 584 |
-
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 585 |
-
|
| 586 |
-
# Always start from single-row caches to avoid leftovers from batched runs
|
| 587 |
for blk in self.text.blocks:
|
| 588 |
if blk.kv_cache.k_cache.size(0) != 1:
|
| 589 |
blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
|
| 590 |
blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
|
| 591 |
-
|
| 592 |
-
|
| 593 |
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
|
|
|
|
|
|
| 599 |
|
| 600 |
with torch.inference_mode():
|
| 601 |
img_emb = self._run_vision_encoder(image)
|
| 602 |
-
bos_emb = text_encoder(
|
| 603 |
-
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
|
| 604 |
-
)
|
| 605 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 609 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 610 |
-
|
| 611 |
-
|
| 612 |
|
| 613 |
return EncodedImage(
|
| 614 |
pos=inputs_embeds.size(1),
|
| 615 |
caches=[
|
| 616 |
(
|
| 617 |
-
b.kv_cache.k_cache[:, :, :
|
| 618 |
-
b.kv_cache.v_cache[:, :, :
|
| 619 |
)
|
| 620 |
for b in self.text.blocks
|
| 621 |
],
|
| 622 |
)
|
| 623 |
|
|
|
|
| 624 |
def query(
|
| 625 |
self,
|
| 626 |
image: Optional[Union[Image.Image, EncodedImage]] = None,
|
|
@@ -913,22 +896,18 @@ class MoondreamModel(nn.Module):
|
|
| 913 |
|
| 914 |
|
| 915 |
def _load_encoded_image_batched(self, encoded_image, batch_size: int):
|
| 916 |
-
"""
|
| 917 |
-
Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
|
| 918 |
-
"""
|
| 919 |
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
| 920 |
T = k.size(2)
|
| 921 |
-
# Allocate new [B, n_kv_heads, T_max, head_dim] caches if needed
|
| 922 |
if b.kv_cache.k_cache.size(0) != batch_size:
|
| 923 |
new_k = b.kv_cache.k_cache.new_zeros((batch_size,) + b.kv_cache.k_cache.shape[1:])
|
| 924 |
new_v = b.kv_cache.v_cache.new_zeros((batch_size,) + b.kv_cache.v_cache.shape[1:])
|
| 925 |
b.kv_cache.k_cache = new_k
|
| 926 |
b.kv_cache.v_cache = new_v
|
| 927 |
-
# Copy current prefix from the encoded image into all B rows
|
| 928 |
b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
|
| 929 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 930 |
|
| 931 |
|
|
|
|
| 932 |
def _prefill_prompt_batched(self, labels, pos: int, lora=None,
|
| 933 |
temperature: float = 0.0, top_p: float = 0.0):
|
| 934 |
tpl = self.config.tokenizer.templates["detect"]
|
|
@@ -945,34 +924,35 @@ class MoondreamModel(nn.Module):
|
|
| 945 |
|
| 946 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 947 |
for i, ids in enumerate(rows):
|
| 948 |
-
prompt_ids[i, :ids.numel()] = ids
|
| 949 |
|
| 950 |
-
prompt_emb = text_encoder(prompt_ids, self.text)
|
| 951 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 952 |
|
| 953 |
-
base = self.attn_mask[:, :, pos:pos+T, :]
|
| 954 |
-
mask = base.expand(B, -1, -1, -1).contiguous()
|
| 955 |
-
|
| 956 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
| 957 |
-
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
|
| 958 |
-
logits_BTV = lm_head(hidden_BTC, self.text)
|
| 959 |
|
| 960 |
-
|
|
|
|
|
|
|
|
|
|
| 961 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 962 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 963 |
|
| 964 |
if temperature == 0.0:
|
| 965 |
-
next_token = last_logits.argmax(dim=-1, keepdim=True)
|
| 966 |
else:
|
| 967 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 968 |
probs = self._apply_top_p(probs, top_p)
|
| 969 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 970 |
|
| 971 |
-
pos_end = int(pos + T)
|
| 972 |
return last_hidden, next_token, pos_end
|
| 973 |
|
| 974 |
|
| 975 |
|
|
|
|
| 976 |
def _generate_points_batched(
|
| 977 |
self,
|
| 978 |
hidden, # (B,1,C)
|
|
@@ -989,11 +969,11 @@ class MoondreamModel(nn.Module):
|
|
| 989 |
eos_id = self.config.tokenizer.eos_id
|
| 990 |
max_ctx = self.config.text.max_context
|
| 991 |
|
| 992 |
-
# Normalize pos to a scalar int
|
| 993 |
if torch.is_tensor(pos):
|
| 994 |
pos = int(pos.max().item())
|
| 995 |
|
| 996 |
-
# 4-D mask: (B,
|
| 997 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 998 |
if pos > 0:
|
| 999 |
mask[:, :, :, :pos] = True
|
|
@@ -1004,32 +984,48 @@ class MoondreamModel(nn.Module):
|
|
| 1004 |
|
| 1005 |
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 1006 |
"""
|
| 1007 |
-
logits: (..., bins) ->
|
| 1008 |
-
Accepts (B,1,bins), (B,bins), or (bins,)
|
| 1009 |
"""
|
| 1010 |
-
# Canonicalize to (B,
|
| 1011 |
-
if logits.dim() == 3:
|
| 1012 |
logits = logits.squeeze(1)
|
| 1013 |
-
elif logits.dim() == 1:
|
| 1014 |
logits = logits.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
| 1015 |
|
| 1016 |
if use_soft_argmax:
|
| 1017 |
probs = torch.softmax(logits, dim=-1)
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
expbin = (probs * bins_idx).sum(dim=-1)
|
| 1021 |
return expbin / float(probs.size(-1) - 1)
|
| 1022 |
else:
|
| 1023 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1024 |
return idx / float(logits.size(-1) - 1)
|
| 1025 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1026 |
with torch.inference_mode():
|
| 1027 |
while alive.any() and (counts < max_objects).any():
|
| 1028 |
# ---- x ------------------------------------------------------
|
| 1029 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 1030 |
-
x_center = _argmax01(x_logits)
|
| 1031 |
-
|
| 1032 |
-
|
|
|
|
| 1033 |
|
| 1034 |
mask[alive, :, :, pos] = True
|
| 1035 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
|
@@ -1037,9 +1033,10 @@ class MoondreamModel(nn.Module):
|
|
| 1037 |
pos += 1
|
| 1038 |
|
| 1039 |
# ---- y ------------------------------------------------------
|
| 1040 |
-
y_logits = decode_coordinate(hidden, self.region)
|
| 1041 |
-
y_center = _argmax01(y_logits)
|
| 1042 |
-
|
|
|
|
| 1043 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1044 |
|
| 1045 |
mask[alive, :, :, pos] = True
|
|
@@ -1049,22 +1046,24 @@ class MoondreamModel(nn.Module):
|
|
| 1049 |
|
| 1050 |
if include_size:
|
| 1051 |
# ---- size ----------------------------------------------
|
| 1052 |
-
size_logits = decode_size(hidden, self.region)
|
| 1053 |
w_logits, h_logits = size_logits
|
| 1054 |
|
| 1055 |
-
# Canonicalize to (B,
|
| 1056 |
if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
|
| 1057 |
if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
|
| 1058 |
if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
|
| 1059 |
if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
|
|
|
|
|
|
|
| 1060 |
|
| 1061 |
if use_soft_argmax:
|
| 1062 |
w_probs = torch.softmax(w_logits, dim=-1)
|
| 1063 |
h_probs = torch.softmax(h_logits, dim=-1)
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
w_bin = (w_probs *
|
| 1067 |
-
h_bin = (h_probs *
|
| 1068 |
else:
|
| 1069 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1070 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
|
@@ -1075,8 +1074,11 @@ class MoondreamModel(nn.Module):
|
|
| 1075 |
w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
|
| 1076 |
h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
|
| 1077 |
|
| 1078 |
-
|
| 1079 |
-
|
|
|
|
|
|
|
|
|
|
| 1080 |
|
| 1081 |
# record boxes only for alive rows
|
| 1082 |
for i in range(B):
|
|
@@ -1097,11 +1099,11 @@ class MoondreamModel(nn.Module):
|
|
| 1097 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1098 |
pos_ids[alive, 0] += 1
|
| 1099 |
pos += 1
|
| 1100 |
-
next_tok = logits.argmax(dim=-1).squeeze(-1)
|
| 1101 |
else:
|
| 1102 |
for i in range(B):
|
| 1103 |
if alive[i]:
|
| 1104 |
-
out[i].append({"x": x_center[i]
|
| 1105 |
mask[alive, :, :, pos] = True
|
| 1106 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1107 |
pos_ids[alive, 0] += 1
|
|
@@ -1120,16 +1122,8 @@ class MoondreamModel(nn.Module):
|
|
| 1120 |
|
| 1121 |
|
| 1122 |
|
|
|
|
| 1123 |
def detect_multi(self, image, objects, settings=None):
|
| 1124 |
-
"""
|
| 1125 |
-
Parallel multi-label detection.
|
| 1126 |
-
Args:
|
| 1127 |
-
image: PIL.Image or EncodedImage
|
| 1128 |
-
objects: list[str], e.g. ["person", "car"]
|
| 1129 |
-
settings: Optional[ObjectSamplingSettings], honors "max_objects" and "variant"
|
| 1130 |
-
Returns:
|
| 1131 |
-
{"objects": {label: [box_dict, ...]}}
|
| 1132 |
-
"""
|
| 1133 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1134 |
raise NotImplementedError("Model does not support object detection.")
|
| 1135 |
settings = settings or {}
|
|
@@ -1160,9 +1154,8 @@ class MoondreamModel(nn.Module):
|
|
| 1160 |
d["label"] = lab
|
| 1161 |
res[lab] = lst
|
| 1162 |
|
| 1163 |
-
# IMPORTANT: restore caches to B=1 so future calls
|
| 1164 |
self._reset_kv_caches(1)
|
| 1165 |
-
|
| 1166 |
return {"objects": res}
|
| 1167 |
|
| 1168 |
|
|
|
|
| 64 |
pos: int
|
| 65 |
caches: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 66 |
|
|
|
|
| 67 |
class KVCache(nn.Module):
|
|
|
|
| 68 |
def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
|
| 69 |
super().__init__()
|
| 70 |
cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
|
| 71 |
+
self.register_buffer("k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
|
| 72 |
+
self.register_buffer("v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def update(self, pos_ids, k, v):
|
| 75 |
"""
|
| 76 |
Supports:
|
| 77 |
• Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
|
| 78 |
• 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
|
| 79 |
+
• Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
|
| 80 |
Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
|
| 81 |
"""
|
| 82 |
kout, vout = self.k_cache, self.v_cache
|
| 83 |
+
|
| 84 |
+
# Normalize pos_ids
|
| 85 |
if not torch.is_tensor(pos_ids):
|
| 86 |
pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
|
| 87 |
else:
|
| 88 |
pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
|
| 89 |
+
|
| 90 |
if k.dim() != 4 or v.dim() != 4:
|
| 91 |
+
raise RuntimeError(f"KV update expects 4D k,v. Got k={tuple(k.shape)} v={tuple(v.shape)}")
|
| 92 |
+
|
| 93 |
B, Hkv, q_len, D = k.shape
|
| 94 |
+
|
| 95 |
+
# Ensure cache batch matches B (expand-from-1 allowed)
|
| 96 |
if kout.size(0) != B:
|
| 97 |
if kout.size(0) == 1:
|
| 98 |
self.k_cache = kout.expand(B, -1, -1, -1).clone()
|
|
|
|
| 100 |
kout, vout = self.k_cache, self.v_cache
|
| 101 |
else:
|
| 102 |
raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
|
| 103 |
+
|
| 104 |
+
# Case A: PREFILL — vector of length q_len (same for all B rows)
|
| 105 |
if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
|
| 106 |
for i in range(B):
|
| 107 |
kout[i, :, pos_ids, :] = k[i]
|
| 108 |
vout[i, :, pos_ids, :] = v[i]
|
| 109 |
return kout, vout
|
| 110 |
+
|
| 111 |
+
# Case B: 1-STEP — q_len == 1 with (B,) or (B,1) per-row positions
|
| 112 |
+
if q_len == 1 and (pos_ids.numel() == B):
|
| 113 |
pos_ids = pos_ids.view(B)
|
| 114 |
for i in range(B):
|
| 115 |
pi = int(pos_ids[i].item())
|
| 116 |
kout[i, :, pi, :] = k[i, :, 0, :]
|
| 117 |
vout[i, :, pi, :] = v[i, :, 0, :]
|
| 118 |
return kout, vout
|
| 119 |
+
|
| 120 |
+
# Case C: scalar + 1-step
|
| 121 |
if pos_ids.dim() == 0 and q_len == 1:
|
| 122 |
pi = int(pos_ids.item())
|
| 123 |
kout[:, :, pi, :] = k[:, :, 0, :]
|
| 124 |
vout[:, :, pi, :] = v[:, :, 0, :]
|
| 125 |
return kout, vout
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
|
| 128 |
|
| 129 |
|
| 130 |
|
|
|
|
| 207 |
head_dim = c.dim // c.n_heads
|
| 208 |
for blk in self.text.blocks:
|
| 209 |
device = blk.kv_cache.k_cache.device
|
| 210 |
+
dtype = blk.kv_cache.k_cache.dtype
|
| 211 |
+
shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
|
| 212 |
blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 213 |
blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
|
| 214 |
|
| 215 |
+
|
| 216 |
|
| 217 |
|
| 218 |
def _setup_caches(self):
|
|
|
|
| 569 |
image: Union[Image.Image, EncodedImage],
|
| 570 |
settings: Optional[ImageEncodingSettings] = None,
|
| 571 |
) -> EncodedImage:
|
| 572 |
+
# Always start from single-row caches; avoids leftovers from batched runs
|
| 573 |
self._setup_caches()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
for blk in self.text.blocks:
|
| 575 |
if blk.kv_cache.k_cache.size(0) != 1:
|
| 576 |
blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
|
| 577 |
blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
|
|
|
|
|
|
|
| 578 |
|
| 579 |
+
if isinstance(image, EncodedImage):
|
| 580 |
+
return image
|
| 581 |
+
if not isinstance(image, Image.Image):
|
| 582 |
+
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 583 |
+
|
| 584 |
+
lora = (variant_state_dict(settings["variant"], device=self.device)
|
| 585 |
+
if settings is not None and "variant" in settings else None)
|
| 586 |
|
| 587 |
with torch.inference_mode():
|
| 588 |
img_emb = self._run_vision_encoder(image)
|
| 589 |
+
bos_emb = text_encoder(torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text)
|
|
|
|
|
|
|
| 590 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 591 |
+
mask = self.attn_mask[:, :, :inputs_embeds.size(1), :]
|
| 592 |
+
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long, device=self.device)
|
|
|
|
| 593 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
|
|
|
|
|
|
| 594 |
|
| 595 |
return EncodedImage(
|
| 596 |
pos=inputs_embeds.size(1),
|
| 597 |
caches=[
|
| 598 |
(
|
| 599 |
+
b.kv_cache.k_cache[:, :, :inputs_embeds.size(1), :].clone(),
|
| 600 |
+
b.kv_cache.v_cache[:, :, :inputs_embeds.size(1), :].clone(),
|
| 601 |
)
|
| 602 |
for b in self.text.blocks
|
| 603 |
],
|
| 604 |
)
|
| 605 |
|
| 606 |
+
|
| 607 |
def query(
|
| 608 |
self,
|
| 609 |
image: Optional[Union[Image.Image, EncodedImage]] = None,
|
|
|
|
| 896 |
|
| 897 |
|
| 898 |
def _load_encoded_image_batched(self, encoded_image, batch_size: int):
|
|
|
|
|
|
|
|
|
|
| 899 |
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
| 900 |
T = k.size(2)
|
|
|
|
| 901 |
if b.kv_cache.k_cache.size(0) != batch_size:
|
| 902 |
new_k = b.kv_cache.k_cache.new_zeros((batch_size,) + b.kv_cache.k_cache.shape[1:])
|
| 903 |
new_v = b.kv_cache.v_cache.new_zeros((batch_size,) + b.kv_cache.v_cache.shape[1:])
|
| 904 |
b.kv_cache.k_cache = new_k
|
| 905 |
b.kv_cache.v_cache = new_v
|
|
|
|
| 906 |
b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
|
| 907 |
b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
|
| 908 |
|
| 909 |
|
| 910 |
+
|
| 911 |
def _prefill_prompt_batched(self, labels, pos: int, lora=None,
|
| 912 |
temperature: float = 0.0, top_p: float = 0.0):
|
| 913 |
tpl = self.config.tokenizer.templates["detect"]
|
|
|
|
| 924 |
|
| 925 |
prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
|
| 926 |
for i, ids in enumerate(rows):
|
| 927 |
+
prompt_ids[i, : ids.numel()] = ids
|
| 928 |
|
| 929 |
+
prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
|
| 930 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 931 |
|
| 932 |
+
base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
|
| 933 |
+
mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
|
|
|
|
| 934 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
|
|
|
|
|
|
| 935 |
|
| 936 |
+
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
|
| 937 |
+
logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
|
| 938 |
+
|
| 939 |
+
idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
|
| 940 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
|
| 941 |
last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
|
| 942 |
|
| 943 |
if temperature == 0.0:
|
| 944 |
+
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
|
| 945 |
else:
|
| 946 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 947 |
probs = self._apply_top_p(probs, top_p)
|
| 948 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 949 |
|
| 950 |
+
pos_end = int(pos + T) # shared next-slot
|
| 951 |
return last_hidden, next_token, pos_end
|
| 952 |
|
| 953 |
|
| 954 |
|
| 955 |
+
|
| 956 |
def _generate_points_batched(
|
| 957 |
self,
|
| 958 |
hidden, # (B,1,C)
|
|
|
|
| 969 |
eos_id = self.config.tokenizer.eos_id
|
| 970 |
max_ctx = self.config.text.max_context
|
| 971 |
|
| 972 |
+
# Normalize pos to a scalar int
|
| 973 |
if torch.is_tensor(pos):
|
| 974 |
pos = int(pos.max().item())
|
| 975 |
|
| 976 |
+
# 4-D mask: (B,1,1,K) and per-row pos ids (B,1)
|
| 977 |
mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
|
| 978 |
if pos > 0:
|
| 979 |
mask[:, :, :, :pos] = True
|
|
|
|
| 984 |
|
| 985 |
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 986 |
"""
|
| 987 |
+
logits: (..., bins) -> (B,) in [0,1]
|
| 988 |
+
Accepts (B,1,bins), (B,bins), or (bins,)
|
| 989 |
"""
|
| 990 |
+
# Canonicalize to (B,bins)
|
| 991 |
+
if logits.dim() == 3: # (B,1,bins)
|
| 992 |
logits = logits.squeeze(1)
|
| 993 |
+
elif logits.dim() == 1: # (bins,)
|
| 994 |
logits = logits.unsqueeze(0)
|
| 995 |
+
# If batch accidentally collapsed to 1, expand to B so downstream indexing is safe.
|
| 996 |
+
if logits.size(0) == 1 and B > 1:
|
| 997 |
+
logits = logits.expand(B, -1)
|
| 998 |
|
| 999 |
if use_soft_argmax:
|
| 1000 |
probs = torch.softmax(logits, dim=-1)
|
| 1001 |
+
bins = torch.arange(probs.size(-1), device=probs.device, dtype=torch.float32)
|
| 1002 |
+
expbin = (probs * bins).sum(dim=-1)
|
|
|
|
| 1003 |
return expbin / float(probs.size(-1) - 1)
|
| 1004 |
else:
|
| 1005 |
idx = logits.argmax(dim=-1).to(torch.float32)
|
| 1006 |
return idx / float(logits.size(-1) - 1)
|
| 1007 |
|
| 1008 |
+
def _ensure_b(vec: torch.Tensor) -> torch.Tensor:
|
| 1009 |
+
"""
|
| 1010 |
+
Make sure 1D tensors are length-B for safe indexing.
|
| 1011 |
+
Accepts scalar/(), (1,), (B,), returns (B,)
|
| 1012 |
+
"""
|
| 1013 |
+
if vec.dim() == 0:
|
| 1014 |
+
return vec.repeat(B)
|
| 1015 |
+
if vec.dim() == 1 and vec.numel() == 1 and B > 1:
|
| 1016 |
+
return vec.repeat(B)
|
| 1017 |
+
if vec.dim() == 1 and vec.numel() == B:
|
| 1018 |
+
return vec
|
| 1019 |
+
raise RuntimeError(f"Expected (B,) vec, got shape {tuple(vec.shape)} for B={B}")
|
| 1020 |
+
|
| 1021 |
with torch.inference_mode():
|
| 1022 |
while alive.any() and (counts < max_objects).any():
|
| 1023 |
# ---- x ------------------------------------------------------
|
| 1024 |
+
x_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b) or (b,)
|
| 1025 |
+
x_center = _argmax01(x_logits) # (B,)
|
| 1026 |
+
x_center = _ensure_b(x_center) # force len B
|
| 1027 |
+
x_in = x_center.to(dtype=hidden.dtype).unsqueeze(-1) # (B,1)
|
| 1028 |
+
x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1029 |
|
| 1030 |
mask[alive, :, :, pos] = True
|
| 1031 |
logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
|
|
|
|
| 1033 |
pos += 1
|
| 1034 |
|
| 1035 |
# ---- y ------------------------------------------------------
|
| 1036 |
+
y_logits = decode_coordinate(hidden, self.region)
|
| 1037 |
+
y_center = _argmax01(y_logits) # (B,)
|
| 1038 |
+
y_center = _ensure_b(y_center)
|
| 1039 |
+
y_in = y_center.to(dtype=hidden.dtype).unsqueeze(-1)
|
| 1040 |
y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
|
| 1041 |
|
| 1042 |
mask[alive, :, :, pos] = True
|
|
|
|
| 1046 |
|
| 1047 |
if include_size:
|
| 1048 |
# ---- size ----------------------------------------------
|
| 1049 |
+
size_logits = decode_size(hidden, self.region) # tuple: (w_logits, h_logits)
|
| 1050 |
w_logits, h_logits = size_logits
|
| 1051 |
|
| 1052 |
+
# Canonicalize to (B,bins); expand if batch collapsed
|
| 1053 |
if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
|
| 1054 |
if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
|
| 1055 |
if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
|
| 1056 |
if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
|
| 1057 |
+
if w_logits.size(0) == 1 and B > 1: w_logits = w_logits.expand(B, -1)
|
| 1058 |
+
if h_logits.size(0) == 1 and B > 1: h_logits = h_logits.expand(B, -1)
|
| 1059 |
|
| 1060 |
if use_soft_argmax:
|
| 1061 |
w_probs = torch.softmax(w_logits, dim=-1)
|
| 1062 |
h_probs = torch.softmax(h_logits, dim=-1)
|
| 1063 |
+
bins_w = torch.arange(w_probs.size(-1), device=device, dtype=torch.float32)
|
| 1064 |
+
bins_h = torch.arange(h_probs.size(-1), device=device, dtype=torch.float32)
|
| 1065 |
+
w_bin = (w_probs * bins_w).sum(dim=-1) # (B,)
|
| 1066 |
+
h_bin = (h_probs * bins_h).sum(dim=-1) # (B,)
|
| 1067 |
else:
|
| 1068 |
w_bin = w_logits.argmax(dim=-1).to(torch.float32)
|
| 1069 |
h_bin = h_logits.argmax(dim=-1).to(torch.float32)
|
|
|
|
| 1074 |
w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
|
| 1075 |
h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
|
| 1076 |
|
| 1077 |
+
# enforce (B,)
|
| 1078 |
+
w = _ensure_b(w); h = _ensure_b(h)
|
| 1079 |
+
|
| 1080 |
+
size_in = torch.stack([w, h], dim=1).to(dtype=hidden.dtype) # (B,2)
|
| 1081 |
+
size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
|
| 1082 |
|
| 1083 |
# record boxes only for alive rows
|
| 1084 |
for i in range(B):
|
|
|
|
| 1099 |
logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
|
| 1100 |
pos_ids[alive, 0] += 1
|
| 1101 |
pos += 1
|
| 1102 |
+
next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
|
| 1103 |
else:
|
| 1104 |
for i in range(B):
|
| 1105 |
if alive[i]:
|
| 1106 |
+
out[i].append({"x": float(x_center[i]), "y": float(y_center[i])})
|
| 1107 |
mask[alive, :, :, pos] = True
|
| 1108 |
logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
|
| 1109 |
pos_ids[alive, 0] += 1
|
|
|
|
| 1122 |
|
| 1123 |
|
| 1124 |
|
| 1125 |
+
|
| 1126 |
def detect_multi(self, image, objects, settings=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1127 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1128 |
raise NotImplementedError("Model does not support object detection.")
|
| 1129 |
settings = settings or {}
|
|
|
|
| 1154 |
d["label"] = lab
|
| 1155 |
res[lab] = lst
|
| 1156 |
|
| 1157 |
+
# IMPORTANT: restore caches to B=1 so future calls are safe
|
| 1158 |
self._reset_kv_caches(1)
|
|
|
|
| 1159 |
return {"objects": res}
|
| 1160 |
|
| 1161 |
|