Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Model name | |
| model_name = "DSDUDEd/firebase" | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", # automatically assigns to GPU if available | |
| load_in_8bit=True # load in 8-bit to save memory | |
| ) | |
| # Function to generate AI responses | |
| def chat_with_model(user_input, chat_history=[]): | |
| chat_history.append({"role": "user", "content": user_input}) | |
| # Build the prompt from chat history | |
| prompt = "" | |
| for turn in chat_history: | |
| if turn["role"] == "user": | |
| prompt += f"User: {turn['content']}\n" | |
| else: | |
| prompt += f"AI: {turn['content']}\n" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.7, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Get only the AI's response | |
| response_text = response.split("AI:")[-1].strip() | |
| chat_history.append({"role": "ai", "content": response_text}) | |
| # Prepare Gradio chat format | |
| chat_for_gradio = [(turn["content"], "") if turn["role"]=="user" else ("", turn["content"]) for turn in chat_history] | |
| return chat_for_gradio, chat_history | |
| # Build Gradio interface | |
| with gr.Blocks() as demo: | |
| chat_history_state = gr.State([]) | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(label="Enter your message") | |
| submit = gr.Button("Send") | |
| submit.click(chat_with_model, inputs=[msg, chat_history_state], outputs=[chatbot, chat_history_state]) | |
| demo.launch() | |