FractalAIR commited on
Commit
f14967b
Β·
verified Β·
1 Parent(s): e2eaf4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -94
app.py CHANGED
@@ -1,108 +1,296 @@
1
- # app.py
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
- from threading import Thread
5
  import gradio as gr
6
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
9
-
10
- @spaces.GPU
11
- class Chatbot:
12
- def __init__(self):
13
- print("⏳ Loading model...")
14
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
15
- self.model = AutoModelForCausalLM.from_pretrained(
16
- MODEL_NAME,
17
- torch_dtype=torch.bfloat16,
18
- device_map="auto",
19
- trust_remote_code=True,
20
- )
21
- self.model.eval()
22
- print("βœ… Model loaded!")
23
-
24
- def chat(self, messages, temperature, max_new_tokens, top_p, repetition_penalty):
25
- # Format messages into prompt
26
- prompt = self._format_messages(messages)
27
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
28
-
29
- streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
30
- generation_kwargs = dict(
31
- input_ids=input_ids,
32
- streamer=streamer,
33
- max_new_tokens=max_new_tokens,
34
- do_sample=True,
35
- top_p=top_p,
36
- temperature=temperature,
37
- repetition_penalty=repetition_penalty,
38
- )
39
-
40
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
41
- thread.start()
42
-
43
- response = ""
44
- for token in streamer:
45
- response += token
46
- yield response
47
-
48
- def _format_messages(self, messages):
49
- prompt = ""
50
- for msg in messages:
51
- if msg["role"] == "user":
52
- prompt += f"<|user|>\n{msg['content'].strip()}\n"
53
- elif msg["role"] == "assistant":
54
- prompt += f"<|assistant|>\n{msg['content'].strip()}\n"
55
- prompt += "<|assistant|>\n"
56
- return prompt
57
-
58
- chatbot = Chatbot()
59
-
60
- # Chat state management
61
- def user_submit(user_message, history):
62
- history = history + [{"role": "user", "content": user_message}, {"role": "assistant", "content": ""}]
63
- return "", history, gr.update(visible=True)
64
-
65
- def generate(history, temperature, max_new_tokens, top_p, repetition_penalty):
66
- response_gen = chatbot.chat(
67
- history,
68
- temperature=temperature,
69
- max_new_tokens=max_new_tokens,
70
- top_p=top_p,
71
- repetition_penalty=repetition_penalty,
72
  )
73
- partial = ""
74
- for chunk in response_gen:
75
- partial = chunk
76
- history[-1]["content"] = partial
77
- yield history, history
78
 
79
- def reset():
80
- return [], []
81
 
82
- with gr.Blocks(css="footer {display: none !important;}") as demo:
83
- gr.Markdown("<h1 align='center'>🧠 Fathom R1 14B Chatbot</h1>")
84
- chatbot_ui = gr.Chatbot([], elem_id="chatbot", height=500, bubble_full_width=False)
85
- state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Row():
88
- with gr.Column(scale=6):
89
- txt = gr.Textbox(placeholder="Ask a math question...", label="Your Message")
90
  with gr.Column(scale=1):
91
- submit = gr.Button("Submit", variant="primary")
92
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- with gr.Accordion("Advanced settings", open=False):
95
- temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
96
- max_new_tokens = gr.Slider(64, 2048, step=64, value=512, label="Max New Tokens")
97
- top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
98
- repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, label="Repetition Penalty")
99
 
100
- submit.click(user_submit, [txt, state], [txt, state, chatbot_ui], queue=False)\
101
- .then(generate, [state, temperature, max_new_tokens, top_p, repetition_penalty], [chatbot_ui, state])
 
 
102
 
103
- txt.submit(user_submit, [txt, state], [txt, state, chatbot_ui], queue=False)\
104
- .then(generate, [state, temperature, max_new_tokens, top_p, repetition_penalty], [chatbot_ui, state])
 
105
 
106
- clear.click(reset, outputs=[chatbot_ui, state])
 
 
 
 
 
 
 
 
107
 
108
- demo.queue().launch()
 
 
 
 
 
1
+ # ---------------------------------------------------------------
2
+ # Fathom-R1-14B ZeroGPU chat-demo (Gradio Blocks)
3
+ # ---------------------------------------------------------------
4
+
5
  import gradio as gr
6
  import spaces
7
+ import torch, re, uuid, tiktoken
8
+ from transformers import (AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ TextIteratorStreamer)
11
+ from threading import Thread
12
+
13
+ # ────────────────────────────────────────────────────────────────
14
+ # 1. Load the model on the single GPU supplied by ZeroGPU
15
+ # (4-bit to stay well below the 24 GB VRAM of an A10G)
16
+ # ────────────────────────────────────────────────────────────────
17
+ model_name = "FractalAIResearch/Fathom-R1-14B"
18
+
19
+ try:
20
+ # 1-line 4-bit loading (needs bitsandbytes, already in HF Space image)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name,
23
+ device_map="auto",
24
+ load_in_4bit=True,
25
+ trust_remote_code=True
26
+ )
27
+ except RuntimeError:
28
+ # fallback to fp16 if 4-bit isn’t available
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_name,
31
+ torch_dtype=torch.float16,
32
+ device_map="auto",
33
+ trust_remote_code=True
34
+ )
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ device = next(model.parameters()).device # usually cuda:0
38
+
39
+
40
+ # ────────────────────────────────────────────────────────────────
41
+ # 2. Helpers
42
+ # ────────────────────────────────────────────────────────────────
43
+ def format_math(text: str) -> str:
44
+ "Replace [...]/\\(...\\) with $$...$$ for nicer math rendering"
45
+ text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL)
46
+ return text.replace(r"\(", "$").replace(r"\)", "$")
47
+
48
+
49
+ def generate_conversation_id() -> str:
50
+ return str(uuid.uuid4())[:8]
51
+
52
+
53
+ # tiktoken – we just keep it to count tokens during streaming
54
+ enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
55
+
56
+
57
+ # Build a prompt that Fathom-R1 understands
58
+ BOS, SEP, EOS = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
59
+
60
+ system_message = (
61
+ "Your role as an assistant involves thoroughly exploring questions "
62
+ "through a systematic thinking process before providing the final "
63
+ "precise and accurate solutions. …" # same text you used before
64
+ )
65
+
66
+
67
+ def build_prompt(history, user_msg: str) -> str:
68
+ prompt = f"{BOS}system{SEP}{system_message}{EOS}"
69
+ for m in history:
70
+ role = m["role"]
71
+ prompt += f"{BOS}{role}{SEP}{m['content']}{EOS}"
72
+ prompt += f"{BOS}user{SEP}{user_msg}{EOS}{BOS}assistant{SEP}"
73
+ return prompt
74
+
75
+
76
+ # ────────────────────────────────────────────────────────────────
77
+ # 3. Generation (runs on the GPU for 60 s max per call)
78
+ # ────────────────────────────────────────────────────────────────
79
+ @spaces.GPU(duration=60)
80
+ def generate_response(user_message,
81
+ max_tokens,
82
+ temperature,
83
+ top_p,
84
+ history_state):
85
+ """
86
+ Takes exactly the same signature the rest of the UI expects:
87
+ returns (visible_chatbot, history_state)
88
+ """
89
+ if not user_message.strip():
90
+ return history_state, history_state
91
+
92
+ prompt = build_prompt(history_state, user_message)
93
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
94
+
95
+ streamer = TextIteratorStreamer(tokenizer,
96
+ skip_prompt=True,
97
+ skip_special_tokens=True)
98
 
99
+ gen_kwargs = dict(
100
+ input_ids=inputs["input_ids"],
101
+ attention_mask=inputs["attention_mask"],
102
+ max_new_tokens=int(max_tokens),
103
+ temperature=float(temperature),
104
+ top_p=float(top_p),
105
+ do_sample=True,
106
+ eos_token_id=tokenizer.eos_token_id,
107
+ pad_token_id=tokenizer.eos_token_id,
108
+ streamer=streamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
 
 
 
 
 
110
 
111
+ # run generate in a background thread – lets us stream tokens
112
+ Thread(target=model.generate, kwargs=gen_kwargs).start()
113
 
114
+ assistant_response = ""
115
+ new_history = history_state + [
116
+ {"role": "user", "content": user_message},
117
+ {"role": "assistant", "content": ""}
118
+ ]
119
+
120
+ # live-stream tokens to the UI
121
+ tokens_seen = 0
122
+ token_budget = int(max_tokens)
123
+
124
+ for new_tok in streamer:
125
+ assistant_response += new_tok
126
+ tokens_seen += len(enc.encode(new_tok))
127
+ new_history[-1]["content"] = format_math(assistant_response.strip())
128
+ yield new_history, new_history
129
+ if tokens_seen >= token_budget:
130
+ break
131
+
132
+ # final return
133
+ yield new_history, new_history
134
+
135
+
136
+ # ────────────────────────────────────────────────────────────────
137
+ # 4. Demo UI – identical to your current one
138
+ # ────────────────────────────────────────────────────────────────
139
+ example_messages = {
140
+ "IIT-JEE 2024 Mathematics": (
141
+ "A student appears for a quiz consisting of only true-false type "
142
+ "questions and answers all the questions. …"
143
+ ),
144
+ "IIT-JEE 2025 Physics": (
145
+ "A person sitting inside an elevator performs a weighing experiment …"
146
+ ),
147
+ "Goldman Sachs Interview Puzzle": (
148
+ "Four friends need to cross a dangerous bridge at night …"
149
+ ),
150
+ "IIT-JEE 2025 Mathematics": (
151
+ "Let S be the set of all seven-digit numbers that can be formed …"
152
+ )
153
+ }
154
+
155
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
+ # session-scoped states
157
+ conversations_state = gr.State({})
158
+ current_convo_id = gr.State(generate_conversation_id())
159
+ history_state = gr.State([])
160
+
161
+ # Header
162
+ gr.HTML(
163
+ """
164
+ <div style="display:flex;align-items:center;gap:16px;margin-bottom:1em">
165
+ <div style="background-color:black;padding:6px;border-radius:8px">
166
+ <img src="https://framerusercontent.com/images/j0KjQQyrUfkFw4NwSaxQOLAoBU.png"
167
+ style="height:48px">
168
+ </div>
169
+ <h1 style="margin:0;">Fathom R1 14B Chatbot</h1>
170
+ </div>
171
+ """
172
+ )
173
+
174
+ # Sidebar
175
+ with gr.Sidebar():
176
+ gr.Markdown("## Conversations")
177
+ conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True)
178
+ new_convo_button = gr.Button("New Conversation οΏ½οΏ½οΏ½")
179
 
180
  with gr.Row():
 
 
181
  with gr.Column(scale=1):
182
+ # intro text
183
+ gr.Markdown(
184
+ """
185
+ Welcome to the Fathom R1 14B Chatbot, developed by **Fractal AI Research**!
186
+ This model excels at reasoning tasks in mathematics and science …
187
+
188
+ Once you close this demo window, all currently saved conversations will be lost.
189
+ """
190
+ )
191
+
192
+ # Settings
193
+ gr.Markdown("### Settings")
194
+ max_tokens_slider = gr.Slider(6144, 32768, step=1024, value=16384, label="Max Tokens")
195
+ with gr.Accordion("Advanced Settings", open=True):
196
+ temperature_slider = gr.Slider(0.1, 2.0, value=0.6, label="Temperature")
197
+ top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
198
+
199
+ gr.Markdown(
200
+ """
201
+ We sincerely acknowledge [VIDraft](https://huggingface.co/VIDraft) …
202
+ """
203
+ )
204
+
205
+ with gr.Column(scale=4):
206
+ chatbot = gr.Chatbot(label="Chat", type="messages", height=520)
207
+ with gr.Row():
208
+ user_input = gr.Textbox(label="User Input",
209
+ placeholder="Type your question here…",
210
+ lines=3, scale=8)
211
+ with gr.Column():
212
+ submit_button = gr.Button("Send", variant="primary", scale=1)
213
+ clear_button = gr.Button("Clear", scale=1)
214
+
215
+ # examples
216
+ gr.Markdown("**Try these examples:**")
217
+ with gr.Row():
218
+ example1_button = gr.Button("IIT-JEE 2025 Mathematics")
219
+ example2_button = gr.Button("IIT-JEE 2025 Physics")
220
+ example3_button = gr.Button("Goldman Sachs Interview Puzzle")
221
+ example4_button = gr.Button("IIT-JEE 2024 Mathematics")
222
+
223
+ # ───────── conversation-management helpers ──────────────────
224
+ def update_conversation_list(conversations):
225
+ return [conversations[cid]["title"] for cid in conversations]
226
+
227
+ def start_new_conversation(conversations):
228
+ new_id = generate_conversation_id()
229
+ conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []}
230
+ return new_id, [], gr.update(choices=update_conversation_list(conversations),
231
+ value=conversations[new_id]["title"]), conversations
232
+
233
+ def load_conversation(selected_title, conversations):
234
+ for cid, convo in conversations.items():
235
+ if convo["title"] == selected_title:
236
+ return cid, convo["messages"], convo["messages"]
237
+ return current_convo_id.value, history_state.value, history_state.value
238
+
239
+ # main β€œsend” wrapper: keeps conversations dict in sync
240
+ def send_message(user_message, max_tokens, temperature, top_p,
241
+ convo_id, history, conversations):
242
+ if convo_id not in conversations:
243
+ title = " ".join(user_message.strip().split()[:5])
244
+ conversations[convo_id] = {"title": title, "messages": history}
245
+ if conversations[convo_id]["title"].startswith("New Conversation"):
246
+ conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5])
247
+
248
+ # call the streamer generator and forward its yields
249
+ for updated_history, new_history in generate_response(
250
+ user_message, max_tokens, temperature, top_p, history):
251
+ conversations[convo_id]["messages"] = new_history
252
+ yield (updated_history, new_history,
253
+ gr.update(choices=update_conversation_list(conversations),
254
+ value=conversations[convo_id]["title"]),
255
+ conversations)
256
+
257
+ # ───────── UI β†’ functions wiring ────────────────────────────
258
+ submit_button.click(
259
+ fn=send_message,
260
+ inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider,
261
+ current_convo_id, history_state, conversations_state],
262
+ outputs=[chatbot, history_state, conversation_selector, conversations_state],
263
+ concurrency_limit=16
264
+ ).then(
265
+ fn=lambda: gr.update(value=""),
266
+ inputs=None,
267
+ outputs=user_input
268
+ )
269
 
270
+ clear_button.click(fn=lambda: ([], []), inputs=None,
271
+ outputs=[chatbot, history_state])
 
 
 
272
 
273
+ new_convo_button.click(fn=start_new_conversation,
274
+ inputs=[conversations_state],
275
+ outputs=[current_convo_id, history_state,
276
+ conversation_selector, conversations_state])
277
 
278
+ conversation_selector.change(fn=load_conversation,
279
+ inputs=[conversation_selector, conversations_state],
280
+ outputs=[current_convo_id, history_state, chatbot])
281
 
282
+ # example buttons
283
+ example1_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2025 Mathematics"]),
284
+ None, user_input)
285
+ example2_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2025 Physics"]),
286
+ None, user_input)
287
+ example3_button.click(lambda: gr.update(value=example_messages["Goldman Sachs Interview Puzzle"]),
288
+ None, user_input)
289
+ example4_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2024 Mathematics"]),
290
+ None, user_input)
291
 
292
+ # ────────────────────────────────────────────────────────────────
293
+ # 5. Launch
294
+ # ────────────────────────────────────────────────────────────────
295
+ if __name__ == "__main__":
296
+ demo.queue().launch(share=True, ssr_mode=False)