Fathom / app.py
FractalAIR's picture
Update app.py
c34c8d5 verified
raw
history blame
4.67 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_ID = "FractalAIResearch/Fathom-R1-14B"
@spaces.GPU
def chat_with_model(message, history, max_tokens, temperature):
try:
print("πŸ”₯ GPU allocated, loading model...")
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# EXPLICITLY move model to GPU
model = model.cuda()
print(f"βœ… Model loaded on device: {model.device}")
print(f"πŸ”₯ GPU available: {torch.cuda.is_available()}")
print(f"πŸ”₯ GPU device count: {torch.cuda.device_count()}")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Simple prompt format
prompt = f"User: {message}\nAssistant:"
# Tokenize and move to GPU
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
print(f"βœ… Inputs moved to: {inputs['input_ids'].device}")
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
print(f"βœ… Generated response: {response[:100]}...")
# Update history
history.append([message, response])
return history, history, ""
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
history.append([message, error_msg])
return history, history, ""
# Create Gradio interface
with gr.Blocks(title="Fathom R1 14B Chatbot") as demo:
gr.HTML("<h1>πŸ€– Fathom R1 14B Chatbot</h1>")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=500, label="Conversation")
with gr.Row():
msg = gr.Textbox(
placeholder="Type your message here...",
label="Message",
lines=3,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
clear_btn = gr.Button("Clear Chat")
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens = gr.Slider(
minimum=50,
maximum=2048,
value=512,
step=50,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
gr.Markdown("### Examples")
gr.Examples(
examples=[
"Solve: 2x + 5 = 15",
"Explain quantum mechanics simply",
"What is the derivative of xΒ²?",
],
inputs=msg
)
# Chat history state
history = gr.State([])
# Event handlers
def user_submit(message, hist):
return hist + [[message, None]], hist + [[message, None]], ""
def bot_respond(hist, max_tok, temp):
if hist and hist[-1][1] is None:
message = hist[-1][0]
_, updated_hist, _ = chat_with_model(message, hist[:-1], max_tok, temp)
return updated_hist, updated_hist
return hist, hist
# Submit message
msg.submit(
user_submit,
[msg, history],
[chatbot, history, msg]
).then(
bot_respond,
[history, max_tokens, temperature],
[chatbot, history]
)
send_btn.click(
user_submit,
[msg, history],
[chatbot, history, msg]
).then(
bot_respond,
[history, max_tokens, temperature],
[chatbot, history]
)
# Clear chat
clear_btn.click(
lambda: ([], []),
outputs=[chatbot, history]
)
if __name__ == "__main__":
demo.launch()