| |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| PreTrainedModel, |
| AutoModelForCausalLM, |
| AutoModel, |
| SiglipImageProcessor, |
| ) |
| from .configuration_llamavision import LlamavisionConfig |
|
|
|
|
| class ProjectionModule(nn.Module): |
| def __init__(self, mm_hidden_size=1152, hidden_size=4096): |
| super(ProjectionModule, self).__init__() |
|
|
| |
| self.model = nn.Sequential( |
| nn.Linear(mm_hidden_size, hidden_size), |
| nn.GELU(), |
| nn.Linear(hidden_size, hidden_size), |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class Llamavision(PreTrainedModel): |
| config_class = LlamavisionConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.vision_model = AutoModel.from_config(self.config.vision_config) |
| self.text_model = AutoModelForCausalLM.from_config(self.config.text_config) |
| self.processor = SiglipImageProcessor() |
| self.mm_projector = ProjectionModule( |
| mm_hidden_size=config.vision_config.hidden_size, |
| hidden_size=config.text_config.hidden_size, |
| ) |
|
|
| @property |
| def device(self): |
| return self.text_model.device |
|
|
| def encode_image(self, image): |
| image = image.convert("RGB") |
| image = self.processor( |
| images=image, |
| return_tensors="pt", |
| do_resize=True, |
| size={"height": 378, "width": 378}, |
| )["pixel_values"].to( |
| device=self.vision_model.device, dtype=self.vision_model.dtype |
| ) |
| with torch.no_grad(): |
| return self.vision_model(image, output_hidden_states=True).hidden_states[-2] |
|
|
| def input_embeds(self, prompt, image_embeds, tokenizer): |
| def _tokenize(txt): |
| return tokenizer( |
| txt, return_tensors="pt", add_special_tokens=False |
| ).input_ids.to(self.device) |
|
|
| text_emb = self.text_model.get_input_embeddings() |
|
|
| embeds = [] |
|
|
| tokenized_prompt = _tokenize(prompt) |
| if ( |
| tokenizer.bos_token_id is not None |
| and tokenized_prompt[0][0] != tokenizer.bos_token_id |
| ): |
| embeds.append( |
| text_emb(torch.tensor([[tokenizer.bos_token_id]], device=self.device)) |
| ) |
|
|
| projected_image_embeds = self.mm_projector(image_embeds.to(self.device)) |
| embeds.append(projected_image_embeds) |
|
|
| embeds.append(text_emb(tokenized_prompt)) |
|
|
| return torch.cat(embeds, dim=1) |
|
|
| def get_input_embeddings(self): |
| return self.text_model.get_input_embeddings() |
|
|
| def generate( |
| self, |
| image_embeds, |
| prompt, |
| tokenizer, |
| max_new_tokens=128, |
| **kwargs, |
| ): |
| generate_config = { |
| "eos_token_id": [ |
| tokenizer.eos_token_id, |
| tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
| ], |
| "bos_token_id": tokenizer.bos_token_id, |
| "pad_token_id": tokenizer.pad_token_id, |
| "max_new_tokens": max_new_tokens, |
| **kwargs, |
| } |
|
|
| with torch.no_grad(): |
| inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer) |
| |
| attention_mask = torch.ones( |
| inputs_embeds.shape[:2], |
| dtype=torch.long, |
| device=inputs_embeds.device |
| ) |
| |
| output_ids = self.text_model.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| **generate_config |
| ) |
|
|
| return tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
| def answer_question(self, image, question, tokenizer, **kwargs): |
| image_embeds = self.encode_image(image) |
|
|
| chat = [ |
| { |
| "role": "system", |
| "content": "You are a helpful AI assistant that can see images and answer questions about them.", |
| }, |
| {"role": "user", "content": question}, |
| ] |
| prompt = tokenizer.apply_chat_template( |
| chat, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| |
| with torch.no_grad(): |
| output = self.generate( |
| image_embeds=image_embeds, |
| prompt=prompt, |
| tokenizer=tokenizer, |
| **kwargs, |
| )[0] |
|
|
| |
| cleaned_answer = output.strip() |
| return cleaned_answer |
|
|