import torch import torch.nn as nn import random from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List from PIL import Image from dataclasses import dataclass from tokenizers import Tokenizer from .config import MoondreamConfig from .image_crops import reconstruct_from_crops from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model from .text import build_text_model, text_encoder, lm_head, text_decoder from .region import decode_coordinate, encode_coordinate import os from .rope import RotaryEmbedding TextSamplingSettings = TypedDict( "TextSamplingSettings", { "max_tokens": int, "temperature": float, "top_p": float, }, total=False, ) ObjectSamplingSettings = TypedDict( "ObjectSamplingSettings", {"max_objects": int}, total=False, ) DEFAULT_MAX_TOKENS = 768 DEFAULT_TEMPERATURE = 0.5 DEFAULT_TOP_P = 0.3 DEFAULT_MAX_OBJECTS = 50 @dataclass(frozen=True) class EncodedImage: pos: int caches: List[Tuple[torch.Tensor, torch.Tensor]] class KVCache(nn.Module): def __init__(self, n_heads, n_kv_heads, max_context, dim, batch_size: int = 1, device=None, dtype=None): super().__init__() cache_shape = (batch_size, n_kv_heads, max_context, dim // n_heads) self.register_buffer( "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype) ) self.register_buffer( "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype) ) def update(self, pos_ids, k, v): kout, vout = self.k_cache, self.v_cache kout[:, :, pos_ids, :] = k vout[:, :, pos_ids, :] = v return kout, vout class MoondreamModel(nn.Module): def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True): super().__init__() self.config = config current_dir = os.path.dirname(os.path.abspath(__file__)) self.tokenizer = Tokenizer.from_file(os.path.join(current_dir, "tokenizer.json")) self.vision = build_vision_model(config.vision, dtype) self.text = build_text_model(config.text, dtype) self.rope = RotaryEmbedding(config.text.dim // config.text.n_heads, config.text.max_context) # Region Model self.region = nn.ModuleDict( { "coord_encoder": nn.Linear( config.region.coord_feat_dim, config.region.dim, dtype=dtype ), "coord_decoder": nn.ModuleDict( { "fc1": nn.Linear( config.region.dim, config.region.inner_dim, dtype=dtype ), "fc2": nn.Linear( config.region.inner_dim, config.region.coord_out_dim, dtype=dtype, ), } ), "size_encoder": nn.Linear( config.region.size_feat_dim, config.region.dim, dtype=dtype ), "size_decoder": nn.ModuleDict( { "fc1": nn.Linear( config.region.dim, config.region.inner_dim, dtype=dtype ), "fc2": nn.Linear( config.region.inner_dim, config.region.size_out_dim, dtype=dtype, ), } ), } ) self.region.coord_features = nn.Parameter( torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T ) self.region.size_features = nn.Parameter( torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T ) attn_mask = torch.tril( torch.ones( 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool ) ) patch_w = config.vision.crop_size // config.vision.enc_patch_size prefix_attn_len = 1 + patch_w**2 attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) # Initialize KV caches. if setup_caches: self._setup_caches() def _setup_caches(self): c = self.config.text for b in self.text.blocks: b.kv_cache = KVCache( c.n_heads, c.n_kv_heads, c.max_context, c.dim, batch_size=2, device=self.device, dtype=self.vision.pos_emb.dtype, ) def load_encoded_image(self, encoded_image: EncodedImage): for b, (k, v) in zip(self.text.blocks, encoded_image.caches): b.kv_cache.k_cache[:, :, : k.size(2), :] = k b.kv_cache.v_cache[:, :, : v.size(2), :] = v @property def device(self): return self.vision.pos_emb.device def _vis_enc(self, x: torch.Tensor): return vision_encoder(x, self.vision, self.config.vision) def _vis_proj(self, g: torch.Tensor, r: torch.Tensor): return vision_projection(g, r, self.vision, self.config.vision) def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor): return text_decoder(x, self.text, attn_mask, self.config.text, self.rope, pos_ids) def _decode_one_tok( self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor ): hidden = text_decoder(x, self.text, attn_mask, self.config.text, self.rope, pos_ids) logits = lm_head(hidden, self.text) return logits, hidden def compile(self): # TODO: vision_projection is not being compiled self._vis_enc = torch.compile(self._vis_enc, fullgraph=True) self._prefill = torch.compile(self._prefill, fullgraph=True) self._decode_one_tok = torch.compile( self._decode_one_tok, fullgraph=True, mode="reduce-overhead" ) def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor: all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device) torch._dynamo.mark_dynamic(all_crops, 0) outputs = self._vis_enc(all_crops) global_features = outputs[0] local_features = outputs[1:].view( -1, self.config.vision.enc_n_layers, self.config.vision.enc_n_layers, self.config.vision.enc_dim, ) reconstructed = reconstruct_from_crops( local_features, tiling, patch_size=1, overlap_margin=self.config.vision.overlap_margin, ) return self._vis_proj(global_features, reconstructed) def encode_image(self, image: Union[Image.Image, EncodedImage, torch.Tensor]) -> EncodedImage: if isinstance(image, EncodedImage): return image elif isinstance(image, torch.Tensor): pass elif not isinstance(image, Image.Image): raise ValueError("image must be a PIL Image or EncodedImage") # Run through text model in addition to the vision encoder, to minimize # re-computation if multiple queries are performed on this image. with torch.inference_mode(): bos = torch.tensor([[self.config.tokenizer.bos_id]], device=self.device) if isinstance(image, Image.Image): img_emb = self._run_vision_encoder(image) else: img_emb = image bos_emb = text_encoder( bos, self.text, ) bos_emb = bos_emb.expand(img_emb.size(0), -1, -1) inputs_embeds = torch.cat([bos_emb, img_emb], dim=1) mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :] pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.int32, device=self.device) self._prefill(inputs_embeds, mask, pos_ids) return EncodedImage( pos=inputs_embeds.size(1), caches=[] ) def _apply_top_p(self, probs: torch.Tensor, top_p: float): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > top_p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_probs = torch.zeros_like(probs) next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort) return next_probs def _prefill_prompt( self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float ): with torch.inference_mode(): prompt_emb = text_encoder(prompt_tokens, self.text) torch._dynamo.mark_dynamic(prompt_emb, 1) mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :] pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.int32, device=self.device) hidden = self._prefill(prompt_emb, mask, pos_ids) logits = lm_head(hidden, self.text) if temperature == 0: next_token = torch.argmax(logits, dim=-1).unsqueeze(1) else: probs = torch.softmax(logits / temperature, dim=-1) probs = self._apply_top_p(probs, top_p) next_token = torch.multinomial(probs, num_samples=1) pos = pos + prompt_emb.size(1) return logits, hidden, next_token, pos def _generate_points( self, hidden: torch.Tensor, next_token: torch.Tensor, pos: int, max_objects: int = DEFAULT_MAX_OBJECTS, ): out = [] mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.int32) with torch.inference_mode(): while ( next_token.item() != self.config.tokenizer.eos_id and len(out) < max_objects ): x_logits = decode_coordinate(hidden, self.region) x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1) next_emb = encode_coordinate( x_center.to(dtype=x_logits.dtype), self.region ).unsqueeze(0) # Decode y-coordinate mask[:, :, pos], pos_ids[0] = 1, pos _, hidden = self._decode_one_tok(next_emb, mask, pos_ids) pos += 1 y_logits = decode_coordinate(hidden, self.region) y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1) next_emb = encode_coordinate( y_center.to(dtype=y_logits.dtype), self.region ).unsqueeze(0) out.append({"x": x_center.item(), "y": y_center.item()}) # Decode next token (x-coordinate, or eos) mask[:, :, pos], pos_ids[0] = 1, pos logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids) pos += 1 next_token = torch.argmax(logits, dim=-1) return out def point( self, image: Union[Image.Image, EncodedImage, torch.Tensor], object: list[str], settings: Optional[ObjectSamplingSettings] = None, ): if self.config.tokenizer.templates["point"] is None: raise NotImplementedError("Model does not support pointing.") # set the pad token to the eos token self.tokenizer.pad_token = self.tokenizer.eos_token image = self.encode_image(image) # input batch tokenized and padded prompt_tokens = [ self.config.tokenizer.templates["point"]["prefix"] + self.tokenizer.encode(" " + obj).ids + self.config.tokenizer.templates["point"]["suffix"] for obj in object ] # padding with eos token to the same length as the longest sequence tokens_batch = self.tokenizer.pad(prompt_tokens, padding="longest", return_tensors="pt") prompt_tokens = tokens_batch.input_ids.to(self.device) _, hidden, next_token, pos = self._prefill_prompt( prompt_tokens, image.pos, temperature=0, top_p=0 ) hidden = hidden[:, -1:, :] max_objects = ( settings.get("max_objects", DEFAULT_MAX_OBJECTS) if settings else DEFAULT_MAX_OBJECTS ) objects = self._generate_points( hidden, next_token, pos, max_objects=max_objects ) return {"points": objects} def forward(self, image: Union[Image.Image, EncodedImage, torch.Tensor], prompt: str, settings: Optional[ObjectSamplingSettings] = None): return self.point(image, prompt, settings)