| | --- |
| | license: mit |
| | tags: |
| | - pytorch |
| | - transformer |
| | - world-model |
| | - game-simulation |
| | - snake-game |
| | language: |
| | - en |
| | pipeline_tag: other |
| | --- |
| | |
| | # Snakeformer |
| |
|
| | This model is a neural engine for an ASCII-based Snake Game |
| |
|
| | ## Model Details |
| |
|
| | | Property | Value | |
| | |----------|-------| |
| | | Architecture | Decoder-only Transformer (GPT-style) | |
| | | Parameters | ~0.8M | |
| | | Layers | 4 | |
| | | Attention Heads | 8 | |
| | | Embedding Dimension | 128 | |
| | | Context Window | 1024 tokens | |
| | | Vocabulary Size | 16 characters | |
| |
|
| | ## How It Works |
| |
|
| | The game board is represented as a 16x16 ASCII grid: |
| |
|
| | ``` |
| | ................ |
| | ................ |
| | ........#....... |
| | ........O....... |
| | ........H......F |
| | ................ |
| | ``` |
| |
|
| | - `.` Empty space |
| | - `H` Snake head |
| | - `O` Snake body |
| | - `#` Snake tail |
| | - `F` Food |
| |
|
| | The model receives a prompt like: |
| | ``` |
| | B: |
| | [current board state] |
| | A:R |
| | T: |
| | ``` |
| |
|
| | And generates the next board state after executing action `R` (Right). |
| |
|
| | ## Files |
| |
|
| | - `snake_model.pt` - PyTorch model weights |
| | - `meta.pkl` - Vocabulary and model configuration |
| | - `gpt.py` - Model architecture (for reference) |
| |
|
| | ## Usage |
| |
|
| | ```python |
| | import torch |
| | import pickle |
| | from huggingface_hub import hf_hub_download |
| | |
| | |
| | model_path = hf_hub_download(repo_id="mcrimi/snakeformer", filename="snake_model.pt") |
| | meta_path = hf_hub_download(repo_id="mcrimi/snakeformer", filename="meta.pkl") |
| | with open(meta_path, "rb") as f: |
| | meta = pickle.load(f) |
| | |
| | # Extract tokenizer mappings |
| | stoi = meta["stoi"] # string to index |
| | itos = meta["itos"] # index to string |
| | |
| | def encode(s): |
| | """Convert string to list of token IDs.""" |
| | return [stoi[c] for c in s] |
| | |
| | def decode(ids): |
| | """Convert list of token IDs back to string.""" |
| | return "".join([itos[i] for i in ids]) |
| | |
| | from model.gpt import GPT, GPTConfig |
| | |
| | config = GPTConfig( |
| | vocab_size=meta["vocab_size"], |
| | block_size=meta.get("block_size", 1024), |
| | n_embd=meta.get("n_embd", 128), |
| | n_head=meta.get("n_head", 8), |
| | n_layer=meta.get("n_layer", 4), |
| | ) |
| | |
| | model = GPT(config) |
| | model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| | model.eval() |
| | |
| | # ----------------------------------------------------------------------------- |
| | # Example: Generate next board state |
| | # ----------------------------------------------------------------------------- |
| | |
| | board = """\ |
| | ................ |
| | ................ |
| | ................ |
| | ................ |
| | ................ |
| | ................ |
| | ................ |
| | ........#....... |
| | ........O....... |
| | ........H......F |
| | ................ |
| | ................ |
| | ................ |
| | ................ |
| | ................ |
| | ................""" |
| | |
| | action = "R" # Move right (towards the food) |
| | |
| | # Build prompt in the expected format |
| | prompt = f"B:\n{board}\nA:{action}\nT:\n" |
| | |
| | print("\n=== Input Prompt ===") |
| | print(prompt) |
| | |
| | # Encode prompt to token IDs |
| | input_ids = encode(prompt) |
| | print(f"\n=== Encoded ({len(input_ids)} tokens) ===") |
| | print(f"First 20 tokens: {input_ids[:20]}...") |
| | |
| | # Convert to tensor |
| | input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) # (1, seq_len) |
| | print(f"Input tensor shape: {input_tensor.shape}") |
| | |
| | # Generate output |
| | print("\n=== Generating... ===") |
| | stop_token_id = stoi.get("$") |
| | print(f"Stop token ID: {stop_token_id}") |
| | |
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | input_tensor, |
| | max_new_tokens=300, # Board is ~16*17 = 272 chars + some overhead |
| | stop_token_id=stop_token_id, |
| | ) |
| | |
| | # Decode output |
| | output_text = decode(output_ids[0].tolist()) |
| | |
| | print("\n=== Full Output ===") |
| | print(output_text) |
| | |
| | # Extract just the generated part (after "T:\n") |
| | generated = output_text[len(prompt):].split("$")[0] |
| | print("\n=== Generated Board State ===") |
| | print(generated) |
| | ``` |
| |
|
| | ## Play the Game |
| |
|
| | Clone the full repository and run: |
| |
|
| | ```bash |
| | git clone https://github.com/mcrimi/snakeformer |
| | cd snakeformer |
| | python play.py |
| | ``` |
| |
|
| | ## Training |
| |
|
| | The model was trained on ~500k state transitions generated by: |
| | 1. A heuristic bot playing thousands of games |
| | 2. Manual gameplay for edge cases |
| | 3. DAgger-style corrections for model errors |
| |
|
| | See the [GitHub repository](https://github.com/mcrimi/snakeformer) for full training details. |
| |
|
| | ## Limitations |
| |
|
| | - **Deterministic food spawning**: The model was trained with deterministic food placement (based on head position) because random spawning creates a non-learnable mapping for autoregressive models. |
| | - **Occasional hallucinations**: The model may occasionally produce invalid states, especially with long snakes or edge cases. See [repo|www.github.com/mcrimi/snakeformer] for online training artifacts for further improvements. |
| |
|
| | ## Citation |
| |
|
| | ```bibtex |
| | @misc{snakeformer2026, |
| | author = {Crimi, M.}, |
| | title = {Snakeformer: A Transformer-Based Snake Game Simulator}, |
| | year = {2026}, |
| | url = {https://github.com/mcrimi/snakeformer} |
| | } |
| | ``` |
| |
|
| | ## License |
| |
|
| | MIT |
| |
|