FractalAIR commited on
Commit
e2eaf4a
·
verified ·
1 Parent(s): 6a093cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -46
app.py CHANGED
@@ -1,57 +1,108 @@
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from threading import Thread
4
  import gradio as gr
 
5
 
6
- model_id = "FractalAIResearch/Fathom-R1-14B" # or your HF repo path
7
 
8
- def load_model():
9
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- device_map="auto",
13
- torch_dtype=torch.bfloat16,
14
- trust_remote_code=True,
15
- low_cpu_mem_usage=True
16
- )
17
- return model, tokenizer
18
-
19
- model, tokenizer = load_model()
20
-
21
- def generate_response(message, history):
22
- prompt = ""
23
- for user, bot in history:
24
- prompt += f"<|user|>\n{user.strip()}\n<|assistant|>\n{bot.strip()}\n"
25
- prompt += f"<|user|>\n{message.strip()}\n<|assistant|>\n"
26
-
27
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
29
- generate_kwargs = dict(
30
- **inputs,
31
- streamer=streamer,
32
- max_new_tokens=16384,
33
- do_sample=True,
34
- temperature=0.7,
35
- top_p=0.9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
38
- thread.start()
39
- for new_text in streamer:
40
- yield new_text
41
-
42
- with gr.Blocks() as demo:
43
- gr.Markdown("## 🧠 Chat with Fathom-R1 14B")
44
- chatbot = gr.Chatbot(show_copy_button=True)
45
- msg = gr.Textbox(placeholder="Ask me anything...", container=False)
 
 
 
46
  state = gr.State([])
47
 
48
- def user_submit(message, history):
49
- history = history + [[message, ""]]
50
- return "", history
 
 
 
51
 
52
- msg.submit(user_submit, [msg, state], [msg, state]).then(
53
- generate_response, [msg, state], chatbot
54
- )
 
 
 
 
 
 
 
 
 
 
55
 
56
- if __name__ == "__main__":
57
- demo.queue().launch()
 
1
+ # app.py
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
  from threading import Thread
5
  import gradio as gr
6
+ import spaces
7
 
8
+ MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
9
 
10
+ @spaces.GPU
11
+ class Chatbot:
12
+ def __init__(self):
13
+ print("⏳ Loading model...")
14
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ torch_dtype=torch.bfloat16,
18
+ device_map="auto",
19
+ trust_remote_code=True,
20
+ )
21
+ self.model.eval()
22
+ print("✅ Model loaded!")
23
+
24
+ def chat(self, messages, temperature, max_new_tokens, top_p, repetition_penalty):
25
+ # Format messages into prompt
26
+ prompt = self._format_messages(messages)
27
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
28
+
29
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
30
+ generation_kwargs = dict(
31
+ input_ids=input_ids,
32
+ streamer=streamer,
33
+ max_new_tokens=max_new_tokens,
34
+ do_sample=True,
35
+ top_p=top_p,
36
+ temperature=temperature,
37
+ repetition_penalty=repetition_penalty,
38
+ )
39
+
40
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
41
+ thread.start()
42
+
43
+ response = ""
44
+ for token in streamer:
45
+ response += token
46
+ yield response
47
+
48
+ def _format_messages(self, messages):
49
+ prompt = ""
50
+ for msg in messages:
51
+ if msg["role"] == "user":
52
+ prompt += f"<|user|>\n{msg['content'].strip()}\n"
53
+ elif msg["role"] == "assistant":
54
+ prompt += f"<|assistant|>\n{msg['content'].strip()}\n"
55
+ prompt += "<|assistant|>\n"
56
+ return prompt
57
+
58
+ chatbot = Chatbot()
59
+
60
+ # Chat state management
61
+ def user_submit(user_message, history):
62
+ history = history + [{"role": "user", "content": user_message}, {"role": "assistant", "content": ""}]
63
+ return "", history, gr.update(visible=True)
64
+
65
+ def generate(history, temperature, max_new_tokens, top_p, repetition_penalty):
66
+ response_gen = chatbot.chat(
67
+ history,
68
+ temperature=temperature,
69
+ max_new_tokens=max_new_tokens,
70
+ top_p=top_p,
71
+ repetition_penalty=repetition_penalty,
72
  )
73
+ partial = ""
74
+ for chunk in response_gen:
75
+ partial = chunk
76
+ history[-1]["content"] = partial
77
+ yield history, history
78
+
79
+ def reset():
80
+ return [], []
81
+
82
+ with gr.Blocks(css="footer {display: none !important;}") as demo:
83
+ gr.Markdown("<h1 align='center'>🧠 Fathom R1 14B Chatbot</h1>")
84
+ chatbot_ui = gr.Chatbot([], elem_id="chatbot", height=500, bubble_full_width=False)
85
  state = gr.State([])
86
 
87
+ with gr.Row():
88
+ with gr.Column(scale=6):
89
+ txt = gr.Textbox(placeholder="Ask a math question...", label="Your Message")
90
+ with gr.Column(scale=1):
91
+ submit = gr.Button("Submit", variant="primary")
92
+ clear = gr.Button("Clear")
93
 
94
+ with gr.Accordion("Advanced settings", open=False):
95
+ temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
96
+ max_new_tokens = gr.Slider(64, 2048, step=64, value=512, label="Max New Tokens")
97
+ top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
98
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, label="Repetition Penalty")
99
+
100
+ submit.click(user_submit, [txt, state], [txt, state, chatbot_ui], queue=False)\
101
+ .then(generate, [state, temperature, max_new_tokens, top_p, repetition_penalty], [chatbot_ui, state])
102
+
103
+ txt.submit(user_submit, [txt, state], [txt, state, chatbot_ui], queue=False)\
104
+ .then(generate, [state, temperature, max_new_tokens, top_p, repetition_penalty], [chatbot_ui, state])
105
+
106
+ clear.click(reset, outputs=[chatbot_ui, state])
107
 
108
+ demo.queue().launch()