Amanda Griffith commited on
Commit
2544172
·
1 Parent(s): f0aa52c

Added ability to generate random spell

Browse files
Files changed (1) hide show
  1. app.py +34 -14
app.py CHANGED
@@ -1,29 +1,43 @@
1
  import os
2
  import gradio as gr
3
- from transformers import pipeline
4
 
5
  HF_API = os.environ.get("HF_API")
6
 
7
- model = "aegrif/gpt2_spell_gen"
8
- tokenizer = "aegrif/gpt2_spell_gen"
9
- pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer, use_auth_token=HF_API)
10
 
 
 
 
11
 
12
- def predict(text, temperature, top_k, top_p, max_length):
 
13
  input_text = f"<|name|> {text} <|spell|>"
14
- pipeline.model.config.pad_token_id = pipeline.model.config.eos_token_id
15
- pipeline.model.config.temperature = temperature
16
- pipeline.model.config.top_k = top_k
17
- pipeline.model.config.top_p = top_p
18
- predictions = pipeline(input_text, max_length=max_length, num_return_sequences=1)[0]["generated_text"]
19
  spell_start = len(text) + 19
20
  output = text + "\n\n" + predictions[spell_start:]
21
  return output.strip()
22
 
23
 
 
 
 
 
 
 
 
 
 
24
  title = "# Spell generation with GPT-2"
25
  description = "## Generate your own spells"
26
- examples = [["Speak with Objects"], ["Summon Burley"], ["Moon Step"], ["Burden of the Gods"], ["Shape Rock"], ["Bard's Laughter"], ["Mundane Foresight"], ["Word of Cancellation"]]
 
27
 
28
  with gr.Blocks(css="#spell-row {justify-content: flex-start; }") as interface:
29
  gr.Markdown(title)
@@ -48,11 +62,17 @@ with gr.Blocks(css="#spell-row {justify-content: flex-start; }") as interface:
48
  top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Top P")
49
  with gr.Row(variant="compact", elem_id="spell-row"):
50
  output = gr.Textbox(label="Generated Spell", placeholder="Your spell will appear here.")
51
- generate_btn = gr.Button("Generate Spell")
 
 
 
 
 
52
  with gr.Row(variant="compact", elem_id="spell-row"):
53
  gr.Markdown(
54
  "**Max Length**: The maximum length of the generated spell.\n\n**Temperature**: The randomness of the generated spell. Higher values are more random, lower values are more deterministic.\n\n**Top K**: The number of highest probability vocabulary tokens to keep for top-k-filtering.\n\n**Top P**: The cumulative probability for top-p-filtering.")
55
 
56
- generate_btn.click(fn=predict, inputs=[name,temperature,top_k,top_p, max_length], outputs=output)
 
57
 
58
- interface.launch()
 
1
  import os
2
  import gradio as gr
3
+ from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
4
 
5
  HF_API = os.environ.get("HF_API")
6
 
7
+ desc_model = GPT2LMHeadModel.from_pretrained("aegrif/gpt2_spell_gen", use_auth_token=HF_API)
8
+ desc_tokenizer = GPT2Tokenizer.from_pretrained("aegrif/gpt2_spell_gen", use_auth_token=HF_API)
9
+ desc_pipeline = pipeline(task="text-generation", model=desc_model, tokenizer=desc_tokenizer, use_auth_token=HF_API)
10
 
11
+ name_model = GPT2LMHeadModel.from_pretrained("aegrif/spell_name_gen", use_auth_token=HF_API)
12
+ name_tokenizer = GPT2Tokenizer.from_pretrained("aegrif/spell_name_gen", use_auth_token=HF_API)
13
+ name_pipeline = pipeline(task="text-generation", model=name_model, tokenizer=name_tokenizer, use_auth_token=HF_API)
14
 
15
+
16
+ def desc_predict(text, temperature, top_k, top_p, max_length):
17
  input_text = f"<|name|> {text} <|spell|>"
18
+ desc_pipeline.model.config.pad_token_id = desc_pipeline.model.config.eos_token_id
19
+ desc_pipeline.model.config.temperature = temperature
20
+ desc_pipeline.model.config.top_k = top_k
21
+ desc_pipeline.model.config.top_p = top_p
22
+ predictions = desc_pipeline(input_text, max_length=max_length, num_return_sequences=1)[0]["generated_text"]
23
  spell_start = len(text) + 19
24
  output = text + "\n\n" + predictions[spell_start:]
25
  return output.strip()
26
 
27
 
28
+ def name_predict(temperature, top_k, top_p, max_length):
29
+ input_text = "<|name|> "
30
+ name_pipeline.model.config.pad_token_id = name_pipeline.model.config.eos_token_id
31
+ predictions = name_pipeline(input_text, max_length=50, num_return_sequences=1)[0]["generated_text"]
32
+ spell_name = predictions[9:].strip()
33
+ desc_predictions = desc_predict(spell_name, temperature, top_k, top_p, max_length)
34
+ return spell_name, desc_predictions
35
+
36
+
37
  title = "# Spell generation with GPT-2"
38
  description = "## Generate your own spells"
39
+ examples = [["Speak with Objects"], ["Summon Burley"], ["Moon Step"], ["Burden of the Gods"], ["Shape Rock"],
40
+ ["Bard's Laughter"], ["Mundane Foresight"], ["Word of Cancellation"]]
41
 
42
  with gr.Blocks(css="#spell-row {justify-content: flex-start; }") as interface:
43
  gr.Markdown(title)
 
62
  top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Top P")
63
  with gr.Row(variant="compact", elem_id="spell-row"):
64
  output = gr.Textbox(label="Generated Spell", placeholder="Your spell will appear here.")
65
+ with gr.Row(variant="compact", elem_id="spell-row"):
66
+ with gr.Column(scale=1):
67
+ generate_btn = gr.Button("Generate Spell")
68
+ with gr.Column(scale=1):
69
+ random_btn = gr.Button("Random Spell")
70
+
71
  with gr.Row(variant="compact", elem_id="spell-row"):
72
  gr.Markdown(
73
  "**Max Length**: The maximum length of the generated spell.\n\n**Temperature**: The randomness of the generated spell. Higher values are more random, lower values are more deterministic.\n\n**Top K**: The number of highest probability vocabulary tokens to keep for top-k-filtering.\n\n**Top P**: The cumulative probability for top-p-filtering.")
74
 
75
+ generate_btn.click(fn=desc_predict, inputs=[name, temperature, top_k, top_p, max_length], outputs=output)
76
+ random_btn.click(fn=name_predict, inputs=[temperature, top_k, top_p, max_length], outputs=[name, output])
77
 
78
+ interface.launch()