Spaces:
Sleeping
Sleeping
File size: 4,756 Bytes
fb0bc70 39fec9a 2fcc959 fb0bc70 7d37afc a0fadff 7d37afc 762089e fb0bc70 85f553e a0fadff 85f553e a0fadff 85f553e a0fadff 85f553e a0fadff 85f553e e852e91 85f553e 6feac3c a0fadff 9222068 a0fadff 6feac3c 28a2ee2 6feac3c 28a2ee2 6feac3c a0fadff 6feac3c a0fadff a2fb7fb 762089e a0fadff 762089e a0fadff 85f553e a0fadff 762089e c72e69a a0fadff 762089e fb0bc70 762089e e852e91 762089e fb0bc70 e852e91 a0fadff 762089e fb0bc70 e852e91 fb0bc70 762089e c72e69a a2fb7fb 762089e fb0bc70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gradio as gr
MODEL_ID = "LMSeed/GPT2-Small-Distilled-900M" #"LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42"
#"openai-community/gpt2"#"LMSeed/GPT2-small-distilled-100M"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
model = model.to("cuda")
def generate_reply(prompt, max_new_tokens, temperature, top_p):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[1]
output_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
no_repeat_ngram_size=3,
repetition_penalty=1.2,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
generated_tokens = output_ids[0][input_len:]
text = tokenizer.decode(
output_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return text[len(prompt):]
def clean_reply(text):
text = text.strip()
stop_words = ["Human:", "User:", "AI:", "Assistant:"]
for word in stop_words:
if word in text:
text = text.split(word)[0]
return text.strip()
# def clean_reply(text):
# text = text.strip()
# for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
# if text.startswith(prefix):
# text = text[len(prefix):].strip()
# lines = [l.strip() for l in text.split("\n")]
# lines = [l for l in lines if l]
# if len(lines) == 0:
# return ""
# return lines[0]
# def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
# if chat_history is None:
# chat_history = []
# # Build conversation history
# # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
# history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \
# The assistant should tell a story according to human's requirment.\n"
# for msg in chat_history:
# role = "Human" if msg["role"] == "user" else "AI"
# history_text += f"{role}: {msg['content']}\n"
# history_text += f"Human: {user_message}\nAI:"
# # -------- generate ----------
# raw = generate_reply(
# history_text,
# max_new_tokens,
# temperature,
# top_p
# )
# # Only keep new part
# reply = raw[len(history_text):]
# reply = clean_reply(reply)
# # ------------------------------
# chat_history.append({"role": "user", "content": user_message})
# chat_history.append({"role": "assistant", "content": reply})
# return "", chat_history, chat_history
def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
if chat_history is None:
chat_history = []
prompt_text = f"User request: {user_message}\n\nHere is a long, creative story based on the request:\nOnce upon a time,"
reply = generate_reply(
prompt_text,
max_new_tokens,
temperature,
top_p
)
final_reply = "Once upon a time, " + clean_reply(reply)
chat_history.append({"role": "user", "content": user_message})
chat_history.append({"role": "assistant", "content": final_reply})
return "", chat_history, chat_history
with gr.Blocks() as demo:
gr.Markdown("# Story generation with Stu")
with gr.Row():
with gr.Column(scale=3):
chat = gr.Chatbot(elem_id="chatbot", label="Story Output")
msg = gr.Textbox(label="What should the story be about?")
send = gr.Button("Generate Story")
max_tokens = gr.Slider(50, 1025, value=300, label="max_new_tokens")
temp = gr.Slider(0.6, 1.5, value=1.0, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, label="top_p")
with gr.Column(scale=1):
gr.Markdown("Model: " + MODEL_ID)
gr.Markdown("Note: Do not input too complex prompts, since the model\
might get confused. This setup is optimized for storytelling.")
state = gr.State([])
send.click(
fn=chat_with_model,
inputs=[msg, state, max_tokens, temp, top_p],
outputs=[msg, chat, state]
)
demo.launch()
|