Upload folder using huggingface_hub
Browse files- moondream.py +54 -2
moondream.py
CHANGED
|
@@ -216,6 +216,7 @@ class MoondreamModel(nn.Module):
|
|
| 216 |
def _prefill_prompt(
|
| 217 |
self, kv_cache: torch.Tensor, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
|
| 218 |
):
|
|
|
|
| 219 |
with torch.no_grad():
|
| 220 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
| 221 |
hidden = self.ops["prefill"](
|
|
@@ -257,15 +258,40 @@ class MoondreamModel(nn.Module):
|
|
| 257 |
)
|
| 258 |
|
| 259 |
def generator(next_token, pos):
|
|
|
|
|
|
|
|
|
|
| 260 |
generated_tokens = 0
|
| 261 |
|
|
|
|
|
|
|
|
|
|
| 262 |
while (
|
| 263 |
next_token_id := next_token.item()
|
| 264 |
) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
with torch.no_grad():
|
| 268 |
next_emb = text_encoder(next_token, self.text)
|
|
|
|
| 269 |
logits, _, kv_cache_update = self.ops["decode_one_token"](
|
| 270 |
next_emb, kv_cache, pos, self.text, self.config.text
|
| 271 |
)
|
|
@@ -273,8 +299,22 @@ class MoondreamModel(nn.Module):
|
|
| 273 |
kv_cache_update
|
| 274 |
)
|
| 275 |
pos += 1
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
generated_tokens += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
return generator(next_token, pos)
|
| 280 |
|
|
@@ -617,3 +657,15 @@ class MoondreamModel(nn.Module):
|
|
| 617 |
)
|
| 618 |
|
| 619 |
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
def _prefill_prompt(
|
| 217 |
self, kv_cache: torch.Tensor, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
|
| 218 |
):
|
| 219 |
+
|
| 220 |
with torch.no_grad():
|
| 221 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
| 222 |
hidden = self.ops["prefill"](
|
|
|
|
| 258 |
)
|
| 259 |
|
| 260 |
def generator(next_token, pos):
|
| 261 |
+
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
|
| 262 |
+
mask[:, :, :pos] = 1
|
| 263 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
| 264 |
generated_tokens = 0
|
| 265 |
|
| 266 |
+
token_cache = []
|
| 267 |
+
print_len = 0
|
| 268 |
+
|
| 269 |
while (
|
| 270 |
next_token_id := next_token.item()
|
| 271 |
) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
|
| 272 |
+
# Add token to our cache
|
| 273 |
+
token_cache.append(next_token_id)
|
| 274 |
+
|
| 275 |
+
# Decode all tokens collected so far
|
| 276 |
+
text = self.tokenizer.decode(token_cache)
|
| 277 |
+
|
| 278 |
+
# After a newline, we flush the cache completely
|
| 279 |
+
if text.endswith("\n"):
|
| 280 |
+
printable_text = text[print_len:]
|
| 281 |
+
token_cache = []
|
| 282 |
+
print_len = 0
|
| 283 |
+
if printable_text:
|
| 284 |
+
yield printable_text
|
| 285 |
+
# If the last token is a CJK character, we can safely print it
|
| 286 |
+
elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
|
| 287 |
+
printable_text = text[print_len:]
|
| 288 |
+
print_len += len(printable_text)
|
| 289 |
+
if printable_text:
|
| 290 |
+
yield printable_text
|
| 291 |
|
| 292 |
with torch.no_grad():
|
| 293 |
next_emb = text_encoder(next_token, self.text)
|
| 294 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 295 |
logits, _, kv_cache_update = self.ops["decode_one_token"](
|
| 296 |
next_emb, kv_cache, pos, self.text, self.config.text
|
| 297 |
)
|
|
|
|
| 299 |
kv_cache_update
|
| 300 |
)
|
| 301 |
pos += 1
|
| 302 |
+
|
| 303 |
+
if temperature == 0:
|
| 304 |
+
next_token = torch.argmax(logits, dim=-1) # (1, 1)
|
| 305 |
+
else:
|
| 306 |
+
probs = torch.softmax(logits / temperature, dim=-1) # (1, V)
|
| 307 |
+
probs = self._apply_top_p(probs, top_p)
|
| 308 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1, 1)
|
| 309 |
+
|
| 310 |
generated_tokens += 1
|
| 311 |
+
|
| 312 |
+
# Flush any remaining text in the cache
|
| 313 |
+
if token_cache:
|
| 314 |
+
text = self.tokenizer.decode(token_cache)
|
| 315 |
+
printable_text = text[print_len:]
|
| 316 |
+
if printable_text:
|
| 317 |
+
yield printable_text
|
| 318 |
|
| 319 |
return generator(next_token, pos)
|
| 320 |
|
|
|
|
| 657 |
)
|
| 658 |
|
| 659 |
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
| 660 |
+
|
| 661 |
+
def _is_cjk_char(cp):
|
| 662 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 663 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 664 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 665 |
+
if (
|
| 666 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 667 |
+
or (cp >= 0x3400 and cp <= 0x4DBF)
|
| 668 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
| 669 |
+
):
|
| 670 |
+
return True
|
| 671 |
+
return False
|