import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import gradio as gr MODEL_ID = "LMSeed/GPT2-Small-Distilled-900M" #"LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42" #"openai-community/gpt2"#"LMSeed/GPT2-small-distilled-100M" device = 0 if torch.cuda.is_available() else -1 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token_id model = AutoModelForCausalLM.from_pretrained(MODEL_ID) if torch.cuda.is_available(): model = model.to("cuda") def generate_reply(prompt, max_new_tokens, temperature, top_p): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) input_len = inputs["input_ids"].shape[1] output_ids = model.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=float(temperature), top_p=float(top_p), no_repeat_ngram_size=3, repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id ) generated_tokens = output_ids[0][input_len:] text = tokenizer.decode( output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) return text[len(prompt):] def clean_reply(text): text = text.strip() stop_words = ["Human:", "User:", "AI:", "Assistant:"] for word in stop_words: if word in text: text = text.split(word)[0] return text.strip() # def clean_reply(text): # text = text.strip() # for prefix in ["Assistant:", "assistant:", "User:", "user:"]: # if text.startswith(prefix): # text = text[len(prefix):].strip() # lines = [l.strip() for l in text.split("\n")] # lines = [l for l in lines if l] # if len(lines) == 0: # return "" # return lines[0] # def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9): # if chat_history is None: # chat_history = [] # # Build conversation history # # history_text = "The following is a friendly conversation between a human and an AI assistant.\n" # history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \ # The assistant should tell a story according to human's requirment.\n" # for msg in chat_history: # role = "Human" if msg["role"] == "user" else "AI" # history_text += f"{role}: {msg['content']}\n" # history_text += f"Human: {user_message}\nAI:" # # -------- generate ---------- # raw = generate_reply( # history_text, # max_new_tokens, # temperature, # top_p # ) # # Only keep new part # reply = raw[len(history_text):] # reply = clean_reply(reply) # # ------------------------------ # chat_history.append({"role": "user", "content": user_message}) # chat_history.append({"role": "assistant", "content": reply}) # return "", chat_history, chat_history def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9): if chat_history is None: chat_history = [] prompt_text = f"User request: {user_message}\n\nHere is a long, creative story based on the request:\nOnce upon a time," reply = generate_reply( prompt_text, max_new_tokens, temperature, top_p ) final_reply = "Once upon a time, " + clean_reply(reply) chat_history.append({"role": "user", "content": user_message}) chat_history.append({"role": "assistant", "content": final_reply}) return "", chat_history, chat_history with gr.Blocks() as demo: gr.Markdown("# Story generation with Stu") with gr.Row(): with gr.Column(scale=3): chat = gr.Chatbot(elem_id="chatbot", label="Story Output") msg = gr.Textbox(label="What should the story be about?") send = gr.Button("Generate Story") max_tokens = gr.Slider(50, 1025, value=300, label="max_new_tokens") temp = gr.Slider(0.6, 1.5, value=1.0, label="temperature") top_p = gr.Slider(0.1, 1.0, value=0.95, label="top_p") with gr.Column(scale=1): gr.Markdown("Model: " + MODEL_ID) gr.Markdown("Note: Do not input too complex prompts, since the model\ might get confused. This setup is optimized for storytelling.") state = gr.State([]) send.click( fn=chat_with_model, inputs=[msg, state, max_tokens, temp, top_p], outputs=[msg, chat, state] ) demo.launch()