Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,830 Bytes
850b0e4 c8c7b71 766dc1d 850b0e4 ee2e0b7 c8c7b71 ee2e0b7 151dc74 850b0e4 c8c7b71 850b0e4 151dc74 ee2e0b7 c8c7b71 151dc74 db4b7db d1edc82 c8c7b71 3eefe14 c8c7b71 c801583 c8c7b71 d1edc82 c801583 d1edc82 db4b7db c801583 d1edc82 db4b7db d1edc82 db4b7db d1edc82 db4b7db c801583 d1edc82 db4b7db d1edc82 c801583 d1edc82 c8c7b71 d1edc82 1ec2ec6 766dc1d c8c7b71 1ec2ec6 7cc62e7 850b0e4 1ec2ec6 7cc62e7 850b0e4 c8c7b71 850b0e4 c8c7b71 ee2e0b7 151dc74 ee2e0b7 c8c7b71 850b0e4 2c782ec c8c7b71 ee2e0b7 c8c7b71 ee2e0b7 c8c7b71 9f829c5 5235656 226c5b4 850b0e4 c8c7b71 040c712 c8c7b71 d59d1e6 1ec2ec6 9f829c5 1ec2ec6 c8c7b71 141b1fb d59d1e6 6a39a28 6a94190 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import torch
import uuid
import spaces
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import os
import uvicorn
from pathlib import Path
mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"
gr.set_static_paths(paths=[Path("static").absolute()])
app = FastAPI()
def make_html_file(generated_level):
level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
unique_id = uuid.uuid1()
html_filename = f"demo-{unique_id}.html"
# This is the final, simplified solution that respects all source code discoveries.
# 1. The JAR is run from its standard "/app/" location.
# 2. The level data is placed at the hardcoded "/str/mylevel.txt" path.
html_content = f'''<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Mario Game</title>
<script src="https://cjrtnc.leaningtech.com/3.1/cj3loader.js"></script>
</head>
<body>
<p id="loading-status">Loading game, please wait...</p>
</body>
<script>
async function runGame() {{
const statusElement = document.getElementById("loading-status");
try {{
// Step 1: Initialize the CheerpJ runtime.
statusElement.textContent = "Initializing Java runtime...";
await cheerpjInit();
// Step 2: Add the level file to the virtual filesystem.
// This is the hardcoded path from the PlayLevel.java source code.
statusElement.textContent = "Loading level...";
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
console.log("Runtime ready and level loaded.");
// Step 3: Run the game.
statusElement.textContent = "Starting game...";
document.body.innerHTML = '';
cheerpjCreateDisplay(512, 500);
// Let CheerpJ fetch the JAR from its standard "/app/" classpath.
// The Java code will ignore arguments and load the level from /str/mylevel.txt on its own.
cheerpjRunJar("/app/gradio_api/file=static/mario.jar");
}} catch (error) {{
console.error("Failed to load the Mario game:", error);
statusElement.innerHTML = "<h1>Error</h1><p>Could not load the game. Please check the browser console for details.</p>";
}}
}}
runGame();
</script>
</html>'''
with open(Path("static") / html_filename, 'w', encoding='utf-8') as f:
f.write(html_content)
return html_filename
@spaces.GPU
def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""):
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
print(f"Using prompt: {prompt}")
print(f"Using temperature: {temperature}")
prompts = [prompt]
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=float(temperature),
use_tqdm=True
)
filename = make_html_file(generated_level)
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
gradio_html = f'''<div>
<iframe width=512 height=512 style="margin: 0 auto" src="/gradio_api/file=static/{filename}"></iframe>
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
</div>'''
return [img, gradio_html]
with gr.Blocks().queue() as demo:
gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
''')
with gr.Tabs():
with gr.TabItem("Compose prompt"):
with gr.Row():
pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
with gr.Row():
blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
elevation = gr.Radio(["low", "high"], label="Elevation?")
with gr.TabItem("Type prompt"):
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations")
level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size")
btn = gr.Button("Generate level")
with gr.Row():
with gr.Group():
level_play = gr.HTML()
level_image = gr.Image()
btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
gr.Examples(
examples=[
["many", "many", "some", "high"],
["no", "some", "many", "high"],
["many", "many", "little", "low"],
["no", "no", "many", "high"],
],
inputs=[pipes, enemies, blocks, elevation],
outputs=[level_image, level_play],
fn=generate,
cache_examples=True,
)
demo.launch()
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")
uvicorn.run(app, host="0.0.0.0", port=7860) |