| from transformers import PreTrainedModel |
| from typing import List |
| import torch |
| from .configuration_flosmolv import FloSmolVConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor |
|
|
| class FloSmolV(PreTrainedModel): |
| config_class = FloSmolVConfig |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def __init__(self, config: FloSmolVConfig): |
| super().__init__(config) |
| self.florence2_model = AutoModelForCausalLM.from_pretrained( |
| self.config.vision_config["_name_or_path"], |
| trust_remote_code=True, |
| ).eval().to(self.device) |
| self.florence2_processor = AutoProcessor.from_pretrained(self.config.vision_config["_name_or_path"], trust_remote_code=True,) |
| self.smollm_model = AutoModelForCausalLM.from_pretrained( |
| self.config.llm_config["_name_or_path"], |
| trust_remote_code=True, |
| ).to(self.device) |
| self.smollm_tokenizer = AutoTokenizer.from_pretrained(self.config.llm_config["_name_or_path"], trust_remote_code=True,) |
| |
| def forward(self, image, query: str): |
| |
| prompt = "<MORE_DETAILED_CAPTION>" |
| vision_inpupt = self.florence2_processor(text=prompt, images=image, return_tensors="pt") |
| generated_ids = self.florence2_model.generate( |
| input_ids=vision_inpupt["input_ids"].to(torch.int64).to(self.device), |
| pixel_values=vision_inpupt["pixel_values"].to(self.device), |
| max_new_tokens=1024, |
| early_stopping=False, |
| do_sample=False, |
| num_beams=3, |
| ) |
| generated_text = self.florence2_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| parsed_answer = self.florence2_processor.post_process_generation( |
| generated_text, |
| task=prompt, |
| image_size=(image.width, image.height) |
| ) |
| |
| messages = [{"role": "user", "content": f"You are an expert AI assistant. Based on the CONTENT, you should answer the QUESTION in short.\n\nCONTENT:{str(parsed_answer[prompt])}\n\nQUESTION:{str(query)}\n"}] |
| input_text=self.smollm_tokenizer.apply_chat_template(messages, tokenize=False) |
| llm_inputs = self.smollm_tokenizer.encode(input_text, return_tensors="pt").to(self.device) |
| outputs = self.smollm_model.generate( |
| llm_inputs, |
| max_new_tokens=50, |
| temperature=0.2, |
| |
| do_sample=True, |
| ) |
| response = self.smollm_tokenizer.decode(outputs[0]) |
| cleaned_text = response.split("assistant\n", 1)[-1].strip() |
| return cleaned_text[:-10] if cleaned_text[-10:] == "<|im_end|>" else cleaned_text |