Spaces:
Paused
Paused
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| MODEL_ID = "FractalAIResearch/Fathom-R1-14B" | |
| def chat_with_model(message, history, max_tokens, temperature): | |
| try: | |
| print("π₯ GPU allocated, loading model...") | |
| # Load model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ) | |
| # EXPLICITLY move model to GPU | |
| model = model.cuda() | |
| print(f"β Model loaded on device: {model.device}") | |
| print(f"π₯ GPU available: {torch.cuda.is_available()}") | |
| print(f"π₯ GPU device count: {torch.cuda.device_count()}") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Simple prompt format | |
| prompt = f"User: {message}\nAssistant:" | |
| # Tokenize and move to GPU | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| print(f"β Inputs moved to: {inputs['input_ids'].device}") | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode response | |
| response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True) | |
| print(f"β Generated response: {response[:100]}...") | |
| # Update history | |
| history.append([message, response]) | |
| return history, history, "" | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| print(error_msg) | |
| history.append([message, error_msg]) | |
| return history, history, "" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Fathom R1 14B Chatbot") as demo: | |
| gr.HTML("<h1>π€ Fathom R1 14B Chatbot</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=500, label="Conversation") | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Message", | |
| lines=3, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear Chat") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=2048, | |
| value=512, | |
| step=50, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| gr.Markdown("### Examples") | |
| gr.Examples( | |
| examples=[ | |
| "Solve: 2x + 5 = 15", | |
| "Explain quantum mechanics simply", | |
| "What is the derivative of xΒ²?", | |
| ], | |
| inputs=msg | |
| ) | |
| # Chat history state | |
| history = gr.State([]) | |
| # Event handlers | |
| def user_submit(message, hist): | |
| return hist + [[message, None]], hist + [[message, None]], "" | |
| def bot_respond(hist, max_tok, temp): | |
| if hist and hist[-1][1] is None: | |
| message = hist[-1][0] | |
| _, updated_hist, _ = chat_with_model(message, hist[:-1], max_tok, temp) | |
| return updated_hist, updated_hist | |
| return hist, hist | |
| # Submit message | |
| msg.submit( | |
| user_submit, | |
| [msg, history], | |
| [chatbot, history, msg] | |
| ).then( | |
| bot_respond, | |
| [history, max_tokens, temperature], | |
| [chatbot, history] | |
| ) | |
| send_btn.click( | |
| user_submit, | |
| [msg, history], | |
| [chatbot, history, msg] | |
| ).then( | |
| bot_respond, | |
| [history, max_tokens, temperature], | |
| [chatbot, history] | |
| ) | |
| # Clear chat | |
| clear_btn.click( | |
| lambda: ([], []), | |
| outputs=[chatbot, history] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |