mariogpt / app.py
Sudoaptinstallpy3's picture
Update app.py
8fb7ddc verified
import gradio as gr
from mario_gpt import MarioLM, SampleOutput
# Load the Mario GPT model
mario_lm = MarioLM()
# Define the function to generate Mario levels
def generate_mario_level(prompt, num_steps=1000, temperature=2.0):
generated_level = mario_lm.sample(
prompts=[prompt],
num_steps=num_steps,
temperature=temperature,
use_tqdm=False
)
# Save the generated image to display
generated_level.img.save("generated_level.png")
generated_level.save("generated_level.txt")
return generated_level.img, "generated_level.txt"
# Define the function to play the generated level
def play_generated_level(level_file):
loaded_level = SampleOutput.load(level_file)
loaded_level.play()
return "Playing the level..."
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## MARIOGPT Level Generator")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=2, placeholder="Enter level description here...")
num_steps = gr.Slider(400, 2000, step=100, label="Number of Steps")
temperature = gr.Slider(0.5, 3.0, step=0.1, label="Temperature")
generate_button = gr.Button("Generate Level")
with gr.Column():
level_image = gr.Image(type="pil")
level_file = gr.File(label="Level File")
generate_button.click(generate_mario_level, inputs=[prompt, num_steps, temperature], outputs=[level_image, level_file])
gr.Markdown("## Play Generated Level")
with gr.Row():
with gr.Column():
play_button = gr.Button("Play Level")
play_result = gr.Textbox(label="Output")
play_button.click(play_generated_level, inputs=[level_file], outputs=[play_result])
# Launch the Gradio interface
demo.launch()