Sudoaptinstallpy3 commited on
Commit
07a67c8
·
verified ·
1 Parent(s): fb92efc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -19
app.py CHANGED
@@ -1,27 +1,51 @@
1
  import gradio as gr
2
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 
3
 
4
- # Load the MARIOGPT model and tokenizer
5
  model_name = "shyamsn97/Mario-GPT2-700-context-length"
6
- model = GPT2LMHeadModel.from_pretrained(model_name)
7
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
 
9
- # Define the function to generate Mario level
10
- def generate_mario_level(prompt):
11
- inputs = tokenizer(prompt, return_tensors="pt")
12
- outputs = model.generate(inputs["input_ids"], max_length=200)
13
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
14
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Create the Gradio interface
17
- interface = gr.Interface(
18
- fn=generate_mario_level,
19
- inputs=gr.Textbox(lines=2, placeholder="Enter level description here..."),
20
- outputs="text",
21
- title="MARIOGPT Level Generator",
22
- description="Generate Mario levels using MARIOGPT by entering a level description.",
23
- examples=[["simple level"], ["difficult level with many enemies"], ["water level with lots of coins"]]
24
- )
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Launch the Gradio interface
27
- interface.launch()
 
1
  import gradio as gr
2
+ from mario_gpt import MarioLM, SampleOutput
3
+ import torch
4
+ from PIL import Image
5
 
6
+ # Load the Mario GPT model
7
  model_name = "shyamsn97/Mario-GPT2-700-context-length"
8
+ mario_lm = MarioLM(lm_path=model_name, tokenizer_path=model_name)
 
9
 
10
+ # Define the function to generate Mario levels
11
+ def generate_mario_level(prompt, num_steps=1400, temperature=2.0):
12
+ generated_level = mario_lm.sample(
13
+ prompts=[prompt],
14
+ num_steps=num_steps,
15
+ temperature=temperature,
16
+ use_tqdm=False
17
+ )
18
+ # Save the generated image to display
19
+ generated_level.img.save("generated_level.png")
20
+ generated_level.save("generated_level.txt")
21
+ return generated_level.img, "generated_level.txt"
22
+
23
+ # Define the function to play the generated level
24
+ def play_generated_level(level_file):
25
+ loaded_level = SampleOutput.load(level_file)
26
+ loaded_level.play()
27
+ return "Playing the level..."
28
 
29
  # Create the Gradio interface
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown("## MARIOGPT Level Generator")
32
+ with gr.Row():
33
+ with gr.Column():
34
+ prompt = gr.Textbox(lines=2, placeholder="Enter level description here...")
35
+ num_steps = gr.Slider(400, 2000, step=100, label="Number of Steps")
36
+ temperature = gr.Slider(0.5, 3.0, step=0.1, label="Temperature")
37
+ generate_button = gr.Button("Generate Level")
38
+ with gr.Column():
39
+ level_image = gr.Image(type="pil")
40
+ level_file = gr.File(label="Level File")
41
+ generate_button.click(generate_mario_level, inputs=[prompt, num_steps, temperature], outputs=[level_image, level_file])
42
+
43
+ gr.Markdown("## Play Generated Level")
44
+ with gr.Row():
45
+ with gr.Column():
46
+ play_button = gr.Button("Play Level")
47
+ play_result = gr.Textbox(label="Output")
48
+ play_button.click(play_generated_level, inputs=[level_file], outputs=[play_result])
49
 
50
  # Launch the Gradio interface
51
+ demo.launch()