File size: 3,198 Bytes
7bec6b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
from typing import Any, List, Dict

def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]:
    file_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(file_path, "r") as file:
        system_prompt = file.read()

    index_begin_think = system_prompt.find("[THINK]")
    index_end_think = system_prompt.find("[/THINK]")

    return {
        "role": "system",
        "content": [
            {"type": "text", "text": system_prompt[:index_begin_think]},
            {
                "type": "text",
                "text": system_prompt[index_end_think + len("[/THINK]") :],
            },
        ],
    }

model_id = "mistralai/Magistral-Small-2509"
tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral")
model = Mistral3ForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()


SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt")

@spaces.zero_gpu(duration=120)
def predict(message: str, image) -> str:
    messages = [
        SYSTEM_PROMPT,
        {
            "role": "user",
            "content": [
                {"type": "text", "text": message},
                {"type": "image_url", "image_url": {"url": image}} if image else {},
            ],
        },
    ]

    # Filter out empty image entries
    messages[1]["content"] = [item for item in messages[1]["content"] if item]

    tokenized = tokenizer.apply_chat_template(messages, return_dict=True)

    input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0)
    attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0)

    if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0:
        pixel_values = torch.tensor(
            tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda"
        ).unsqueeze(0)
        image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0)
        output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
        )[0]
    else:
        output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )[0]

    decoded_output = tokenizer.decode(
        output[
            len(tokenized.input_ids) : (
                -1 if output[-1] == tokenizer.eos_token_id else len(output)
            )
        ]
    )
    return decoded_output

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label="Your Message", placeholder="Ask me anything..."),
        gr.Image(label="Upload Image (Optional)", type="filepath"),
    ],
    outputs=gr.Textbox(label="Response"),
    title="Magistral Chat App",
    description='Chat with Magistral AI. Upload an image if relevant to your question.<br>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>',
)

if __name__ == "__main__":
    demo.launch()