| | from transformers import PreTrainedModel, PretrainedConfig |
| |
|
| | from .config import MoondreamConfig |
| | from .moondream import MoondreamModel |
| |
|
| | |
| | from .image_crops import * |
| | from .vision import * |
| | from .text import * |
| | from .region import * |
| | from .utils import * |
| |
|
| |
|
| | def extract_question(text): |
| | prefix = "<image>\n\nQuestion: " |
| | suffix = "\n\nAnswer:" |
| | |
| | if text.startswith(prefix) and text.endswith(suffix): |
| | return text[len(prefix) : -len(suffix)] |
| | else: |
| | return None |
| |
|
| |
|
| | class HfConfig(PretrainedConfig): |
| | _auto_class = "AutoConfig" |
| | model_type = "moondream1" |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.config = {} |
| |
|
| |
|
| | class HfMoondream(PreTrainedModel): |
| | _auto_class = "AutoModelForCausalLM" |
| | config_class = HfConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = MoondreamModel(MoondreamConfig.from_dict(config.config)) |
| |
|
| | @property |
| | def encode_image(self): |
| | return self.model.encode_image |
| |
|
| | @property |
| | def query(self): |
| | return self.model.query |
| |
|
| | @property |
| | def caption(self): |
| | return self.model.caption |
| |
|
| | @property |
| | def detect(self): |
| | return self.model.detect |
| |
|
| | @property |
| | def point(self): |
| | return self.model.point |
| |
|
| | @property |
| | def detect_gaze(self): |
| | return self.model.detect_gaze |
| |
|
| | def answer_question( |
| | self, |
| | image_embeds, |
| | question, |
| | tokenizer=None, |
| | chat_history="", |
| | result_queue=None, |
| | max_new_tokens=256, |
| | **kwargs |
| | ): |
| | answer = self.query(image_embeds, question)["answer"].strip() |
| |
|
| | if result_queue is not None: |
| | result_queue.put(answer) |
| | return answer |
| |
|
| | def batch_answer(self, images, prompts, tokenizer=None, **kwargs): |
| | answers = [] |
| | for image, prompt in zip(images, prompts): |
| | answers.append(self.query(image, prompt)["answer"].strip()) |
| | return answers |
| |
|
| | def _unsupported_exception(self): |
| | raise NotImplementedError( |
| | "This method is not supported in the latest version of moondream. " |
| | "Consider upgrading to the updated API spec, or alternately pin " |
| | "to 'revision=2024-08-26'." |
| | ) |
| |
|
| | def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs): |
| | """ |
| | Function definition remains unchanged for backwards compatibility. |
| | Be aware that tokenizer, max_new_takens, and kwargs are ignored. |
| | """ |
| | prompt_extracted = extract_question(prompt) |
| | if prompt_extracted is not None: |
| | answer = self.model.query(image=image_embeds, question=prompt_extracted, stream=False)[ |
| | "answer" |
| | ] |
| | else: |
| | image_embeds = self.encode_image(image_embeds) |
| | prompt_tokens = torch.tensor( |
| | [self.model.tokenizer.encode(prompt).ids], |
| | device=self.device, |
| | ) |
| | def generator(): |
| | for token in self.model._generate_text( |
| | prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens |
| | ): |
| | yield token |
| | answer = "".join(list(generator())) |
| | |
| | return [answer] |
| |
|
| | def get_input_embeddings(self): |
| | return super().get_input_embeddings() |
| |
|
| | def input_embeds(self, *args, **kwargs): |
| | self._unsupported_exception() |
| |
|