Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # Load the model and tokenizer | |
| model_name = "TheDrummer/Gemmasutra-Mini-2B-v1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Ensure model runs on CPU (default for Hugging Face Spaces free tier) | |
| device = torch.device("cpu") | |
| model.to(device) | |
| # Chatbot function | |
| def chat_with_model(user_input, history): | |
| # Format history and input into a single prompt | |
| if history is None: | |
| history = [] | |
| # Build conversation context | |
| prompt = "" | |
| for h in history: | |
| prompt += f"User: {h[0]}\nBot: {h[1]}\n" | |
| prompt += f"User: {user_input}\nBot: " | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Generate response | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, # Limit response length | |
| do_sample=True, # Enable sampling for varied responses | |
| temperature=0.7, # Control creativity | |
| top_p=0.9, # Nucleus sampling | |
| pad_token_id=tokenizer.eos_token_id # Handle padding | |
| ) | |
| # Decode response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract just the bot's reply (after the last "Bot: ") | |
| bot_response = response.split("Bot: ")[-1].strip() | |
| # Update history | |
| history.append((user_input, bot_response)) | |
| return bot_response, history | |
| # Gradio Interface | |
| with gr.Blocks(title="Grok-like Chatbot") as iface: | |
| gr.Markdown("## Chat with Gemmasutra-Mini-2B-v1") | |
| chatbot = gr.Chatbot(label="Conversation") | |
| msg = gr.Textbox(label="Your Message", placeholder="Type here...") | |
| submit_btn = gr.Button("Send") | |
| # State to maintain conversation history | |
| state = gr.State(value=[]) | |
| def submit_message(user_input, history): | |
| response, updated_history = chat_with_model(user_input, history) | |
| return response, updated_history, updated_history, "" | |
| # Connect button and enter key to submit | |
| submit_btn.click( | |
| fn=submit_message, | |
| inputs=[msg, state], | |
| outputs=[msg, state, chatbot, msg] # Clear input after submission | |
| ) | |
| msg.submit( | |
| fn=submit_message, | |
| inputs=[msg, state], | |
| outputs=[msg, state, chatbot, msg] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |