Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import gradio as gr | |
| MODEL_ID = "LMSeed/GPT2-Small-Distilled-900M" #"LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42" | |
| #"openai-community/gpt2"#"LMSeed/GPT2-small-distilled-100M" | |
| device = 0 if torch.cuda.is_available() else -1 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token_id | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| def generate_reply(prompt, max_new_tokens, temperature, top_p): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[1] | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=True, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=1.2, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_tokens = output_ids[0][input_len:] | |
| text = tokenizer.decode( | |
| output_ids[0], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| return text[len(prompt):] | |
| def clean_reply(text): | |
| text = text.strip() | |
| stop_words = ["Human:", "User:", "AI:", "Assistant:"] | |
| for word in stop_words: | |
| if word in text: | |
| text = text.split(word)[0] | |
| return text.strip() | |
| # def clean_reply(text): | |
| # text = text.strip() | |
| # for prefix in ["Assistant:", "assistant:", "User:", "user:"]: | |
| # if text.startswith(prefix): | |
| # text = text[len(prefix):].strip() | |
| # lines = [l.strip() for l in text.split("\n")] | |
| # lines = [l for l in lines if l] | |
| # if len(lines) == 0: | |
| # return "" | |
| # return lines[0] | |
| # def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9): | |
| # if chat_history is None: | |
| # chat_history = [] | |
| # # Build conversation history | |
| # # history_text = "The following is a friendly conversation between a human and an AI assistant.\n" | |
| # history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \ | |
| # The assistant should tell a story according to human's requirment.\n" | |
| # for msg in chat_history: | |
| # role = "Human" if msg["role"] == "user" else "AI" | |
| # history_text += f"{role}: {msg['content']}\n" | |
| # history_text += f"Human: {user_message}\nAI:" | |
| # # -------- generate ---------- | |
| # raw = generate_reply( | |
| # history_text, | |
| # max_new_tokens, | |
| # temperature, | |
| # top_p | |
| # ) | |
| # # Only keep new part | |
| # reply = raw[len(history_text):] | |
| # reply = clean_reply(reply) | |
| # # ------------------------------ | |
| # chat_history.append({"role": "user", "content": user_message}) | |
| # chat_history.append({"role": "assistant", "content": reply}) | |
| # return "", chat_history, chat_history | |
| def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9): | |
| if chat_history is None: | |
| chat_history = [] | |
| prompt_text = f"User request: {user_message}\n\nHere is a long, creative story based on the request:\nOnce upon a time," | |
| reply = generate_reply( | |
| prompt_text, | |
| max_new_tokens, | |
| temperature, | |
| top_p | |
| ) | |
| final_reply = "Once upon a time, " + clean_reply(reply) | |
| chat_history.append({"role": "user", "content": user_message}) | |
| chat_history.append({"role": "assistant", "content": final_reply}) | |
| return "", chat_history, chat_history | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Story generation with Stu") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chat = gr.Chatbot(elem_id="chatbot", label="Story Output") | |
| msg = gr.Textbox(label="What should the story be about?") | |
| send = gr.Button("Generate Story") | |
| max_tokens = gr.Slider(50, 1025, value=300, label="max_new_tokens") | |
| temp = gr.Slider(0.6, 1.5, value=1.0, label="temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, label="top_p") | |
| with gr.Column(scale=1): | |
| gr.Markdown("Model: " + MODEL_ID) | |
| gr.Markdown("Note: Do not input too complex prompts, since the model\ | |
| might get confused. This setup is optimized for storytelling.") | |
| state = gr.State([]) | |
| send.click( | |
| fn=chat_with_model, | |
| inputs=[msg, state, max_tokens, temp, top_p], | |
| outputs=[msg, chat, state] | |
| ) | |
| demo.launch() | |