| from vllm import LLM | |
| from vllm.sampling_params import SamplingParams | |
| import base64 | |
| def encode_image(image_path: str): | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode("utf-8") | |
| class Pixtral: | |
| def __init__(self, max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.65, temperature=0.35): | |
| self.model_name = "mistralai/Pixtral-12B-2409" | |
| self.sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temperature) | |
| self.llm = LLM( | |
| model=self.model_name, | |
| tokenizer_mode="mistral", | |
| gpu_memory_utilization=gpu_memory_utilization, | |
| load_format="mistral", | |
| config_format="mistral", | |
| max_model_len=max_model_len | |
| ) | |
| def generate_message_from_image(self, prompt, image_path): | |
| base64_image = encode_image(image_path) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
| ] | |
| }, | |
| ] | |
| outputs = self.llm.chat(messages, sampling_params=self.sampling_params) | |
| print("OUTPUT") | |
| print(outputs[0].outputs[0].text) | |
| return outputs[0].outputs[0].text | |
| def generate_message(self, prompt): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| ] | |
| }, | |
| ] | |
| outputs = self.llm.chat(messages, sampling_params=self.sampling_params) | |
| return outputs[0].outputs[0].text | |