Chat_with_Stu / app.py
Jia0603's picture
Update app.py
39fec9a verified
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gradio as gr
MODEL_ID = "LMSeed/GPT2-Small-Distilled-900M" #"LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42"
#"openai-community/gpt2"#"LMSeed/GPT2-small-distilled-100M"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
model = model.to("cuda")
def generate_reply(prompt, max_new_tokens, temperature, top_p):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[1]
output_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
no_repeat_ngram_size=3,
repetition_penalty=1.2,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
generated_tokens = output_ids[0][input_len:]
text = tokenizer.decode(
output_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return text[len(prompt):]
def clean_reply(text):
text = text.strip()
stop_words = ["Human:", "User:", "AI:", "Assistant:"]
for word in stop_words:
if word in text:
text = text.split(word)[0]
return text.strip()
# def clean_reply(text):
# text = text.strip()
# for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
# if text.startswith(prefix):
# text = text[len(prefix):].strip()
# lines = [l.strip() for l in text.split("\n")]
# lines = [l for l in lines if l]
# if len(lines) == 0:
# return ""
# return lines[0]
# def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
# if chat_history is None:
# chat_history = []
# # Build conversation history
# # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
# history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \
# The assistant should tell a story according to human's requirment.\n"
# for msg in chat_history:
# role = "Human" if msg["role"] == "user" else "AI"
# history_text += f"{role}: {msg['content']}\n"
# history_text += f"Human: {user_message}\nAI:"
# # -------- generate ----------
# raw = generate_reply(
# history_text,
# max_new_tokens,
# temperature,
# top_p
# )
# # Only keep new part
# reply = raw[len(history_text):]
# reply = clean_reply(reply)
# # ------------------------------
# chat_history.append({"role": "user", "content": user_message})
# chat_history.append({"role": "assistant", "content": reply})
# return "", chat_history, chat_history
def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
if chat_history is None:
chat_history = []
prompt_text = f"User request: {user_message}\n\nHere is a long, creative story based on the request:\nOnce upon a time,"
reply = generate_reply(
prompt_text,
max_new_tokens,
temperature,
top_p
)
final_reply = "Once upon a time, " + clean_reply(reply)
chat_history.append({"role": "user", "content": user_message})
chat_history.append({"role": "assistant", "content": final_reply})
return "", chat_history, chat_history
with gr.Blocks() as demo:
gr.Markdown("# Story generation with Stu")
with gr.Row():
with gr.Column(scale=3):
chat = gr.Chatbot(elem_id="chatbot", label="Story Output")
msg = gr.Textbox(label="What should the story be about?")
send = gr.Button("Generate Story")
max_tokens = gr.Slider(50, 1025, value=300, label="max_new_tokens")
temp = gr.Slider(0.6, 1.5, value=1.0, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, label="top_p")
with gr.Column(scale=1):
gr.Markdown("Model: " + MODEL_ID)
gr.Markdown("Note: Do not input too complex prompts, since the model\
might get confused. This setup is optimized for storytelling.")
state = gr.State([])
send.click(
fn=chat_with_model,
inputs=[msg, state, max_tokens, temp, top_p],
outputs=[msg, chat, state]
)
demo.launch()