experiment / app.py
liumi-model-spaces's picture
Create app.py
5d0852f verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch
# 1. Load the Model and Tokenizer
# We use the specific 'openai-community/gpt2' as requested
model_id = "openai-community/gpt2"
print(f"Loading {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# 2. Define the Chat Logic
def chat_response(message, history):
"""
This function handles the logic for generating a response.
It formats the input to look like a conversation to help GPT-2 understand context.
"""
# Construct a prompt based on conversation history
# GPT-2 is a raw completion model, so we format it as a script:
# User: Hello
# AI: Hi there
conversation = ""
for user_msg, ai_msg in history:
conversation += f"User: {user_msg}\nAI: {ai_msg}\n"
# Add the current message
prompt = conversation + f"User: {message}\nAI:"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt")
# Setup the streamer (this creates the 'typing' effect)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generation arguments
generation_kwargs = dict(
inputs=inputs.input_ids,
max_new_tokens=100, # Limit response length
temperature=0.7, # Controls randomness (0.7 is a good balance)
top_p=0.9, # Nucleus sampling for better quality
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
no_repeat_ngram_size=2 # Prevents the model from repeating phrases
)
# Run generation in a separate thread so the GUI doesn't freeze
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield the text chunk by chunk to the Gradio interface
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# 3. Create the GUI
# We use a custom theme to make it look professional
theme = gr.themes.Soft(
primary_hue="purple",
secondary_hue="cyan",
).set(
body_background_fill="*neutral_50",
)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(
"""
# 🤖 Classic GPT-2 Chat
### Powered by `openai-community/gpt2`
Type a message below to start chatting!
*Note: GPT-2 is an older model (2019), so it may be whimsical or repetitive.*
"""
)
# The Chat Interface
gr.ChatInterface(
fn=chat_response,
retry_btn=None,
undo_btn=None,
clear_btn="Clear Chat",
examples=["Hello, how are you?", "Tell me a story about a dragon.", "What is the capital of France?"],
)
# 4. Launch the App
if __name__ == "__main__":
demo.launch()