Spaces:
Sleeping
Sleeping
File size: 8,397 Bytes
9b6a25b c1cb000 e30ac3f c1cb000 f8246c7 f14967b f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 9b6a25b addf70e f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 55861c0 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 ae84060 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 ae84060 f8246c7 ae84060 f8246c7 c1cb000 f8246c7 c1cb000 f8246c7 c1cb000 ae84060 c1cb000 f8246c7 c1cb000 f8246c7 ae84060 c1cb000 f8246c7 c1cb000 f8246c7 ae84060 f8246c7 ae84060 f8246c7 c1cb000 f8246c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model and tokenizer configuration
MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
try:
logger.info("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16, # Optimize for H200 GPUs
device_map="auto", # Automatically distribute across GPU
trust_remote_code=True # Required for Qwen2-based models
)
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Error loading model or tokenizer: {str(e)}")
raise e
# Ensure model is on GPU
#model = model.to(device)
@spaces.GPU(duration=300)
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
logger.info("Empty message received, returning history unchanged.")
return history_state, history_state
try:
logger.info("Processing new message...")
# System prompt for Fathom-R1-14B
system_message = "You are a helpful assistant, specialising at math and STEM reasoning."
# Build messages list using Qwen2 chat template format
messages = [{"role": "system", "content": system_message}]
for message in history_state:
messages.append({"role": message["role"], "content": message["content"]})
messages.append({"role": "user", "content": user_message})
# Apply Qwen2 chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Configure sampling
do_sample = not (temperature == 1.0 and top_k >= 100 and top_p == 1.0)
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
# Generation parameters
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"do_sample": do_sample,
"temperature": temperature,
"top_k": int(top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id
}
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
for new_token in streamer:
assistant_response += new_token
new_history[-1]["content"] = assistant_response.strip()
yield new_history, new_history
logger.info("Response generated successfully.")
yield new_history, new_history
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return f"Error: {str(e)}", history_state
# Example prompts
example_messages = {
"IIT-JEE 2025 Physics": "A person sitting inside an elevator performs a weighing experiment with an object of mass 50 kg. Suppose that the variation of the height π¦ (in m) of the elevator, from the ground, with time π‘ (in s) is given by π¦ = 8 [1 + sin ( 2ππ‘/π )], where π = 40π s. Taking acceleration due to gravity, π = 10 m/s^2 , the maximum variation of the objectβs weight (in N) as observed in the experiment is ?",
"Goldman Sachs Interview Puzzle": "Four friends need to cross a dangerous bridge at night. Unfortunately, they have only one torch and the bridge is too dangerous to cross without one. The bridge is only strong enough to support two people at a time. Not all people take the same time to cross the bridge. Times for each person: 1 min, 2 mins, 7 mins and 10 mins. What is the shortest time needed for all four of them to cross the bridge?",
"IIT-JEE 2025 Mathematics": "Let π be the set of all seven-digit numbers that can be formed using the digits 0, 1 and 2. For example, 2210222 is in π, but 0210222 is NOT in π.Then the number of elements π₯ in π such that at least one of the digits 0 and 1 appears exactly twice in π₯, is ?"
}
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Fathom-R1-14B Chatbot
Welcome to the Fathom-R1-14B Chatbot! This model excels at multi-step reasoning tasks in mathematics, logic, and science.
The model specializes in math and STEM reasoning, providing detailed step-by-step solutions.
Try the example problems below to see how the model breaks down complex reasoning problems.
"""
)
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(
minimum=8192,
maximum=16384, # Fathomβs context window is 16K
step=1024,
value=16384,
label="Max Tokens"
)
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
label="Temperature"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top-k"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
label="Top-p"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.0,
label="Repetition Penalty"
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", type="messages")
with gr.Row():
user_input = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
scale=3
)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
example1_button = gr.Button("IIT-JEE 2025 Mathematics")
example2_button = gr.Button("Goldman Sachs Interview Puzzle")
example3_button = gr.Button("IIT-JEE 2025 Physics")
submit_button.click(
fn=generate_response,
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=user_input
)
clear_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[chatbot, history_state]
)
example1_button.click(
fn=lambda: gr.update(value=example_messages["IIT-JEE 2025 Mathematics"]),
inputs=None,
outputs=user_input
)
example2_button.click(
fn=lambda: gr.update(value=example_messages["Goldman Sachs Interview Puzzle"]),
inputs=None,
outputs=user_input
)
example3_button.click(
fn=lambda: gr.update(value=example_messages["IIT-JEE 2025 Physics"]),
inputs=None,
outputs=user_input
)
demo.launch(ssr_mode=False) |