File size: 4,756 Bytes
fb0bc70
 
 
 
 
39fec9a
2fcc959
fb0bc70
 
 
7d37afc
a0fadff
 
 
7d37afc
762089e
fb0bc70
 
 
85f553e
 
 
a0fadff
 
85f553e
 
 
 
 
 
a0fadff
 
 
85f553e
 
a0fadff
 
85f553e
 
a0fadff
85f553e
 
 
e852e91
85f553e
6feac3c
 
a0fadff
 
 
 
 
 
 
 
 
9222068
a0fadff
 
 
6feac3c
28a2ee2
 
6feac3c
28a2ee2
 
6feac3c
a0fadff
6feac3c
a0fadff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2fb7fb
762089e
 
a0fadff
762089e
a0fadff
 
85f553e
 
 
 
a0fadff
 
762089e
c72e69a
a0fadff
762089e
 
 
fb0bc70
762089e
e852e91
762089e
fb0bc70
 
e852e91
 
 
a0fadff
 
 
762089e
fb0bc70
 
e852e91
 
fb0bc70
 
762089e
 
c72e69a
 
 
a2fb7fb
762089e
fb0bc70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gradio as gr

MODEL_ID = "LMSeed/GPT2-Small-Distilled-900M" #"LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42"
#"openai-community/gpt2"#"LMSeed/GPT2-small-distilled-100M"

device = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token_id
    
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

if torch.cuda.is_available():
    model = model.to("cuda")

def generate_reply(prompt, max_new_tokens, temperature, top_p):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    input_len = inputs["input_ids"].shape[1]

    output_ids = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        no_repeat_ngram_size=3,
        repetition_penalty=1.2,
        eos_token_id=tokenizer.eos_token_id,                  
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_tokens = output_ids[0][input_len:]
    text = tokenizer.decode(
        output_ids[0],
        skip_special_tokens=True,               
        clean_up_tokenization_spaces=True
    )

    return text[len(prompt):]

def clean_reply(text):
    text = text.strip()
    stop_words = ["Human:", "User:", "AI:", "Assistant:"]
    for word in stop_words:
        if word in text:
            text = text.split(word)[0]
    return text.strip()
    
# def clean_reply(text):
    
#     text = text.strip()
    
#     for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
#         if text.startswith(prefix):
#             text = text[len(prefix):].strip()
            
    # lines = [l.strip() for l in text.split("\n")]
    # lines = [l for l in lines if l]

    # if len(lines) == 0:
    #     return ""

    # return lines[0]
    
# def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):

#     if chat_history is None:
#         chat_history = []

#     # Build conversation history
#     # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
#     history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \
#             The assistant should tell a story according to human's requirment.\n"

#     for msg in chat_history:
#         role = "Human" if msg["role"] == "user" else "AI"
#         history_text += f"{role}: {msg['content']}\n"

#     history_text += f"Human: {user_message}\nAI:"

#     # -------- generate ----------
#     raw = generate_reply(
#         history_text,
#         max_new_tokens,
#         temperature,
#         top_p
#     )
#     # Only keep new part
#     reply = raw[len(history_text):]
#     reply = clean_reply(reply)
#     # ------------------------------

#     chat_history.append({"role": "user", "content": user_message})
#     chat_history.append({"role": "assistant", "content": reply})

#     return "", chat_history, chat_history
def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
    if chat_history is None:
        chat_history = []
    prompt_text = f"User request: {user_message}\n\nHere is a long, creative story based on the request:\nOnce upon a time,"

    reply = generate_reply(
        prompt_text,
        max_new_tokens,
        temperature,
        top_p
    )
    
    final_reply = "Once upon a time, " + clean_reply(reply)

    chat_history.append({"role": "user", "content": user_message})
    chat_history.append({"role": "assistant", "content": final_reply})

    return "", chat_history, chat_history

with gr.Blocks() as demo:

    gr.Markdown("# Story generation with Stu")

    with gr.Row():
        with gr.Column(scale=3):
            chat = gr.Chatbot(elem_id="chatbot", label="Story Output")
            msg = gr.Textbox(label="What should the story be about?")
            send = gr.Button("Generate Story")
            max_tokens = gr.Slider(50, 1025, value=300, label="max_new_tokens")
            temp = gr.Slider(0.6, 1.5, value=1.0, label="temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, label="top_p")

        with gr.Column(scale=1):
            gr.Markdown("Model: " + MODEL_ID)
            gr.Markdown("Note: Do not input too complex prompts, since the model\
                    might get confused. This setup is optimized for storytelling.")

    state = gr.State([])

    send.click(
    fn=chat_with_model,
    inputs=[msg, state, max_tokens, temp, top_p],
    outputs=[msg, chat, state]
    )

    demo.launch()