moondream / moondream2 /moondream.py
johnmalek312
broken change start of batching
ded605e
import torch
import torch.nn as nn
import random
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
from PIL import Image
from dataclasses import dataclass
from tokenizers import Tokenizer
from .config import MoondreamConfig
from .image_crops import reconstruct_from_crops
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
from .text import build_text_model, text_encoder, lm_head, text_decoder
from .region import decode_coordinate, encode_coordinate
import os
from .rope import RotaryEmbedding
TextSamplingSettings = TypedDict(
"TextSamplingSettings",
{
"max_tokens": int,
"temperature": float,
"top_p": float,
},
total=False,
)
ObjectSamplingSettings = TypedDict(
"ObjectSamplingSettings",
{"max_objects": int},
total=False,
)
DEFAULT_MAX_TOKENS = 768
DEFAULT_TEMPERATURE = 0.5
DEFAULT_TOP_P = 0.3
DEFAULT_MAX_OBJECTS = 50
@dataclass(frozen=True)
class EncodedImage:
pos: int
caches: List[Tuple[torch.Tensor, torch.Tensor]]
class KVCache(nn.Module):
def __init__(self,
n_heads,
n_kv_heads,
max_context,
dim,
batch_size: int = 1,
device=None,
dtype=None):
super().__init__()
cache_shape = (batch_size,
n_kv_heads,
max_context,
dim // n_heads)
self.register_buffer(
"k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
)
self.register_buffer(
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
)
def update(self, pos_ids, k, v):
kout, vout = self.k_cache, self.v_cache
kout[:, :, pos_ids, :] = k
vout[:, :, pos_ids, :] = v
return kout, vout
class MoondreamModel(nn.Module):
def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
super().__init__()
self.config = config
current_dir = os.path.dirname(os.path.abspath(__file__))
self.tokenizer = Tokenizer.from_file(os.path.join(current_dir, "tokenizer.json"))
self.vision = build_vision_model(config.vision, dtype)
self.text = build_text_model(config.text, dtype)
self.rope = RotaryEmbedding(config.text.dim // config.text.n_heads, config.text.max_context)
# Region Model
self.region = nn.ModuleDict(
{
"coord_encoder": nn.Linear(
config.region.coord_feat_dim, config.region.dim, dtype=dtype
),
"coord_decoder": nn.ModuleDict(
{
"fc1": nn.Linear(
config.region.dim, config.region.inner_dim, dtype=dtype
),
"fc2": nn.Linear(
config.region.inner_dim,
config.region.coord_out_dim,
dtype=dtype,
),
}
),
"size_encoder": nn.Linear(
config.region.size_feat_dim, config.region.dim, dtype=dtype
),
"size_decoder": nn.ModuleDict(
{
"fc1": nn.Linear(
config.region.dim, config.region.inner_dim, dtype=dtype
),
"fc2": nn.Linear(
config.region.inner_dim,
config.region.size_out_dim,
dtype=dtype,
),
}
),
}
)
self.region.coord_features = nn.Parameter(
torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
)
self.region.size_features = nn.Parameter(
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
)
attn_mask = torch.tril(
torch.ones(
1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
)
)
patch_w = config.vision.crop_size // config.vision.enc_patch_size
prefix_attn_len = 1 + patch_w**2
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
self.register_buffer("attn_mask", attn_mask, persistent=False)
# Initialize KV caches.
if setup_caches:
self._setup_caches()
def _setup_caches(self):
c = self.config.text
for b in self.text.blocks:
b.kv_cache = KVCache(
c.n_heads,
c.n_kv_heads,
c.max_context,
c.dim,
batch_size=2,
device=self.device,
dtype=self.vision.pos_emb.dtype,
)
def load_encoded_image(self, encoded_image: EncodedImage):
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
b.kv_cache.k_cache[:, :, : k.size(2), :] = k
b.kv_cache.v_cache[:, :, : v.size(2), :] = v
@property
def device(self):
return self.vision.pos_emb.device
def _vis_enc(self, x: torch.Tensor):
return vision_encoder(x, self.vision, self.config.vision)
def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
return vision_projection(g, r, self.vision, self.config.vision)
def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
return text_decoder(x, self.text, attn_mask, self.config.text, self.rope, pos_ids)
def _decode_one_tok(
self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
):
hidden = text_decoder(x, self.text, attn_mask, self.config.text, self.rope, pos_ids)
logits = lm_head(hidden, self.text)
return logits, hidden
def compile(self):
# TODO: vision_projection is not being compiled
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
self._prefill = torch.compile(self._prefill, fullgraph=True)
self._decode_one_tok = torch.compile(
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
)
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
torch._dynamo.mark_dynamic(all_crops, 0)
outputs = self._vis_enc(all_crops)
global_features = outputs[0]
local_features = outputs[1:].view(
-1,
self.config.vision.enc_n_layers,
self.config.vision.enc_n_layers,
self.config.vision.enc_dim,
)
reconstructed = reconstruct_from_crops(
local_features,
tiling,
patch_size=1,
overlap_margin=self.config.vision.overlap_margin,
)
return self._vis_proj(global_features, reconstructed)
def encode_image(self, image: Union[Image.Image, EncodedImage, torch.Tensor]) -> EncodedImage:
if isinstance(image, EncodedImage):
return image
elif isinstance(image, torch.Tensor):
pass
elif not isinstance(image, Image.Image):
raise ValueError("image must be a PIL Image or EncodedImage")
# Run through text model in addition to the vision encoder, to minimize
# re-computation if multiple queries are performed on this image.
with torch.inference_mode():
bos = torch.tensor([[self.config.tokenizer.bos_id]], device=self.device)
if isinstance(image, Image.Image):
img_emb = self._run_vision_encoder(image)
else:
img_emb = image
bos_emb = text_encoder(
bos,
self.text,
)
bos_emb = bos_emb.expand(img_emb.size(0), -1, -1)
inputs_embeds = torch.cat([bos_emb, img_emb], dim=1)
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.int32, device=self.device)
self._prefill(inputs_embeds, mask, pos_ids)
return EncodedImage(
pos=inputs_embeds.size(1),
caches=[]
)
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_probs = torch.zeros_like(probs)
next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
return next_probs
def _prefill_prompt(
self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
):
with torch.inference_mode():
prompt_emb = text_encoder(prompt_tokens, self.text)
torch._dynamo.mark_dynamic(prompt_emb, 1)
mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.int32, device=self.device)
hidden = self._prefill(prompt_emb, mask, pos_ids)
logits = lm_head(hidden, self.text)
if temperature == 0:
next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
else:
probs = torch.softmax(logits / temperature, dim=-1)
probs = self._apply_top_p(probs, top_p)
next_token = torch.multinomial(probs, num_samples=1)
pos = pos + prompt_emb.size(1)
return logits, hidden, next_token, pos
def _generate_points(
self,
hidden: torch.Tensor,
next_token: torch.Tensor,
pos: int,
max_objects: int = DEFAULT_MAX_OBJECTS,
):
out = []
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
mask[:, :, :pos] = 1
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.int32)
with torch.inference_mode():
while (
next_token.item() != self.config.tokenizer.eos_id
and len(out) < max_objects
):
x_logits = decode_coordinate(hidden, self.region)
x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
next_emb = encode_coordinate(
x_center.to(dtype=x_logits.dtype), self.region
).unsqueeze(0)
# Decode y-coordinate
mask[:, :, pos], pos_ids[0] = 1, pos
_, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
pos += 1
y_logits = decode_coordinate(hidden, self.region)
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
next_emb = encode_coordinate(
y_center.to(dtype=y_logits.dtype), self.region
).unsqueeze(0)
out.append({"x": x_center.item(), "y": y_center.item()})
# Decode next token (x-coordinate, or eos)
mask[:, :, pos], pos_ids[0] = 1, pos
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
pos += 1
next_token = torch.argmax(logits, dim=-1)
return out
def point(
self,
image: Union[Image.Image, EncodedImage, torch.Tensor],
object: list[str],
settings: Optional[ObjectSamplingSettings] = None,
):
if self.config.tokenizer.templates["point"] is None:
raise NotImplementedError("Model does not support pointing.")
# set the pad token to the eos token
self.tokenizer.pad_token = self.tokenizer.eos_token
image = self.encode_image(image)
# input batch tokenized and padded
prompt_tokens = [
self.config.tokenizer.templates["point"]["prefix"]
+ self.tokenizer.encode(" " + obj).ids
+ self.config.tokenizer.templates["point"]["suffix"]
for obj in object
]
# padding with eos token to the same length as the longest sequence
tokens_batch = self.tokenizer.pad(prompt_tokens, padding="longest", return_tensors="pt")
prompt_tokens = tokens_batch.input_ids.to(self.device)
_, hidden, next_token, pos = self._prefill_prompt(
prompt_tokens, image.pos, temperature=0, top_p=0
)
hidden = hidden[:, -1:, :]
max_objects = (
settings.get("max_objects", DEFAULT_MAX_OBJECTS)
if settings
else DEFAULT_MAX_OBJECTS
)
objects = self._generate_points(
hidden, next_token, pos, max_objects=max_objects
)
return {"points": objects}
def forward(self, image: Union[Image.Image, EncodedImage, torch.Tensor], prompt: str, settings: Optional[ObjectSamplingSettings] = None):
return self.point(image, prompt, settings)