File size: 5,798 Bytes
49b8c43
de878f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b8c43
de878f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b8c43
de878f4
 
49b8c43
de878f4
 
 
49b8c43
de878f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175c150
de878f4
 
175c150
de878f4
 
 
 
175c150
de878f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b8c43
 
de878f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import torch
import uuid
import spaces
from supermariogpt.dataset import MarioDataset
from supermariogpt.prompter import Prompter
from supermariogpt.lm import MarioLM
from supermariogpt.utils import view_level, convert_level_to_png

from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles

import os
import uvicorn
from pathlib import Path
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize model
try:
    mario_lm = MarioLM()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    mario_lm = mario_lm.to(device)
    logger.info(f"Model loaded successfully on {device}")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise

TILE_DIR = "data/tiles"

# Ensure static directory exists
Path("static").mkdir(exist_ok=True)

gr.set_static_paths(paths=[Path("static").absolute()])

app = FastAPI()

def make_html_file(generated_level):
    """Generate HTML file for level visualization"""
    try:
        level_text = f"""{'''
'''.join(view_level(generated_level, mario_lm.tokenizer))}"""
        unique_id = uuid.uuid4()  # Changed from uuid1 to uuid4 for better randomness
        html_filename = f"demo-{unique_id}.html"
        
        html_content = f'''<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="utf-8">
    <title>supermariogpt</title>
    <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>

<body>
</body>
<script>
    cheerpjInit().then(function () {{
        cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
    }});
    cheerpjCreateDisplay(512, 500);
    cheerpjRunJar("/app/gradio_api/file=static/mario.jar");
</script>
</html>'''
        
        with open(Path("static") / html_filename, 'w', encoding='utf-8') as f:
            f.write(html_content)
        
        return html_filename
    except Exception as e:
        logger.error(f"Error creating HTML file: {e}")
        raise

@spaces.GPU
def generate(pipes, enemies, blocks, elevation, temperature=2.0, level_size=1399, prompt="", progress=gr.Progress(track_tqdm=True)):
    """Generate Mario level based on parameters"""
    try:
        # Validate inputs
        temperature = max(0.1, min(2.0, float(temperature)))
        level_size = max(100, min(2799, int(level_size)))
        
        if prompt == "":
            prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
        
        logger.info(f"Using prompt: {prompt}")
        logger.info(f"Using temperature: {temperature}")
        logger.info(f"Using level size: {level_size}")
        
        prompts = [prompt]
        generated_level = mario_lm.sample(
            prompts=prompts,
            num_steps=level_size,
            temperature=float(temperature),
            use_tqdm=True
        )
        
        filename = make_html_file(generated_level)
        img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
        
        gradio_html = f'''<div>
            <iframe width=512 height=512 style="margin: 0 auto" src="/gradio_api/file=static/{filename}"></iframe>
            <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
        </div>'''
        
        return [img, gradio_html]
    except Exception as e:
        logger.error(f"Error generating level: {e}")
        raise gr.Error(f"Failed to generate level: {str(e)}")

with gr.Blocks().queue() as demo:
    gr.Markdown('''# MarioGPT
### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
    ''')
    with gr.Tabs():
        with gr.TabItem("Compose prompt"):
            with gr.Row():
                pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?")
                enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?")
            with gr.Row():
                blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?")
                elevation = gr.Radio(["low", "high"], value="low", label="Elevation?")
        with gr.TabItem("Type prompt"):
            text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
        
    with gr.Accordion(label="Advanced settings", open=False):
        temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations")
        level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size")
    
    btn = gr.Button("Generate level")
    with gr.Row():
        with gr.Group():
            level_play = gr.HTML()    
        level_image = gr.Image()
    btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
    gr.Examples(
        examples=[
            ["many", "many", "some", "high"],
            ["no", "some", "many", "high"],
            ["many", "many", "little", "low"],
            ["no", "no", "many", "high"],
        ],
        inputs=[pipes, enemies, blocks, elevation],
        outputs=[level_image, level_play],
        fn=generate,
        cache_examples=True,
    )

# Mount static files and Gradio app
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)