HV-Khurdula commited on
Commit
c7c8764
·
verified ·
1 Parent(s): 99ac5b4

Update moondream.py

Browse files

feat: adding warning to code, please do not touch the cache_reset.

Files changed (1) hide show
  1. moondream.py +40 -43
moondream.py CHANGED
@@ -249,49 +249,6 @@ class MoondreamModel(nn.Module):
249
 
250
  return self._vis_proj(global_features, reconstructed)
251
 
252
-
253
- def encode_image(
254
- self,
255
- image: Union[Image.Image, EncodedImage],
256
- settings: Optional[ImageEncodingSettings] = None,
257
- ) -> EncodedImage:
258
- # Always start from single-row caches; avoids leftovers from batched runs.
259
- self._setup_caches()
260
-
261
- if isinstance(image, EncodedImage):
262
- return image
263
- elif not isinstance(image, Image.Image):
264
- raise ValueError("image must be a PIL Image or EncodedImage")
265
-
266
- lora = (
267
- variant_state_dict(settings["variant"], device=self.device)
268
- if settings is not None and "variant" in settings
269
- else None
270
- )
271
-
272
- with torch.inference_mode():
273
- img_emb = self._run_vision_encoder(image)
274
- bos_emb = text_encoder(
275
- torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
276
- )
277
- inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
278
- mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
279
- pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
280
- self._prefill(inputs_embeds, mask, pos_ids, lora)
281
-
282
- return EncodedImage(
283
- pos=inputs_embeds.size(1),
284
- caches=[
285
- (
286
- b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
287
- b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
288
- )
289
- for b in self.text.blocks
290
- ],
291
- )
292
-
293
-
294
-
295
  def _apply_top_p(self, probs: torch.Tensor, top_p: float):
296
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
297
  probs_sum = torch.cumsum(probs_sort, dim=-1)
@@ -563,6 +520,46 @@ class MoondreamModel(nn.Module):
563
 
564
  return generator(next_token, pos)
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def query(
567
  self,
568
  image: Optional[Union[Image.Image, EncodedImage]] = None,
 
249
 
250
  return self._vis_proj(global_features, reconstructed)
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def _apply_top_p(self, probs: torch.Tensor, top_p: float):
253
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
254
  probs_sum = torch.cumsum(probs_sort, dim=-1)
 
520
 
521
  return generator(next_token, pos)
522
 
523
+ def encode_image(
524
+ self,
525
+ image: Union[Image.Image, EncodedImage],
526
+ settings: Optional[ImageEncodingSettings] = None,
527
+ ) -> EncodedImage:
528
+ # Always start from single-row caches; avoids leftovers from batched runs. DO NOT TOUCH THIS!!!!!!!!!
529
+ self._setup_caches()
530
+
531
+ if isinstance(image, EncodedImage):
532
+ return image
533
+ elif not isinstance(image, Image.Image):
534
+ raise ValueError("image must be a PIL Image or EncodedImage")
535
+
536
+ lora = (
537
+ variant_state_dict(settings["variant"], device=self.device)
538
+ if settings is not None and "variant" in settings
539
+ else None
540
+ )
541
+
542
+ with torch.inference_mode():
543
+ img_emb = self._run_vision_encoder(image)
544
+ bos_emb = text_encoder(
545
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
546
+ )
547
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
548
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
549
+ pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
550
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
551
+
552
+ return EncodedImage(
553
+ pos=inputs_embeds.size(1),
554
+ caches=[
555
+ (
556
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
557
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
558
+ )
559
+ for b in self.text.blocks
560
+ ],
561
+ )
562
+
563
  def query(
564
  self,
565
  image: Optional[Union[Image.Image, EncodedImage]] = None,