| import torch |
| from .vision_encoder import VisionEncoder |
| from .configuration_moondream import MoondreamConfig |
| from transformers import PreTrainedModel |
|
|
| from .modeling_phi import PhiForCausalLM |
| from .configuration_moondream import PhiConfig |
|
|
| class Moondream(PreTrainedModel): |
| config_class = MoondreamConfig |
| _supports_flash_attn_2 = True |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.vision_encoder = VisionEncoder() |
|
|
| if type(config.phi_config) == dict: |
| phi_config = PhiConfig( |
| **config.phi_config, attn_implementation=config._attn_implementation |
| ) |
| else: |
| phi_config = config.phi_config |
| self.text_model = PhiForCausalLM(phi_config) |
|
|
| @property |
| def device(self): |
| return self.text_model.device |
|
|
| def encode_image(self, image): |
| return self.vision_encoder(image) |
|
|
| def input_embeds(self, prompt, image_embeds, tokenizer): |
| def _tokenize(txt): |
| return tokenizer( |
| txt, return_tensors="pt", add_special_tokens=False |
| ).input_ids.to(self.device) |
|
|
| text_emb = self.text_model.get_input_embeddings() |
|
|
| |
| embeds = [] |
| embeds.append( |
| text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device))) |
| ) |
|
|
| if "<image>" not in prompt: |
| embeds.append(text_emb(_tokenize(prompt))) |
| else: |
| assert prompt.count("<image>") == 1 |
| before, after = prompt.split("<image>") |
| if len(before) > 0: |
| embeds.append(text_emb(_tokenize(before))) |
| embeds.append(image_embeds.to(self.device)) |
| if len(after) > 0: |
| embeds.append(text_emb(_tokenize(after))) |
|
|
| return torch.cat(embeds, dim=1) |
|
|
| def generate( |
| self, |
| image_embeds, |
| prompt, |
| tokenizer, |
| max_new_tokens=128, |
| **kwargs, |
| ): |
| generate_config = { |
| "eos_token_id": tokenizer.eos_token_id, |
| "bos_token_id": tokenizer.bos_token_id, |
| "pad_token_id": tokenizer.bos_token_id, |
| "max_new_tokens": max_new_tokens, |
| **kwargs, |
| } |
|
|
| with torch.no_grad(): |
| inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer) |
| output_ids = self.text_model.generate( |
| inputs_embeds=inputs_embeds, **generate_config |
| ) |
|
|
| return tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
| def answer_question( |
| self, |
| image_embeds, |
| question, |
| tokenizer, |
| chat_history="", |
| result_queue=None, |
| **kwargs, |
| ): |
| prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:" |
| answer = self.generate( |
| image_embeds, |
| prompt, |
| tokenizer=tokenizer, |
| max_new_tokens=512, |
| **kwargs, |
| )[0] |
| cleaned_answer = answer.strip() |
|
|
| |
| if result_queue: |
| result_queue.put(cleaned_answer) |
| else: |
| return cleaned_answer |
|
|
| def batch_answer( |
| self, |
| images, |
| prompts, |
| tokenizer, |
| **kwargs, |
| ): |
| image_embeds = self.encode_image(images) |
|
|
| templated_prompts = [ |
| f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts |
| ] |
| prompt_embs = [ |
| self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0] |
| for prompt, image_embed in zip(templated_prompts, image_embeds) |
| ] |
|
|
| bos_emb = prompt_embs[0][0] |
| max_len = max([p.shape[0] for p in prompt_embs]) |
|
|
| inputs_embeds = torch.cat( |
| [ |
| torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0) |
| for p in prompt_embs |
| ], |
| dim=0, |
| ) |
| attention_mask = torch.cat( |
| [ |
| torch.cat( |
| [ |
| torch.zeros( |
| 1, |
| max_len - p.shape[0], |
| device=self.device, |
| dtype=torch.long, |
| ), |
| torch.ones(1, p.shape[0], device=self.device, dtype=torch.long), |
| ], |
| dim=1, |
| ) |
| for p in prompt_embs |
| ], |
| dim=0, |
| ) |
|
|
| generate_config = { |
| "eos_token_id": tokenizer.eos_token_id, |
| "bos_token_id": tokenizer.bos_token_id, |
| "pad_token_id": tokenizer.bos_token_id, |
| "max_new_tokens": 512, |
| **kwargs, |
| } |
|
|
| with torch.no_grad(): |
| output_ids = self.text_model.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| **generate_config, |
| ) |
|
|
| return [ |
| x.strip() |
| for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| ] |
|
|