akhaliq's picture
akhaliq HF Staff
Upload app.py with huggingface_hub
7bec6b8 verified
raw
history blame
3.2 kB
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
from typing import Any, List, Dict
def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]:
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "r") as file:
system_prompt = file.read()
index_begin_think = system_prompt.find("[THINK]")
index_end_think = system_prompt.find("[/THINK]")
return {
"role": "system",
"content": [
{"type": "text", "text": system_prompt[:index_begin_think]},
{
"type": "text",
"text": system_prompt[index_end_think + len("[/THINK]") :],
},
],
}
model_id = "mistralai/Magistral-Small-2509"
tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral")
model = Mistral3ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt")
@spaces.zero_gpu(duration=120)
def predict(message: str, image) -> str:
messages = [
SYSTEM_PROMPT,
{
"role": "user",
"content": [
{"type": "text", "text": message},
{"type": "image_url", "image_url": {"url": image}} if image else {},
],
},
]
# Filter out empty image entries
messages[1]["content"] = [item for item in messages[1]["content"] if item]
tokenized = tokenizer.apply_chat_template(messages, return_dict=True)
input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0)
attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0)
if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0:
pixel_values = torch.tensor(
tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda"
).unsqueeze(0)
image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0)
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_sizes=image_sizes,
)[0]
else:
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
)[0]
decoded_output = tokenizer.decode(
output[
len(tokenized.input_ids) : (
-1 if output[-1] == tokenizer.eos_token_id else len(output)
)
]
)
return decoded_output
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Your Message", placeholder="Ask me anything..."),
gr.Image(label="Upload Image (Optional)", type="filepath"),
],
outputs=gr.Textbox(label="Response"),
title="Magistral Chat App",
description='Chat with Magistral AI. Upload an image if relevant to your question.<br>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>',
)
if __name__ == "__main__":
demo.launch()