example_test / app.py
Wenye He
Create app.py
84a76f5 verified
raw
history blame
2.19 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Choose your model – here we use GPT-2 as an example
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_response(user_input, chat_history):
"""
This function takes the user's input and current conversation history,
appends the input to the history, builds the conversation string, and
generates a response using the local LLM.
"""
if chat_history is None:
chat_history = []
# Append the user message to the conversation history.
chat_history.append(("User", user_input))
# Build a conversation string from the history.
conversation = ""
for speaker, message in chat_history:
conversation += f"{speaker}: {message}\n"
conversation += "AI:" # Signal for the model to generate AI's response
# Tokenize the input and generate a response.
input_ids = tokenizer.encode(conversation, return_tensors="pt")
output_ids = model.generate(
input_ids,
max_length=input_ids.shape[1] + 50, # Adjust max_length as needed
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract only the AI response (everything after the last "AI:" prompt).
ai_response = generated_text[len(conversation):].strip().split("\n")[0]
chat_history.append(("AI", ai_response))
# Return an empty string (to clear the input box) and updated chat history.
return "", chat_history
# Build the Gradio interface using Blocks for a flexible layout.
with gr.Blocks() as demo:
gr.Markdown("# Local LLM Chatbot")
# Chatbot display widget
chatbot = gr.Chatbot()
# Hidden state to hold the conversation history
state = gr.State([])
# Textbox for user input
txt = gr.Textbox(placeholder="Enter your message and press Enter")
# When the textbox is submitted, generate a response.
txt.submit(generate_response, [txt, state], [txt, chatbot])
# Launch the interface
demo.launch()