|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import random
|
|
|
|
|
|
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
|
|
|
from PIL import Image
|
|
|
from dataclasses import dataclass
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
|
from .config import MoondreamConfig
|
|
|
from .image_crops import reconstruct_from_crops
|
|
|
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
|
|
from .text import build_text_model, text_encoder, lm_head, text_decoder
|
|
|
from .region import decode_coordinate, encode_coordinate
|
|
|
import os
|
|
|
from .rope import RotaryEmbedding
|
|
|
TextSamplingSettings = TypedDict(
|
|
|
"TextSamplingSettings",
|
|
|
{
|
|
|
"max_tokens": int,
|
|
|
"temperature": float,
|
|
|
"top_p": float,
|
|
|
},
|
|
|
total=False,
|
|
|
)
|
|
|
|
|
|
ObjectSamplingSettings = TypedDict(
|
|
|
"ObjectSamplingSettings",
|
|
|
{"max_objects": int},
|
|
|
total=False,
|
|
|
)
|
|
|
|
|
|
DEFAULT_MAX_TOKENS = 768
|
|
|
DEFAULT_TEMPERATURE = 0.5
|
|
|
DEFAULT_TOP_P = 0.3
|
|
|
DEFAULT_MAX_OBJECTS = 50
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class EncodedImage:
|
|
|
pos: int
|
|
|
caches: List[Tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
|
|
class KVCache(nn.Module):
|
|
|
|
|
|
def __init__(self,
|
|
|
n_heads,
|
|
|
n_kv_heads,
|
|
|
max_context,
|
|
|
dim,
|
|
|
batch_size: int = 1,
|
|
|
device=None,
|
|
|
dtype=None):
|
|
|
super().__init__()
|
|
|
cache_shape = (batch_size,
|
|
|
n_kv_heads,
|
|
|
max_context,
|
|
|
dim // n_heads)
|
|
|
self.register_buffer(
|
|
|
"k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
|
|
)
|
|
|
self.register_buffer(
|
|
|
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
|
|
)
|
|
|
|
|
|
def update(self, pos_ids, k, v):
|
|
|
kout, vout = self.k_cache, self.v_cache
|
|
|
kout[:, :, pos_ids, :] = k
|
|
|
vout[:, :, pos_ids, :] = v
|
|
|
return kout, vout
|
|
|
class MoondreamModel(nn.Module):
|
|
|
def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
self.tokenizer = Tokenizer.from_file(os.path.join(current_dir, "tokenizer.json"))
|
|
|
self.vision = build_vision_model(config.vision, dtype)
|
|
|
self.text = build_text_model(config.text, dtype)
|
|
|
self.rope = RotaryEmbedding(config.text.dim // config.text.n_heads, config.text.max_context)
|
|
|
|
|
|
|
|
|
self.region = nn.ModuleDict(
|
|
|
{
|
|
|
"coord_encoder": nn.Linear(
|
|
|
config.region.coord_feat_dim, config.region.dim, dtype=dtype
|
|
|
),
|
|
|
"coord_decoder": nn.ModuleDict(
|
|
|
{
|
|
|
"fc1": nn.Linear(
|
|
|
config.region.dim, config.region.inner_dim, dtype=dtype
|
|
|
),
|
|
|
"fc2": nn.Linear(
|
|
|
config.region.inner_dim,
|
|
|
config.region.coord_out_dim,
|
|
|
dtype=dtype,
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
"size_encoder": nn.Linear(
|
|
|
config.region.size_feat_dim, config.region.dim, dtype=dtype
|
|
|
),
|
|
|
"size_decoder": nn.ModuleDict(
|
|
|
{
|
|
|
"fc1": nn.Linear(
|
|
|
config.region.dim, config.region.inner_dim, dtype=dtype
|
|
|
),
|
|
|
"fc2": nn.Linear(
|
|
|
config.region.inner_dim,
|
|
|
config.region.size_out_dim,
|
|
|
dtype=dtype,
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
}
|
|
|
)
|
|
|
self.region.coord_features = nn.Parameter(
|
|
|
torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
|
|
|
)
|
|
|
self.region.size_features = nn.Parameter(
|
|
|
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
|
|
|
)
|
|
|
|
|
|
attn_mask = torch.tril(
|
|
|
torch.ones(
|
|
|
1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
|
|
|
)
|
|
|
)
|
|
|
patch_w = config.vision.crop_size // config.vision.enc_patch_size
|
|
|
prefix_attn_len = 1 + patch_w**2
|
|
|
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
|
|
|
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
|
|
|
|
|
|
|
|
if setup_caches:
|
|
|
self._setup_caches()
|
|
|
|
|
|
def _setup_caches(self):
|
|
|
c = self.config.text
|
|
|
for b in self.text.blocks:
|
|
|
b.kv_cache = KVCache(
|
|
|
c.n_heads,
|
|
|
c.n_kv_heads,
|
|
|
c.max_context,
|
|
|
c.dim,
|
|
|
batch_size=2,
|
|
|
device=self.device,
|
|
|
dtype=self.vision.pos_emb.dtype,
|
|
|
)
|
|
|
def load_encoded_image(self, encoded_image: EncodedImage):
|
|
|
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
|
|
b.kv_cache.k_cache[:, :, : k.size(2), :] = k
|
|
|
b.kv_cache.v_cache[:, :, : v.size(2), :] = v
|
|
|
@property
|
|
|
def device(self):
|
|
|
return self.vision.pos_emb.device
|
|
|
|
|
|
def _vis_enc(self, x: torch.Tensor):
|
|
|
return vision_encoder(x, self.vision, self.config.vision)
|
|
|
|
|
|
def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
|
|
|
return vision_projection(g, r, self.vision, self.config.vision)
|
|
|
|
|
|
def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
|
|
|
return text_decoder(x, self.text, attn_mask, self.config.text, self.rope, pos_ids)
|
|
|
|
|
|
def _decode_one_tok(
|
|
|
self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
|
|
|
):
|
|
|
hidden = text_decoder(x, self.text, attn_mask, self.config.text, self.rope, pos_ids)
|
|
|
logits = lm_head(hidden, self.text)
|
|
|
return logits, hidden
|
|
|
|
|
|
def compile(self):
|
|
|
|
|
|
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
|
|
|
self._prefill = torch.compile(self._prefill, fullgraph=True)
|
|
|
self._decode_one_tok = torch.compile(
|
|
|
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
|
|
|
)
|
|
|
|
|
|
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
|
|
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
|
|
torch._dynamo.mark_dynamic(all_crops, 0)
|
|
|
|
|
|
outputs = self._vis_enc(all_crops)
|
|
|
|
|
|
global_features = outputs[0]
|
|
|
local_features = outputs[1:].view(
|
|
|
-1,
|
|
|
self.config.vision.enc_n_layers,
|
|
|
self.config.vision.enc_n_layers,
|
|
|
self.config.vision.enc_dim,
|
|
|
)
|
|
|
|
|
|
reconstructed = reconstruct_from_crops(
|
|
|
local_features,
|
|
|
tiling,
|
|
|
patch_size=1,
|
|
|
overlap_margin=self.config.vision.overlap_margin,
|
|
|
)
|
|
|
|
|
|
return self._vis_proj(global_features, reconstructed)
|
|
|
|
|
|
def encode_image(self, image: Union[Image.Image, EncodedImage, torch.Tensor]) -> EncodedImage:
|
|
|
if isinstance(image, EncodedImage):
|
|
|
return image
|
|
|
elif isinstance(image, torch.Tensor):
|
|
|
pass
|
|
|
elif not isinstance(image, Image.Image):
|
|
|
raise ValueError("image must be a PIL Image or EncodedImage")
|
|
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
|
bos = torch.tensor([[self.config.tokenizer.bos_id]], device=self.device)
|
|
|
|
|
|
if isinstance(image, Image.Image):
|
|
|
img_emb = self._run_vision_encoder(image)
|
|
|
else:
|
|
|
img_emb = image
|
|
|
|
|
|
bos_emb = text_encoder(
|
|
|
bos,
|
|
|
self.text,
|
|
|
)
|
|
|
bos_emb = bos_emb.expand(img_emb.size(0), -1, -1)
|
|
|
inputs_embeds = torch.cat([bos_emb, img_emb], dim=1)
|
|
|
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
|
|
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.int32, device=self.device)
|
|
|
self._prefill(inputs_embeds, mask, pos_ids)
|
|
|
|
|
|
return EncodedImage(
|
|
|
pos=inputs_embeds.size(1),
|
|
|
caches=[]
|
|
|
)
|
|
|
|
|
|
def _apply_top_p(self, probs: torch.Tensor, top_p: float):
|
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
|
mask = probs_sum - probs_sort > top_p
|
|
|
probs_sort[mask] = 0.0
|
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
|
next_probs = torch.zeros_like(probs)
|
|
|
next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
|
|
|
return next_probs
|
|
|
|
|
|
def _prefill_prompt(
|
|
|
self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
|
|
|
):
|
|
|
with torch.inference_mode():
|
|
|
prompt_emb = text_encoder(prompt_tokens, self.text)
|
|
|
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
|
|
mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
|
|
|
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.int32, device=self.device)
|
|
|
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
|
|
logits = lm_head(hidden, self.text)
|
|
|
|
|
|
if temperature == 0:
|
|
|
next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
|
|
|
else:
|
|
|
probs = torch.softmax(logits / temperature, dim=-1)
|
|
|
probs = self._apply_top_p(probs, top_p)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
pos = pos + prompt_emb.size(1)
|
|
|
return logits, hidden, next_token, pos
|
|
|
|
|
|
def _generate_points(
|
|
|
self,
|
|
|
hidden: torch.Tensor,
|
|
|
next_token: torch.Tensor,
|
|
|
pos: int,
|
|
|
max_objects: int = DEFAULT_MAX_OBJECTS,
|
|
|
):
|
|
|
out = []
|
|
|
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
|
|
|
mask[:, :, :pos] = 1
|
|
|
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.int32)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
while (
|
|
|
next_token.item() != self.config.tokenizer.eos_id
|
|
|
and len(out) < max_objects
|
|
|
):
|
|
|
x_logits = decode_coordinate(hidden, self.region)
|
|
|
x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
|
|
|
next_emb = encode_coordinate(
|
|
|
x_center.to(dtype=x_logits.dtype), self.region
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
|
mask[:, :, pos], pos_ids[0] = 1, pos
|
|
|
_, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
pos += 1
|
|
|
y_logits = decode_coordinate(hidden, self.region)
|
|
|
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
|
|
next_emb = encode_coordinate(
|
|
|
y_center.to(dtype=y_logits.dtype), self.region
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
out.append({"x": x_center.item(), "y": y_center.item()})
|
|
|
|
|
|
|
|
|
mask[:, :, pos], pos_ids[0] = 1, pos
|
|
|
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
pos += 1
|
|
|
next_token = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
return out
|
|
|
|
|
|
def point(
|
|
|
self,
|
|
|
image: Union[Image.Image, EncodedImage, torch.Tensor],
|
|
|
object: list[str],
|
|
|
settings: Optional[ObjectSamplingSettings] = None,
|
|
|
):
|
|
|
if self.config.tokenizer.templates["point"] is None:
|
|
|
raise NotImplementedError("Model does not support pointing.")
|
|
|
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
image = self.encode_image(image)
|
|
|
|
|
|
|
|
|
prompt_tokens = [
|
|
|
self.config.tokenizer.templates["point"]["prefix"]
|
|
|
+ self.tokenizer.encode(" " + obj).ids
|
|
|
+ self.config.tokenizer.templates["point"]["suffix"]
|
|
|
for obj in object
|
|
|
]
|
|
|
|
|
|
tokens_batch = self.tokenizer.pad(prompt_tokens, padding="longest", return_tensors="pt")
|
|
|
prompt_tokens = tokens_batch.input_ids.to(self.device)
|
|
|
|
|
|
|
|
|
_, hidden, next_token, pos = self._prefill_prompt(
|
|
|
prompt_tokens, image.pos, temperature=0, top_p=0
|
|
|
)
|
|
|
hidden = hidden[:, -1:, :]
|
|
|
|
|
|
max_objects = (
|
|
|
settings.get("max_objects", DEFAULT_MAX_OBJECTS)
|
|
|
if settings
|
|
|
else DEFAULT_MAX_OBJECTS
|
|
|
)
|
|
|
objects = self._generate_points(
|
|
|
hidden, next_token, pos, max_objects=max_objects
|
|
|
)
|
|
|
|
|
|
return {"points": objects}
|
|
|
|
|
|
def forward(self, image: Union[Image.Image, EncodedImage, torch.Tensor], prompt: str, settings: Optional[ObjectSamplingSettings] = None):
|
|
|
return self.point(image, prompt, settings) |