FractalAIR commited on
Commit
9e474b9
·
verified ·
1 Parent(s): 3670181

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -176
app.py CHANGED
@@ -1,226 +1,221 @@
1
- # app.py – Gradio chatbot for FractalAIResearch/Fathom-R1-14B
2
- # ---------------------------------------------------------------------
3
  import gradio as gr
4
  import spaces
 
5
  import torch
6
  from threading import Thread
7
- from transformers import (
8
- AutoModelForCausalLM,
9
- AutoTokenizer,
10
- TextIteratorStreamer,
11
- )
12
 
13
- # ---------------------------------------------------------------------
14
- # 1. Model & tokenizer
15
- # ---------------------------------------------------------------------
16
- MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
17
-
18
- print("⏳ Loading model … (this may take a couple of minutes)")
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
21
- device_map="auto", # dispatch across any available device(s)
22
- trust_remote_code=True, # Fathom uses custom modelling code
23
- low_cpu_mem_usage=True,
24
- )
25
- tokenizer = AutoTokenizer.from_pretrained(
26
- MODEL_NAME,
27
- trust_remote_code=True,
28
- )
29
 
30
- print(" Model loaded")
 
31
 
32
- # ---------------------------------------------------------------------
33
- # 2. Helper: build a prompt with the tokenizer’s chat_template
34
- # ---------------------------------------------------------------------
35
- def build_chat_prompt(history, user_message, system_message):
36
- """
37
- history : list[dict(role, content)]
38
- user_message : str
39
- system_message : str
40
- returns a single prompt string (not tokenised)
41
- """
42
- msgs = []
43
- if system_message:
44
- msgs.append({"role": "system", "content": system_message})
45
- msgs.extend(history)
46
- msgs.append({"role": "user", "content": user_message})
47
-
48
- return tokenizer.apply_chat_template(
49
- msgs,
50
- tokenize=False, # return pure text
51
- add_generation_prompt=True,
52
- )
53
 
54
- # ---------------------------------------------------------------------
55
- # 3. Generation endpoint
56
- # ---------------------------------------------------------------------
57
- @spaces.GPU(duration=60) # short GPU reservation if available
58
- def generate_response(
59
- user_message,
60
- max_tokens,
61
- temperature,
62
- top_k,
63
- top_p,
64
- repetition_penalty,
65
- history_state,
66
- ):
67
- # Empty input → nothing to do
68
- if not user_message.strip():
69
- return history_state, history_state
70
 
71
- # System prompt (kept from your Phi-4 version)
72
- system_message = (
73
- "Your role as an assistant involves thoroughly exploring questions through a "
74
- "systematic thinking process before providing the final precise and accurate "
75
- "solutions. Please structure your response into two main sections: "
76
- "<think> … </think> and Solution."
77
- )
78
 
79
- prompt = build_chat_prompt(history_state, user_message, system_message)
80
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
81
-
82
- # Stream tokens as they come
83
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
84
-
85
- generation_kwargs = dict(
86
- input_ids=inputs["input_ids"],
87
- attention_mask=inputs["attention_mask"],
88
- max_new_tokens=int(max_tokens),
89
- do_sample=True,
90
- temperature=float(temperature),
91
- top_k=int(top_k),
92
- top_p=float(top_p),
93
- repetition_penalty=float(repetition_penalty),
94
- streamer=streamer,
95
- )
96
 
97
- # Run generate in a background thread so the UI stays responsive
98
- Thread(target=model.generate, kwargs=generation_kwargs).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  assistant_response = ""
101
  new_history = history_state + [
102
  {"role": "user", "content": user_message},
103
- {"role": "assistant", "content": ""},
104
  ]
105
 
106
- for token in streamer:
107
- # strip any stray special tokens the model may output
108
- cleaned = (
109
- token.replace("<|im_start|>", "")
110
- .replace("<|im_end|>", "")
111
- .replace("<|im_sep|>", "")
112
- )
113
- assistant_response += cleaned
114
- new_history[-1]["content"] = assistant_response.strip()
115
- yield new_history, new_history
116
 
117
  yield new_history, new_history
118
 
119
- # ---------------------------------------------------------------------
120
- # 4. Example questions (unchanged)
121
- # ---------------------------------------------------------------------
122
  example_messages = {
123
- "Math reasoning": "If a rectangular prism has a length of 6 cm, a width of 4 cm, and a height of 5 cm, what is the length of the longest line segment that can be drawn from one vertex to another?",
124
- "Logic puzzle": "Four people (Alex, Blake, Casey, and Dana) each have a different favorite color (red, blue, green, yellow) and a different favorite fruit (apple, banana, cherry, date). Given the following clues: 1) The person who likes red doesn't like dates. 2) Alex likes yellow. 3) The person who likes blue likes cherries. 4) Blake doesn't like apples or bananas. 5) Casey doesn't like yellow or green. Who likes what color and what fruit?",
125
- "Physics problem": "A ball is thrown upward with an initial velocity of 15 m/s from a height of 2 meters above the ground. Assuming the acceleration due to gravity is 9.8 m/s², determine: 1) The maximum height the ball reaches. 2) The total time the ball is in the air before hitting the ground. 3) The velocity with which the ball hits the ground.",
 
126
  }
127
 
128
- # ---------------------------------------------------------------------
129
- # 5. Gradio UI (identical to the original, just lower default max_tokens)
130
- # ---------------------------------------------------------------------
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
- gr.Markdown(
133
- """
134
- # Fathom-R1-14B Chatbot
135
- The model excels at multi-step reasoning in mathematics, logic, and science.
136
-
137
- It returns two sections:\n
138
- 1. **<think>** detailed chain-of-thought (reasoning)\n
139
- 2. **Solution** – concise, final answer
140
- """
141
- )
 
 
 
 
 
 
 
142
 
 
143
  history_state = gr.State([])
144
 
145
  with gr.Row():
146
- # Settings panel
147
  with gr.Column(scale=1):
148
- gr.Markdown("### Settings")
149
- max_tokens_slider = gr.Slider(
150
- minimum=64, maximum=4096, step=256, value=1024, label="Max Tokens"
 
 
 
 
 
 
 
 
151
  )
152
- with gr.Accordion("Advanced Settings", open=False):
153
- temperature_slider = gr.Slider(
154
- minimum=0.1, maximum=2.0, value=0.8, label="Temperature"
155
- )
156
- top_k_slider = gr.Slider(
157
- minimum=1, maximum=100, step=1, value=50, label="Top-k"
158
- )
159
- top_p_slider = gr.Slider(
160
- minimum=0.1, maximum=1.0, value=0.95, label="Top-p"
161
- )
162
- repetition_penalty_slider = gr.Slider(
163
- minimum=1.0, maximum=2.0, value=1.0, label="Repetition Penalty"
164
- )
165
-
166
- # Chat area
167
  with gr.Column(scale=4):
168
- chatbot = gr.Chatbot(label="Chat", type="messages")
 
169
  with gr.Row():
170
- user_input = gr.Textbox(
171
- label="Your message", placeholder="Type your message here…", scale=3
172
- )
173
- submit_button = gr.Button("Send", variant="primary", scale=1)
174
- clear_button = gr.Button("Clear", scale=1)
175
  gr.Markdown("**Try these examples:**")
176
  with gr.Row():
177
- example1_button = gr.Button("Math reasoning")
178
- example2_button = gr.Button("Logic puzzle")
179
- example3_button = gr.Button("Physics problem")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Button wiring
182
  submit_button.click(
183
- fn=generate_response,
184
- inputs=[
185
- user_input,
186
- max_tokens_slider,
187
- temperature_slider,
188
- top_k_slider,
189
- top_p_slider,
190
- repetition_penalty_slider,
191
- history_state,
192
- ],
193
- outputs=[chatbot, history_state],
194
  ).then(
195
  fn=lambda: gr.update(value=""),
196
  inputs=None,
197
- outputs=user_input,
198
  )
199
 
200
  clear_button.click(
201
  fn=lambda: ([], []),
202
  inputs=None,
203
- outputs=[chatbot, history_state],
204
  )
205
 
206
- example1_button.click(
207
- fn=lambda: gr.update(value=example_messages["Math reasoning"]),
208
- inputs=None,
209
- outputs=user_input,
210
- )
211
- example2_button.click(
212
- fn=lambda: gr.update(value=example_messages["Logic puzzle"]),
213
  inputs=None,
214
- outputs=user_input,
215
  )
216
- example3_button.click(
217
- fn=lambda: gr.update(value=example_messages["Physics problem"]),
218
- inputs=None,
219
- outputs=user_input,
 
220
  )
221
 
222
- # ---------------------------------------------------------------------
223
- # 6. Launch
224
- # ---------------------------------------------------------------------
225
- if __name__ == "__main__":
226
- demo.launch(ssr_mode=False)
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
6
+ import re
7
+ import uuid
 
 
 
8
 
9
+ # Load model and tokenizer
10
+ our_model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ our_model = AutoModelForCausalLM.from_pretrained(our_model_path, device_map="auto", torch_dtype="auto")
14
+ our_tokenizer = AutoTokenizer.from_pretrained(our_model_path)
15
 
16
+ def format_math(text):
17
+ text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL)
18
+ text = text.replace(r"\(", "$").replace(r"\)", "$")
19
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Global dictionary to store all conversations: {id: {"title": str, "messages": list}}
22
+ conversations = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def generate_conversation_id():
25
+ return str(uuid.uuid4())[:8]
 
 
 
 
 
26
 
27
+ @spaces.GPU(duration=60)
28
+ def generate_response(user_message, max_tokens, temperature, top_p, history_state):
29
+ if not user_message.strip():
30
+ return history_state, history_state
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ model = our_model
33
+ tokenizer = our_tokenizer
34
+ start_tag = "<|im_start|>"
35
+ sep_tag = "<|im_sep|>"
36
+ end_tag = "<|im_end|>"
37
+
38
+ system_message = "Your role as an assistant..."
39
+ prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
40
+ for message in history_state:
41
+ if message["role"] == "user":
42
+ prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
43
+ elif message["role"] == "assistant" and message["content"]:
44
+ prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
45
+ prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
46
+
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
49
+
50
+ generation_kwargs = {
51
+ "input_ids": inputs["input_ids"],
52
+ "attention_mask": inputs["attention_mask"],
53
+ "max_new_tokens": int(max_tokens),
54
+ "do_sample": True,
55
+ "temperature": temperature,
56
+ "top_k": 50,
57
+ "top_p": top_p,
58
+ "repetition_penalty": 1.0,
59
+ "pad_token_id": tokenizer.eos_token_id,
60
+ "streamer": streamer,
61
+ }
62
+
63
+ try:
64
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
+ thread.start()
66
+ except Exception:
67
+ yield history_state + [{"role": "user", "content": user_message}, {"role": "assistant", "content": "⚠️ Generation failed."}], history_state
68
+ return
69
 
70
  assistant_response = ""
71
  new_history = history_state + [
72
  {"role": "user", "content": user_message},
73
+ {"role": "assistant", "content": ""}
74
  ]
75
 
76
+ try:
77
+ for new_token in streamer:
78
+ if "<|end" in new_token:
79
+ continue
80
+ cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "")
81
+ assistant_response += cleaned_token
82
+ new_history[-1]["content"] = assistant_response.strip()
83
+ yield new_history, new_history
84
+ except Exception:
85
+ pass
86
 
87
  yield new_history, new_history
88
 
89
+
 
 
90
  example_messages = {
91
+ "JEE Main 2025 Combinatorics": "From all the English alphabets, five letters are chosen and are arranged in alphabetical order. The total number of ways, in which the middle letter is 'M', is?",
92
+ "JEE Main 2025 Coordinate Geometry": "A circle \\(C\\) of radius 2 lies in the second quadrant and touches both the coordinate axes. Let \\(r\\) be the radius of a circle that has centre at the point \\((2, 5)\\) and intersects the circle \\(C\\) at exactly two points. If the set of all possible values of \\(r\\) is the interval \\((\\alpha, \\beta)\\), then \\(3\\beta - 2\\alpha\\) is?",
93
+ "JEE Main 2025 Probability & Statistics": "A coin is tossed three times. Let \(X\) denote the number of times a tail follows a head. If \\(\\mu\\) and \\(\\sigma^2\\) denote the mean and variance of \\(X\\), then the value of \\(64(\\mu + \\sigma^2)\\) is?",
94
+ "JEE Main 2025 Laws of Motion": "A massless spring gets elongated by amount x_1 under a tension of 5 N . Its elongation is x_2 under the tension of 7 N . For the elongation of 5x_1 - 2x_2 , the tension in the spring will be?"
95
  }
96
 
 
 
 
97
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
98
+ # Global heading stays at top
99
+ #gr.Markdown("# Ramanujan Ganit R1 14B V1 Chatbot")
100
+ gr.HTML(
101
+ """
102
+ <div style="display: flex; align-items: center; gap: 16px; margin-bottom: 1em;">
103
+ <div style="background-color: black; padding: 6px; border-radius: 8px;">
104
+ <img src="https://framerusercontent.com/images/j0KjQQyrUfkFw4NwSaxQOLAoBU.png" alt="Fractal AI Logo" style="height: 48px;">
105
+ </div>
106
+ <h1 style="margin: 0;">Ramanujan Ganit R1 14B V1 Chatbot</h1>
107
+ </div>
108
+ """
109
+ )
110
+
111
+ with gr.Sidebar():
112
+ gr.Markdown("## Conversations")
113
+ conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True)
114
+ new_convo_button = gr.Button("New Conversation ➕")
115
 
116
+ current_convo_id = gr.State(generate_conversation_id())
117
  history_state = gr.State([])
118
 
119
  with gr.Row():
 
120
  with gr.Column(scale=1):
121
+ # INTRO TEXT MOVED HERE
122
+ gr.Markdown(
123
+ """
124
+ Welcome to the Ramanujan Ganit R1 14B V1 Chatbot, developed by Fractal AI Research!
125
+
126
+ Our model excels at reasoning tasks in mathematics and science.
127
+
128
+ Try the example problems below from JEE Main 2025 or type in your own problems to see how our model breaks down complex reasoning problems.
129
+
130
+ Please note that once you close this demo window, all currently saved conversations will be lost.
131
+ """
132
  )
133
+
134
+ gr.Markdown("### Settings")
135
+ max_tokens_slider = gr.Slider(minimum=6144, maximum=32768, step=1024, value=16384, label="Max Tokens")
136
+ with gr.Accordion("Advanced Settings", open=True):
137
+ temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.6, label="Temperature")
138
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
139
+
140
+ # New acknowledgment line at bottom
141
+ gr.Markdown("""
142
+
143
+ We sincerely acknowledge [VIDraft](https://huggingface.co/VIDraft) for their Phi 4 Reasoning Plus [space](https://huggingface.co/spaces/VIDraft/phi-4-reasoning-plus), which served as the starting point for this demo.
144
+ """
145
+ )
146
+
 
147
  with gr.Column(scale=4):
148
+ #chatbot = gr.Chatbot(label="Chat", type="messages")
149
+ chatbot = gr.Chatbot(label="Chat", type="messages", height=520)
150
  with gr.Row():
151
+ user_input = gr.Textbox(label="User Input", placeholder="Type your question here...", lines=3, scale=8)
152
+ with gr.Column():
153
+ submit_button = gr.Button("Send", variant="primary", scale=1)
154
+ clear_button = gr.Button("Clear", scale=1)
 
155
  gr.Markdown("**Try these examples:**")
156
  with gr.Row():
157
+ example1_button = gr.Button("JEE Main 2025\nCombinatorics")
158
+ example2_button = gr.Button("JEE Main 2025\nCoordinate Geometry")
159
+ example3_button = gr.Button("JEE Main 2025\nProbability & Statistics")
160
+ example4_button = gr.Button("JEE Main 2025\nLaws of Motion")
161
+
162
+ def update_conversation_list():
163
+ return [conversations[cid]["title"] for cid in conversations]
164
+
165
+ def start_new_conversation():
166
+ new_id = generate_conversation_id()
167
+ conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []}
168
+ return new_id, [], gr.update(choices=update_conversation_list(), value=conversations[new_id]["title"])
169
+
170
+ def load_conversation(selected_title):
171
+ for cid, convo in conversations.items():
172
+ if convo["title"] == selected_title:
173
+ return cid, convo["messages"], convo["messages"]
174
+ return current_convo_id.value, history_state.value, history_state.value
175
+
176
+ def send_message(user_message, max_tokens, temperature, top_p, convo_id, history):
177
+ if convo_id not in conversations:
178
+ #title = user_message.strip().split("\n")[0][:40]
179
+ title = " ".join(user_message.strip().split()[:5])
180
+ conversations[convo_id] = {"title": title, "messages": history}
181
+ if conversations[convo_id]["title"].startswith("New Conversation"):
182
+ #conversations[convo_id]["title"] = user_message.strip().split("\n")[0][:40]
183
+ conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5])
184
+ for updated_history, new_history in generate_response(user_message, max_tokens, temperature, top_p, history):
185
+ conversations[convo_id]["messages"] = new_history
186
+ yield updated_history, new_history, gr.update(choices=update_conversation_list(), value=conversations[convo_id]["title"])
187
 
 
188
  submit_button.click(
189
+ fn=send_message,
190
+ inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider, current_convo_id, history_state],
191
+ outputs=[chatbot, history_state, conversation_selector]
 
 
 
 
 
 
 
 
192
  ).then(
193
  fn=lambda: gr.update(value=""),
194
  inputs=None,
195
+ outputs=user_input
196
  )
197
 
198
  clear_button.click(
199
  fn=lambda: ([], []),
200
  inputs=None,
201
+ outputs=[chatbot, history_state]
202
  )
203
 
204
+ new_convo_button.click(
205
+ fn=start_new_conversation,
 
 
 
 
 
206
  inputs=None,
207
+ outputs=[current_convo_id, history_state, conversation_selector]
208
  )
209
+
210
+ conversation_selector.change(
211
+ fn=load_conversation,
212
+ inputs=conversation_selector,
213
+ outputs=[current_convo_id, history_state, chatbot]
214
  )
215
 
216
+ example1_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Combinatorics"]), inputs=None, outputs=user_input)
217
+ example2_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Coordinate Geometry"]), inputs=None, outputs=user_input)
218
+ example3_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Probability & Statistics"]), inputs=None, outputs=user_input)
219
+ example4_button.click(fn=lambda: gr.update(value=example_messages["JEE Main 2025 Laws of Motion"]), inputs=None, outputs=user_input)
220
+
221
+ demo.launch(share=True, ssr_mode=False)