Jooju2872 commited on
Commit
c3c60cc
·
verified ·
1 Parent(s): a5ea795

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. moondream.py +54 -2
moondream.py CHANGED
@@ -216,6 +216,7 @@ class MoondreamModel(nn.Module):
216
  def _prefill_prompt(
217
  self, kv_cache: torch.Tensor, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
218
  ):
 
219
  with torch.no_grad():
220
  prompt_emb = text_encoder(prompt_tokens, self.text)
221
  hidden = self.ops["prefill"](
@@ -257,15 +258,40 @@ class MoondreamModel(nn.Module):
257
  )
258
 
259
  def generator(next_token, pos):
 
 
 
260
  generated_tokens = 0
261
 
 
 
 
262
  while (
263
  next_token_id := next_token.item()
264
  ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
265
- yield self.tokenizer.decode([next_token_id])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  with torch.no_grad():
268
  next_emb = text_encoder(next_token, self.text)
 
269
  logits, _, kv_cache_update = self.ops["decode_one_token"](
270
  next_emb, kv_cache, pos, self.text, self.config.text
271
  )
@@ -273,8 +299,22 @@ class MoondreamModel(nn.Module):
273
  kv_cache_update
274
  )
275
  pos += 1
276
- next_token = torch.argmax(logits, dim=-1)
 
 
 
 
 
 
 
277
  generated_tokens += 1
 
 
 
 
 
 
 
278
 
279
  return generator(next_token, pos)
280
 
@@ -617,3 +657,15 @@ class MoondreamModel(nn.Module):
617
  )
618
 
619
  return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def _prefill_prompt(
217
  self, kv_cache: torch.Tensor, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
218
  ):
219
+
220
  with torch.no_grad():
221
  prompt_emb = text_encoder(prompt_tokens, self.text)
222
  hidden = self.ops["prefill"](
 
258
  )
259
 
260
  def generator(next_token, pos):
261
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
262
+ mask[:, :, :pos] = 1
263
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
264
  generated_tokens = 0
265
 
266
+ token_cache = []
267
+ print_len = 0
268
+
269
  while (
270
  next_token_id := next_token.item()
271
  ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
272
+ # Add token to our cache
273
+ token_cache.append(next_token_id)
274
+
275
+ # Decode all tokens collected so far
276
+ text = self.tokenizer.decode(token_cache)
277
+
278
+ # After a newline, we flush the cache completely
279
+ if text.endswith("\n"):
280
+ printable_text = text[print_len:]
281
+ token_cache = []
282
+ print_len = 0
283
+ if printable_text:
284
+ yield printable_text
285
+ # If the last token is a CJK character, we can safely print it
286
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
287
+ printable_text = text[print_len:]
288
+ print_len += len(printable_text)
289
+ if printable_text:
290
+ yield printable_text
291
 
292
  with torch.no_grad():
293
  next_emb = text_encoder(next_token, self.text)
294
+ mask[:, :, pos], pos_ids[0] = 1, pos
295
  logits, _, kv_cache_update = self.ops["decode_one_token"](
296
  next_emb, kv_cache, pos, self.text, self.config.text
297
  )
 
299
  kv_cache_update
300
  )
301
  pos += 1
302
+
303
+ if temperature == 0:
304
+ next_token = torch.argmax(logits, dim=-1) # (1, 1)
305
+ else:
306
+ probs = torch.softmax(logits / temperature, dim=-1) # (1, V)
307
+ probs = self._apply_top_p(probs, top_p)
308
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1, 1)
309
+
310
  generated_tokens += 1
311
+
312
+ # Flush any remaining text in the cache
313
+ if token_cache:
314
+ text = self.tokenizer.decode(token_cache)
315
+ printable_text = text[print_len:]
316
+ if printable_text:
317
+ yield printable_text
318
 
319
  return generator(next_token, pos)
320
 
 
657
  )
658
 
659
  return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
660
+
661
+ def _is_cjk_char(cp):
662
+ """Checks whether CP is the codepoint of a CJK character."""
663
+ # This defines a "chinese character" as anything in the CJK Unicode block:
664
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
665
+ if (
666
+ (cp >= 0x4E00 and cp <= 0x9FFF)
667
+ or (cp >= 0x3400 and cp <= 0x4DBF)
668
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
669
+ ):
670
+ return True
671
+ return False