Axolotlily commited on
Commit
4eb45db
·
1 Parent(s): 0a5c9f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RPG Room Generator App
3
+ Sets the sampling parameters and provides minimal interface to the user
4
+
5
+ https://huggingface.co/blog/how-to-generate
6
+ """
7
+ import gradio as gr
8
+ from gradio import inputs # allows easier doc lookup in Pycharm
9
+ import transformers as tr
10
+
11
+ MPATH = "./models/mdl_roomgen2"
12
+ MODEL = tr.GPT2LMHeadModel.from_pretrained(MPATH)
13
+
14
+ # ToDo: Will save tokenizer next time so can replace this with a load
15
+ SPECIAL_TOKENS = {
16
+ 'eos_token': '<|EOS|>',
17
+ 'bos_token': '<|endoftext|>',
18
+ 'pad_token': '<pad>',
19
+ 'sep_token': '<|body|>'
20
+ }
21
+ TOK = tr.GPT2Tokenizer.from_pretrained("gpt2")
22
+ TOK.add_special_tokens(SPECIAL_TOKENS)
23
+
24
+
25
+ SAMPLING_OPTIONS = {
26
+ "Reasonable":
27
+ {
28
+ "top_k": 25,
29
+ "temperature": 50,
30
+ "top_p": 60
31
+ },
32
+ "Odd":
33
+ {
34
+ "top_k": 50,
35
+ "temperature": 75,
36
+ "top_p": 90
37
+ },
38
+ "Insane":
39
+ {
40
+ "top_k": 300,
41
+ "temperature": 100,
42
+ "top_p": 85
43
+ },
44
+ }
45
+
46
+
47
+ def generate_room(room_name, room_desc, max_length, sampling_method):
48
+ """
49
+ Uses pretrained model to generate text for a dungeon room
50
+ Returns: Room description text
51
+ """
52
+ prompt = " ".join(
53
+ [
54
+ SPECIAL_TOKENS["bos_token"],
55
+ room_name,
56
+ SPECIAL_TOKENS["sep_token"],
57
+ room_desc
58
+ ]
59
+ )
60
+ # Only want to skip the room name part
61
+ to_skip = TOK.encode(" ".join([SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"]]),
62
+ return_tensors="pt")
63
+ ids = TOK.encode(prompt, return_tensors="pt")
64
+
65
+ # Sample
66
+ top_k = SAMPLING_OPTIONS[sampling_method]["top_k"]
67
+ temperature = SAMPLING_OPTIONS[sampling_method]["temperature"] / 100.
68
+ top_p = SAMPLING_OPTIONS[sampling_method]["top_p"] / 100.
69
+ output = MODEL.generate(
70
+ ids,
71
+ max_length=max_length,
72
+ do_sample=True,
73
+ top_k=top_k,
74
+ temperature=temperature,
75
+ top_p=top_p
76
+ )
77
+ output = TOK.decode(output[0][to_skip.shape[1]:], clean_up_tokenization_spaces=True).replace(" ", " ")
78
+ # Slice off last partial sentence
79
+ last_period = output.rfind(".")
80
+ if last_period > 0:
81
+ output = output[:last_period+1]
82
+ return output
83
+
84
+
85
+ if __name__ == "__main__":
86
+ iface = gr.Interface(
87
+ title="RPG Room Generator",
88
+ fn=generate_room,
89
+ inputs=[
90
+ inputs.Textbox(lines=1, label="Room Name"),
91
+ inputs.Textbox(lines=3, label="Start of Room Description (Optional)", default=""),
92
+ inputs.Slider(minimum=50, maximum=250, default=200, label="Length"),
93
+ inputs.Radio(choices=list(SAMPLING_OPTIONS.keys()), default="Odd", label="Craziness"),
94
+ ],
95
+ outputs="text",
96
+ layout="horizontal",
97
+ allow_flagging="never",
98
+ theme="dark",
99
+ )
100
+ app, local_url, share_url = iface.launch()