| | import torch |
| | import torch.nn as nn |
| |
|
| | from transformers import PreTrainedModel, PretrainedConfig |
| | from typing import Union |
| |
|
| | from .config import MoondreamConfig |
| | from .moondream import MoondreamModel |
| |
|
| | |
| | from .image_crops import * |
| | from .vision import * |
| | from .text import * |
| | from .region import * |
| | from .utils import * |
| |
|
| |
|
| | def extract_question(text): |
| | prefix = "<image>\n\nQuestion: " |
| | suffix = "\n\nAnswer:" |
| |
|
| | if text.startswith(prefix) and text.endswith(suffix): |
| | return text[len(prefix) : -len(suffix)] |
| | else: |
| | return None |
| |
|
| |
|
| | class HfConfig(PretrainedConfig): |
| | _auto_class = "AutoConfig" |
| | model_type = "moondream1" |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.config = {} |
| |
|
| |
|
| | class HfMoondream(PreTrainedModel): |
| | _auto_class = "AutoModelForCausalLM" |
| | config_class = HfConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = MoondreamModel( |
| | MoondreamConfig.from_dict(config.config), setup_caches=False |
| | ) |
| | self._is_kv_cache_setup = False |
| |
|
| | def _setup_caches(self): |
| | if not self._is_kv_cache_setup: |
| | self.model._setup_caches() |
| | self._is_kv_cache_setup = True |
| |
|
| | @property |
| | def encode_image(self): |
| | self._setup_caches() |
| | return self.model.encode_image |
| |
|
| | @property |
| | def query(self): |
| | self._setup_caches() |
| | return self.model.query |
| |
|
| | @property |
| | def caption(self): |
| | self._setup_caches() |
| | return self.model.caption |
| |
|
| | @property |
| | def detect(self): |
| | self._setup_caches() |
| | return self.model.detect |
| |
|
| | @property |
| | def point(self): |
| | self._setup_caches() |
| | return self.model.point |
| |
|
| | @property |
| | def detect_gaze(self): |
| | self._setup_caches() |
| | return self.model.detect_gaze |
| |
|
| | def answer_question( |
| | self, |
| | image_embeds, |
| | question, |
| | tokenizer=None, |
| | chat_history="", |
| | result_queue=None, |
| | max_new_tokens=256, |
| | **kwargs |
| | ): |
| | answer = self.query(image_embeds, question)["answer"].strip() |
| |
|
| | if result_queue is not None: |
| | result_queue.put(answer) |
| | return answer |
| |
|
| | def batch_answer(self, images, prompts, tokenizer=None, **kwargs): |
| | answers = [] |
| | for image, prompt in zip(images, prompts): |
| | answers.append(self.query(image, prompt)["answer"].strip()) |
| | return answers |
| |
|
| | def _unsupported_exception(self): |
| | raise NotImplementedError( |
| | "This method is not supported in the latest version of moondream. " |
| | "Consider upgrading to the updated API spec, or alternately pin " |
| | "to 'revision=2024-08-26'." |
| | ) |
| |
|
| | def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs): |
| | """ |
| | Function definition remains unchanged for backwards compatibility. |
| | Be aware that tokenizer, max_new_takens, and kwargs are ignored. |
| | """ |
| | prompt_extracted = extract_question(prompt) |
| | if prompt_extracted is not None: |
| | answer = self.model.query( |
| | image=image_embeds, question=prompt_extracted, stream=False |
| | )["answer"] |
| | else: |
| | image_embeds = self.encode_image(image_embeds) |
| | prompt_tokens = torch.tensor( |
| | [self.model.tokenizer.encode(prompt).ids], |
| | device=self.device, |
| | ) |
| |
|
| | def generator(): |
| | for token in self.model._generate_answer( |
| | prompt_tokens, |
| | image_embeds.kv_cache, |
| | image_embeds.pos, |
| | max_new_tokens, |
| | ): |
| | yield token |
| |
|
| | answer = "".join(list(generator())) |
| |
|
| | return [answer] |
| |
|
| | def get_input_embeddings(self) -> nn.Embedding: |
| | """ |
| | Lazily wrap the raw parameter `self.model.text.wte` in a real |
| | `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper |
| | **shares** the weight tensor—no copy is made. |
| | """ |
| | if not hasattr(self, "_input_embeddings"): |
| | self._input_embeddings = nn.Embedding.from_pretrained( |
| | self.model.text.wte, |
| | freeze=True, |
| | ) |
| | return self._input_embeddings |
| |
|
| | def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None: |
| | """ |
| | Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the |
| | embeddings and keeps everything tied to `self.model.text.wte`. |
| | """ |
| | |
| | self.model.text.wte = value.weight |
| | |
| | self._input_embeddings = value |
| |
|
| | def input_embeds( |
| | self, |
| | input_ids: Union[torch.LongTensor, list, tuple], |
| | *, |
| | device: torch.device | None = None |
| | ) -> torch.FloatTensor: |
| | """ |
| | Back-compat wrapper that turns token IDs into embeddings. |
| | |
| | Example: |
| | ids = torch.tensor([[1, 2, 3]]) |
| | embeds = model.input_embeds(ids) # (1, 3, hidden_dim) |
| | """ |
| | if not torch.is_tensor(input_ids): |
| | input_ids = torch.as_tensor(input_ids) |
| | if device is not None: |
| | input_ids = input_ids.to(device) |
| |
|
| | return self.get_input_embeddings()(input_ids) |
| |
|