akhaliq HF Staff commited on
Commit
7bec6b8
·
verified ·
1 Parent(s): d4d6bcc

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+ from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
6
+ from typing import Any, List, Dict
7
+
8
+ def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]:
9
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename)
10
+ with open(file_path, "r") as file:
11
+ system_prompt = file.read()
12
+
13
+ index_begin_think = system_prompt.find("[THINK]")
14
+ index_end_think = system_prompt.find("[/THINK]")
15
+
16
+ return {
17
+ "role": "system",
18
+ "content": [
19
+ {"type": "text", "text": system_prompt[:index_begin_think]},
20
+ {
21
+ "type": "text",
22
+ "text": system_prompt[index_end_think + len("[/THINK]") :],
23
+ },
24
+ ],
25
+ }
26
+
27
+ model_id = "mistralai/Magistral-Small-2509"
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral")
29
+ model = Mistral3ForConditionalGeneration.from_pretrained(
30
+ model_id, torch_dtype=torch.bfloat16, device_map="auto"
31
+ ).eval()
32
+
33
+
34
+ SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt")
35
+
36
+ @spaces.zero_gpu(duration=120)
37
+ def predict(message: str, image) -> str:
38
+ messages = [
39
+ SYSTEM_PROMPT,
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "text", "text": message},
44
+ {"type": "image_url", "image_url": {"url": image}} if image else {},
45
+ ],
46
+ },
47
+ ]
48
+
49
+ # Filter out empty image entries
50
+ messages[1]["content"] = [item for item in messages[1]["content"] if item]
51
+
52
+ tokenized = tokenizer.apply_chat_template(messages, return_dict=True)
53
+
54
+ input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0)
55
+ attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0)
56
+
57
+ if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0:
58
+ pixel_values = torch.tensor(
59
+ tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda"
60
+ ).unsqueeze(0)
61
+ image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0)
62
+ output = model.generate(
63
+ input_ids=input_ids,
64
+ attention_mask=attention_mask,
65
+ pixel_values=pixel_values,
66
+ image_sizes=image_sizes,
67
+ )[0]
68
+ else:
69
+ output = model.generate(
70
+ input_ids=input_ids,
71
+ attention_mask=attention_mask,
72
+ )[0]
73
+
74
+ decoded_output = tokenizer.decode(
75
+ output[
76
+ len(tokenized.input_ids) : (
77
+ -1 if output[-1] == tokenizer.eos_token_id else len(output)
78
+ )
79
+ ]
80
+ )
81
+ return decoded_output
82
+
83
+ demo = gr.Interface(
84
+ fn=predict,
85
+ inputs=[
86
+ gr.Textbox(label="Your Message", placeholder="Ask me anything..."),
87
+ gr.Image(label="Upload Image (Optional)", type="filepath"),
88
+ ],
89
+ outputs=gr.Textbox(label="Response"),
90
+ title="Magistral Chat App",
91
+ description='Chat with Magistral AI. Upload an image if relevant to your question.<br>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>',
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch()