Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from mario_gpt.dataset import MarioDataset | |
| from mario_gpt.prompter import Prompter | |
| from mario_gpt.lm import MarioLM | |
| from mario_gpt.utils import view_level, convert_level_to_png | |
| mario_lm = MarioLM() | |
| device = torch.device('cuda') | |
| mario_lm = mario_lm.to(device) | |
| TILE_DIR = "data/tiles" | |
| def update(prompt, progress=gr.Progress(track_tqdm=True)): | |
| prompts = [prompt] | |
| generated_level = mario_lm.sample( | |
| prompts=prompts, | |
| num_steps=1399, | |
| temperature=2.0, | |
| use_tqdm=True | |
| ) | |
| img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] | |
| return img | |
| with gr.Blocks() as demo: | |
| prompt = gr.Textbox(label="Enter your MarioGPT prompt") | |
| level_image = gr.Image() | |
| btn = gr.Button("Generate level") | |
| btn.click(fn=update, inputs=prompt, outputs=level_image) | |
| pass | |
| demo.launch() | |