Spaces:
Runtime error
Runtime error
File size: 2,933 Bytes
4eb45db f2fc5f5 4eb45db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
"""
RPG Room Generator App
Sets the sampling parameters and provides minimal interface to the user
https://huggingface.co/blog/how-to-generate
"""
import gradio as gr
from gradio import inputs # allows easier doc lookup in Pycharm
import transformers as tr
MPATH = "kmacdermid/RpgRoomGenerator/tree/main/models/mdl_roomgen2"
MODEL = tr.GPT2LMHeadModel.from_pretrained(MPATH)
# ToDo: Will save tokenizer next time so can replace this with a load
SPECIAL_TOKENS = {
'eos_token': '<|EOS|>',
'bos_token': '<|endoftext|>',
'pad_token': '<pad>',
'sep_token': '<|body|>'
}
TOK = tr.GPT2Tokenizer.from_pretrained("gpt2")
TOK.add_special_tokens(SPECIAL_TOKENS)
SAMPLING_OPTIONS = {
"Reasonable":
{
"top_k": 25,
"temperature": 50,
"top_p": 60
},
"Odd":
{
"top_k": 50,
"temperature": 75,
"top_p": 90
},
"Insane":
{
"top_k": 300,
"temperature": 100,
"top_p": 85
},
}
def generate_room(room_name, room_desc, max_length, sampling_method):
"""
Uses pretrained model to generate text for a dungeon room
Returns: Room description text
"""
prompt = " ".join(
[
SPECIAL_TOKENS["bos_token"],
room_name,
SPECIAL_TOKENS["sep_token"],
room_desc
]
)
# Only want to skip the room name part
to_skip = TOK.encode(" ".join([SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"]]),
return_tensors="pt")
ids = TOK.encode(prompt, return_tensors="pt")
# Sample
top_k = SAMPLING_OPTIONS[sampling_method]["top_k"]
temperature = SAMPLING_OPTIONS[sampling_method]["temperature"] / 100.
top_p = SAMPLING_OPTIONS[sampling_method]["top_p"] / 100.
output = MODEL.generate(
ids,
max_length=max_length,
do_sample=True,
top_k=top_k,
temperature=temperature,
top_p=top_p
)
output = TOK.decode(output[0][to_skip.shape[1]:], clean_up_tokenization_spaces=True).replace(" ", " ")
# Slice off last partial sentence
last_period = output.rfind(".")
if last_period > 0:
output = output[:last_period+1]
return output
if __name__ == "__main__":
iface = gr.Interface(
title="RPG Room Generator",
fn=generate_room,
inputs=[
inputs.Textbox(lines=1, label="Room Name"),
inputs.Textbox(lines=3, label="Start of Room Description (Optional)", default=""),
inputs.Slider(minimum=50, maximum=250, default=200, label="Length"),
inputs.Radio(choices=list(SAMPLING_OPTIONS.keys()), default="Odd", label="Craziness"),
],
outputs="text",
layout="horizontal",
allow_flagging="never",
theme="dark",
)
app, local_url, share_url = iface.launch()
|