File size: 7,424 Bytes
49b8c43
 
 
 
 
 
 
 
41ed846
49b8c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41ed846
49b8c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41ed846
49b8c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41ed846
49b8c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python3
"""
MarioGPT MCP Server
Provides MCP-compatible interface for generating Mario levels via HuggingChat
"""

import asyncio
import logging
from typing import Any, Optional, List, Union
from mcp.server import Server
from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
from pydantic import BaseModel, Field
import base64
import io

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize MCP Server
mcp_server = Server("mariogpt-server")

# Model will be initialized on first use
mario_lm = None
device = None

def initialize_model():
    """Lazy initialization of the Mario model"""
    global mario_lm, device
    if mario_lm is None:
        try:
            import torch
            from supermariogpt.lm import MarioLM
            
            mario_lm = MarioLM()
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            mario_lm = mario_lm.to(device)
            logger.info(f"MarioGPT model loaded on {device}")
        except Exception as e:
            logger.error(f"Failed to initialize model: {e}")
            raise

class GenerateLevelParams(BaseModel):
    """Parameters for generating a Mario level"""
    prompt: str = Field(
        description="Text description of the level (e.g., 'many pipes, some enemies, low elevation')"
    )
    temperature: float = Field(
        default=2.0,
        ge=0.1,
        le=2.0,
        description="Generation temperature (0.1-2.0). Higher = more diverse but lower quality"
    )
    level_size: int = Field(
        default=1399,
        ge=100,
        le=2799,
        description="Size of the level in tokens (100-2799)"
    )

@mcp_server.list_tools()
async def list_tools() -> List[Tool]:
    """List available MCP tools"""
    return [
        Tool(
            name="generate_mario_level",
            description="Generate a playable Super Mario level from text description. "
                       "Returns both a visual representation and level data. "
                       "Example prompts: 'many pipes, some enemies, high elevation', "
                       "'no pipes, many enemies, some blocks, low elevation'",
            inputSchema=GenerateLevelParams.model_json_schema()
        ),
        Tool(
            name="get_level_suggestions",
            description="Get example prompts and suggestions for creating interesting Mario levels",
            inputSchema={
                "type": "object",
                "properties": {},
            }
        )
    ]

@mcp_server.call_tool()
async def call_tool(name: str, arguments: Any) -> List[Union[TextContent, ImageContent, EmbeddedResource]]:
    """Handle tool calls"""
    
    if name == "generate_mario_level":
        try:
            # Initialize model if needed
            initialize_model()
            
            # Parse and validate parameters
            params = GenerateLevelParams(**arguments)
            
            logger.info(f"Generating level with prompt: {params.prompt}")
            logger.info(f"Temperature: {params.temperature}, Size: {params.level_size}")
            
            # Import required modules
            from supermariogpt.utils import view_level, convert_level_to_png
            
            TILE_DIR = "data/tiles"
            
            # Generate level
            prompts = [params.prompt]
            generated_level = mario_lm.sample(
                prompts=prompts,
                num_steps=params.level_size,
                temperature=float(params.temperature),
                use_tqdm=False
            )
            
            # Convert to text representation
            level_lines = view_level(generated_level, mario_lm.tokenizer)
            level_text = '\n'.join(level_lines)
            
            # Generate PNG image
            try:
                img = convert_level_to_png(
                    generated_level.squeeze(), 
                    TILE_DIR, 
                    mario_lm.tokenizer
                )[0]
                
                # Convert PIL Image to base64
                buffered = io.BytesIO()
                img.save(buffered, format="PNG")
                img_base64 = base64.b64encode(buffered.getvalue()).decode()
                
                return [
                    TextContent(
                        type="text",
                        text=f"Successfully generated Mario level!\n\n"
                             f"Prompt: {params.prompt}\n"
                             f"Temperature: {params.temperature}\n"
                             f"Level size: {params.level_size}\n\n"
                             f"Level representation:\n{level_text[:500] + '...' if len(level_text) > 500 else level_text}"
                    ),
                    ImageContent(
                        type="image",
                        data=img_base64,
                        mimeType="image/png"
                    )
                ]
            except Exception as img_error:
                logger.warning(f"Could not generate image: {img_error}")
                return [
                    TextContent(
                        type="text",
                        text=f"Successfully generated Mario level!\n\n"
                             f"Prompt: {params.prompt}\n\n"
                             f"Level representation:\n{level_text}"
                    )
                ]
                
        except Exception as e:
            logger.error(f"Error generating level: {e}")
            return [
                TextContent(
                    type="text",
                    text=f"Error generating Mario level: {str(e)}"
                )
            ]
    
    elif name == "get_level_suggestions":
        suggestions = """
# Mario Level Generation Suggestions

## Pipe Variations:
- "no pipes, many enemies, some blocks, low elevation"
- "many pipes, few enemies, many blocks, high elevation"
- "some pipes, some enemies, some blocks, low elevation"

## Enemy Focused:
- "little pipes, many enemies, little blocks, low elevation"
- "no pipes, many enemies, many blocks, high elevation"

## Platform Challenges:
- "some pipes, few enemies, many blocks, high elevation"
- "no pipes, some enemies, many blocks, high elevation"

## Balanced Levels:
- "some pipes, some enemies, some blocks, low elevation"
- "many pipes, some enemies, some blocks, high elevation"

## Tips:
- Use "high elevation" for more vertical platforming challenges
- Use "low elevation" for more horizontal levels
- "many enemies" creates more combat-focused levels
- "many blocks" creates more platform-jumping challenges
- Temperature 1.0-1.5: More consistent, quality levels
- Temperature 1.5-2.0: More diverse, experimental levels
"""
        return [
            TextContent(
                type="text",
                text=suggestions
            )
        ]
    
    else:
        raise ValueError(f"Unknown tool: {name}")

async def run_server():
    """Run the MCP server"""
    from mcp.server.stdio import stdio_server
    
    logger.info("Starting MarioGPT MCP Server...")
    
    async with stdio_server() as (read_stream, write_stream):
        await mcp_server.run(
            read_stream,
            write_stream,
            mcp_server.create_initialization_options()
        )

if __name__ == "__main__":
    asyncio.run(run_server())