Update moondream.py
Browse filesfeat: adding warning to code, please do not touch the cache_reset.
- moondream.py +40 -43
moondream.py
CHANGED
|
@@ -249,49 +249,6 @@ class MoondreamModel(nn.Module):
|
|
| 249 |
|
| 250 |
return self._vis_proj(global_features, reconstructed)
|
| 251 |
|
| 252 |
-
|
| 253 |
-
def encode_image(
|
| 254 |
-
self,
|
| 255 |
-
image: Union[Image.Image, EncodedImage],
|
| 256 |
-
settings: Optional[ImageEncodingSettings] = None,
|
| 257 |
-
) -> EncodedImage:
|
| 258 |
-
# Always start from single-row caches; avoids leftovers from batched runs.
|
| 259 |
-
self._setup_caches()
|
| 260 |
-
|
| 261 |
-
if isinstance(image, EncodedImage):
|
| 262 |
-
return image
|
| 263 |
-
elif not isinstance(image, Image.Image):
|
| 264 |
-
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 265 |
-
|
| 266 |
-
lora = (
|
| 267 |
-
variant_state_dict(settings["variant"], device=self.device)
|
| 268 |
-
if settings is not None and "variant" in settings
|
| 269 |
-
else None
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
with torch.inference_mode():
|
| 273 |
-
img_emb = self._run_vision_encoder(image)
|
| 274 |
-
bos_emb = text_encoder(
|
| 275 |
-
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
|
| 276 |
-
)
|
| 277 |
-
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 278 |
-
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 279 |
-
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 280 |
-
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 281 |
-
|
| 282 |
-
return EncodedImage(
|
| 283 |
-
pos=inputs_embeds.size(1),
|
| 284 |
-
caches=[
|
| 285 |
-
(
|
| 286 |
-
b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 287 |
-
b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 288 |
-
)
|
| 289 |
-
for b in self.text.blocks
|
| 290 |
-
],
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
|
| 296 |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 297 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
@@ -563,6 +520,46 @@ class MoondreamModel(nn.Module):
|
|
| 563 |
|
| 564 |
return generator(next_token, pos)
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
def query(
|
| 567 |
self,
|
| 568 |
image: Optional[Union[Image.Image, EncodedImage]] = None,
|
|
|
|
| 249 |
|
| 250 |
return self._vis_proj(global_features, reconstructed)
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
|
| 253 |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 254 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
|
|
| 520 |
|
| 521 |
return generator(next_token, pos)
|
| 522 |
|
| 523 |
+
def encode_image(
|
| 524 |
+
self,
|
| 525 |
+
image: Union[Image.Image, EncodedImage],
|
| 526 |
+
settings: Optional[ImageEncodingSettings] = None,
|
| 527 |
+
) -> EncodedImage:
|
| 528 |
+
# Always start from single-row caches; avoids leftovers from batched runs. DO NOT TOUCH THIS!!!!!!!!!
|
| 529 |
+
self._setup_caches()
|
| 530 |
+
|
| 531 |
+
if isinstance(image, EncodedImage):
|
| 532 |
+
return image
|
| 533 |
+
elif not isinstance(image, Image.Image):
|
| 534 |
+
raise ValueError("image must be a PIL Image or EncodedImage")
|
| 535 |
+
|
| 536 |
+
lora = (
|
| 537 |
+
variant_state_dict(settings["variant"], device=self.device)
|
| 538 |
+
if settings is not None and "variant" in settings
|
| 539 |
+
else None
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
with torch.inference_mode():
|
| 543 |
+
img_emb = self._run_vision_encoder(image)
|
| 544 |
+
bos_emb = text_encoder(
|
| 545 |
+
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
|
| 546 |
+
)
|
| 547 |
+
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 548 |
+
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 549 |
+
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 550 |
+
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 551 |
+
|
| 552 |
+
return EncodedImage(
|
| 553 |
+
pos=inputs_embeds.size(1),
|
| 554 |
+
caches=[
|
| 555 |
+
(
|
| 556 |
+
b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 557 |
+
b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 558 |
+
)
|
| 559 |
+
for b in self.text.blocks
|
| 560 |
+
],
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
def query(
|
| 564 |
self,
|
| 565 |
image: Optional[Union[Image.Image, EncodedImage]] = None,
|