FIRE.DOCS / app.py
DSDUDEd's picture
Update app.py
60cffd8 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 1️⃣ Load the model
MODEL_REPO = "DSDUDEd/firebase" # your HF model repo
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
# Set device (CPU or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 2️⃣ Chat history
chat_history = []
# 3️⃣ Function to generate AI response
def chat_with_ai(user_input):
global chat_history
chat_history.append(f"You: {user_input}")
# Prepare input for the model
input_text = "\n".join(chat_history) + "\nAI:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
# Generate output
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the AI's last response
ai_response = response.split("AI:")[-1].strip()
chat_history.append(f"AI: {ai_response}")
# Display the chat nicely
return "\n".join(chat_history)
# 4️⃣ Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## πŸ€– Custom GPT-2 AI Chat")
chatbot = gr.Textbox(label="Your Message", placeholder="Type here...", lines=2)
output = gr.Textbox(label="Chat Output", interactive=False, lines=15)
send_button = gr.Button("Send")
send_button.click(fn=chat_with_ai, inputs=chatbot, outputs=output)
# 5️⃣ Launch the Space
demo.launch()