Commit
·
0b51134
0
Parent(s):
Initial commit with Xet-managed safetensors
Browse files- .gitattributes +1 -0
- .gitignore +108 -0
- README.md +130 -0
- local_test_sudoku.py +124 -0
- model.py +125 -0
- model.safetensors +3 -0
- model_100k.safetensors +3 -0
- pyproject.toml +14 -0
- uv.lock +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================
|
| 2 |
+
# Python
|
| 3 |
+
# =========================
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
*.so
|
| 9 |
+
*.egg-info/
|
| 10 |
+
.eggs/
|
| 11 |
+
.env
|
| 12 |
+
.venv
|
| 13 |
+
venv/
|
| 14 |
+
ENV/
|
| 15 |
+
env/
|
| 16 |
+
sudoku.csv
|
| 17 |
+
|
| 18 |
+
# =========================
|
| 19 |
+
# PyTorch / ML
|
| 20 |
+
# =========================
|
| 21 |
+
*.pt
|
| 22 |
+
*.pth
|
| 23 |
+
*.ckpt
|
| 24 |
+
*.bin
|
| 25 |
+
*.onnx
|
| 26 |
+
*.trt
|
| 27 |
+
*.engine
|
| 28 |
+
|
| 29 |
+
# Allow safetensors (HF preferred)
|
| 30 |
+
!*.safetensors
|
| 31 |
+
|
| 32 |
+
# =========================
|
| 33 |
+
# Training / Runtime
|
| 34 |
+
# =========================
|
| 35 |
+
runs/
|
| 36 |
+
logs/
|
| 37 |
+
lightning_logs/
|
| 38 |
+
wandb/
|
| 39 |
+
mlruns/
|
| 40 |
+
tensorboard/
|
| 41 |
+
tb_logs/
|
| 42 |
+
.prof
|
| 43 |
+
*.log
|
| 44 |
+
|
| 45 |
+
# =========================
|
| 46 |
+
# Datasets (do NOT upload raw datasets)
|
| 47 |
+
# =========================
|
| 48 |
+
data/
|
| 49 |
+
datasets/
|
| 50 |
+
*.csv
|
| 51 |
+
*.tsv
|
| 52 |
+
*.parquet
|
| 53 |
+
*.arrow
|
| 54 |
+
*.jsonl
|
| 55 |
+
*.hdf5
|
| 56 |
+
*.npz
|
| 57 |
+
|
| 58 |
+
# =========================
|
| 59 |
+
# Caches
|
| 60 |
+
# =========================
|
| 61 |
+
.cache/
|
| 62 |
+
huggingface/
|
| 63 |
+
hf_cache/
|
| 64 |
+
torch_cache/
|
| 65 |
+
transformers_cache/
|
| 66 |
+
|
| 67 |
+
# =========================
|
| 68 |
+
# Jupyter / Colab
|
| 69 |
+
# =========================
|
| 70 |
+
.ipynb_checkpoints/
|
| 71 |
+
*.ipynb
|
| 72 |
+
*.colab
|
| 73 |
+
|
| 74 |
+
# =========================
|
| 75 |
+
# OS / Editor
|
| 76 |
+
# =========================
|
| 77 |
+
.DS_Store
|
| 78 |
+
Thumbs.db
|
| 79 |
+
*.swp
|
| 80 |
+
*.swo
|
| 81 |
+
.idea/
|
| 82 |
+
.vscode/
|
| 83 |
+
.history/
|
| 84 |
+
|
| 85 |
+
# =========================
|
| 86 |
+
# Build / Packaging
|
| 87 |
+
# =========================
|
| 88 |
+
dist/
|
| 89 |
+
build/
|
| 90 |
+
*.tar.gz
|
| 91 |
+
*.zip
|
| 92 |
+
|
| 93 |
+
# =========================
|
| 94 |
+
# Secrets
|
| 95 |
+
# =========================
|
| 96 |
+
*.key
|
| 97 |
+
*.pem
|
| 98 |
+
*.crt
|
| 99 |
+
*.token
|
| 100 |
+
.env.*
|
| 101 |
+
|
| 102 |
+
# =========================
|
| 103 |
+
# Temporary / Scratch
|
| 104 |
+
# =========================
|
| 105 |
+
tmp/
|
| 106 |
+
temp/
|
| 107 |
+
scratch/
|
| 108 |
+
|
README.md
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
tags:
|
| 4 |
+
- sudoku
|
| 5 |
+
- reasoning
|
| 6 |
+
- pytorch
|
| 7 |
+
- rhan
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# PotatoAGI (RHAN-Sudoku)
|
| 11 |
+
|
| 12 |
+
This is the official weight repository for the **Recurrent Hybrid Attention Network (RHAN)** trained on Sudoku.
|
| 13 |
+
|
| 14 |
+
It uses a **Universal Linear Attention** mechanism combined with **Recursive Memory** and was trained using **Adversarial Erasure**.
|
| 15 |
+
|
| 16 |
+
## Stats
|
| 17 |
+
- **Parameters:** ~150k
|
| 18 |
+
- **Architecture:** 12-Loop Recurrent CNN + Linear Attention
|
| 19 |
+
- **Accuracy:** 99% Cell Accuracy / 90%+ Perfect Solve Rate
|
| 20 |
+
- **License:** CC BY-NC 4.0 (Non-Commercial Research Use Only)
|
| 21 |
+
|
| 22 |
+
## Files in this Repository
|
| 23 |
+
|
| 24 |
+
model.py # Model architecture (UniversalPotato)
|
| 25 |
+
model.safetensors # Trained weights
|
| 26 |
+
local_test_sudoku.py # Dataset-based local evaluation
|
| 27 |
+
README.md
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
### 1️⃣ Install dependencies
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
pip install torch safetensors
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Python ≥ 3.10 recommended.
|
| 37 |
+
|
| 38 |
+
2️⃣ Load the model and weights
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
from safetensors.torch import load_file
|
| 42 |
+
from model import UniversalPotato, HIDDEN_DIM
|
| 43 |
+
|
| 44 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
+
|
| 46 |
+
model = UniversalPotato().to(device)
|
| 47 |
+
model.load_state_dict(load_file("model.safetensors"), strict=True)
|
| 48 |
+
model.eval()
|
| 49 |
+
|
| 50 |
+
3️⃣ Run inference on a single Sudoku puzzle
|
| 51 |
+
|
| 52 |
+
Sudoku grids are represented as a flat tensor of length 81,
|
| 53 |
+
with 0 indicating empty cells.
|
| 54 |
+
|
| 55 |
+
# Example puzzle (0 = empty)
|
| 56 |
+
puzzle = [
|
| 57 |
+
5,3,0,0,7,0,0,0,0,
|
| 58 |
+
6,0,0,1,9,5,0,0,0,
|
| 59 |
+
0,9,8,0,0,0,0,6,0,
|
| 60 |
+
8,0,0,0,6,0,0,0,3,
|
| 61 |
+
4,0,0,8,0,3,0,0,1,
|
| 62 |
+
7,0,0,0,2,0,0,0,6,
|
| 63 |
+
0,6,0,0,0,0,2,8,0,
|
| 64 |
+
0,0,0,4,1,9,0,0,5,
|
| 65 |
+
0,0,0,0,8,0,0,7,9,
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
clues = torch.tensor(puzzle, dtype=torch.long).unsqueeze(0).to(device)
|
| 69 |
+
board = clues.clone()
|
| 70 |
+
memory = torch.zeros(1, HIDDEN_DIM, 9, 9, device=device)
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
for _ in range(24): # reasoning steps
|
| 74 |
+
logits, memory = model(
|
| 75 |
+
clues=clues,
|
| 76 |
+
current_board=board,
|
| 77 |
+
memory=memory,
|
| 78 |
+
blindfold=False,
|
| 79 |
+
)
|
| 80 |
+
board = logits.argmax(dim=-1)
|
| 81 |
+
|
| 82 |
+
solution = board.view(9, 9).cpu()
|
| 83 |
+
print(solution)
|
| 84 |
+
|
| 85 |
+
4️⃣ Dataset-based evaluation
|
| 86 |
+
|
| 87 |
+
To evaluate the model on a real Sudoku dataset:
|
| 88 |
+
|
| 89 |
+
Download sudoku.csv from Kaggle
|
| 90 |
+
👉 https://www.kaggle.com/datasets/rohanrao/sudoku
|
| 91 |
+
|
| 92 |
+
Place it in the repository root
|
| 93 |
+
|
| 94 |
+
Run:
|
| 95 |
+
|
| 96 |
+
python local_test_sudoku.py
|
| 97 |
+
|
| 98 |
+
This script:
|
| 99 |
+
|
| 100 |
+
runs multi-step inference
|
| 101 |
+
|
| 102 |
+
compares predictions against ground truth
|
| 103 |
+
|
| 104 |
+
reports solve success rate
|
| 105 |
+
|
| 106 |
+
Notes
|
| 107 |
+
|
| 108 |
+
This model does not use Hugging Face Transformers
|
| 109 |
+
|
| 110 |
+
model.py is the authoritative architecture definition
|
| 111 |
+
|
| 112 |
+
Inference requires multiple recurrent steps for best results
|
| 113 |
+
|
| 114 |
+
Designed for reasoning research, not commercial deployment
|
| 115 |
+
|
| 116 |
+
License
|
| 117 |
+
|
| 118 |
+
This project is released under CC BY-NC 4.0.
|
| 119 |
+
|
| 120 |
+
You may:
|
| 121 |
+
|
| 122 |
+
use
|
| 123 |
+
|
| 124 |
+
modify
|
| 125 |
+
|
| 126 |
+
redistribute
|
| 127 |
+
for non-commercial research purposes only, with attribution.
|
| 128 |
+
|
| 129 |
+
Commercial use is not permitted.
|
| 130 |
+
|
local_test_sudoku.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import torch
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
|
| 6 |
+
from model import UniversalPotato, HIDDEN_DIM
|
| 7 |
+
|
| 8 |
+
CSV_PATH = "sudoku.csv"
|
| 9 |
+
WEIGHTS_PATH = "model.safetensors"
|
| 10 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
+
STEPS = 24
|
| 13 |
+
MAX_PUZZLES = 50
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def require_sudoku_csv(path: str):
|
| 17 |
+
if not os.path.exists(path):
|
| 18 |
+
raise FileNotFoundError(
|
| 19 |
+
"""
|
| 20 |
+
sudoku.csv not found.
|
| 21 |
+
|
| 22 |
+
Please download it manually from:
|
| 23 |
+
https://www.kaggle.com/datasets/rohanrao/sudoku
|
| 24 |
+
|
| 25 |
+
Then place sudoku.csv in the project root.
|
| 26 |
+
"""
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_sudoku_csv(path: str, limit: int):
|
| 31 |
+
puzzles = []
|
| 32 |
+
solutions = []
|
| 33 |
+
|
| 34 |
+
with open(path, newline="") as f:
|
| 35 |
+
reader = csv.DictReader(f)
|
| 36 |
+
|
| 37 |
+
if not reader.fieldnames:
|
| 38 |
+
raise RuntimeError("sudoku.csv has no header row")
|
| 39 |
+
|
| 40 |
+
# Normalize headers
|
| 41 |
+
fieldnames_lower = [h.lower() for h in reader.fieldnames]
|
| 42 |
+
|
| 43 |
+
def pick(*candidates):
|
| 44 |
+
for name in candidates:
|
| 45 |
+
if name in fieldnames_lower:
|
| 46 |
+
return reader.fieldnames[fieldnames_lower.index(name)]
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
puzzle_key = pick("puzzle", "quiz", "quizzes")
|
| 50 |
+
solution_key = pick("solution", "solutions")
|
| 51 |
+
|
| 52 |
+
if puzzle_key is None or solution_key is None:
|
| 53 |
+
raise RuntimeError(
|
| 54 |
+
f"Unsupported sudoku.csv format. Headers found: {reader.fieldnames}"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
for i, row in enumerate(reader):
|
| 58 |
+
if i >= limit:
|
| 59 |
+
break
|
| 60 |
+
puzzles.append(row[puzzle_key])
|
| 61 |
+
solutions.append(row[solution_key])
|
| 62 |
+
|
| 63 |
+
return puzzles, solutions
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def str_to_tensor(grid_str: str) -> torch.Tensor:
|
| 67 |
+
return torch.tensor([int(c) for c in grid_str], dtype=torch.long)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def tensor_to_str(t: torch.Tensor) -> str:
|
| 71 |
+
return "".join(str(int(x)) for x in t)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def run_inference(model, clues: torch.Tensor, steps: int):
|
| 75 |
+
clues = clues.unsqueeze(0).to(DEVICE)
|
| 76 |
+
board = clues.clone()
|
| 77 |
+
memory = torch.zeros(1, HIDDEN_DIM, 9, 9, device=DEVICE)
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
for _ in range(steps):
|
| 81 |
+
logits, memory = model(
|
| 82 |
+
clues=clues,
|
| 83 |
+
current_board=board,
|
| 84 |
+
memory=memory,
|
| 85 |
+
blindfold=False,
|
| 86 |
+
)
|
| 87 |
+
board = logits.argmax(dim=-1)
|
| 88 |
+
|
| 89 |
+
return board.squeeze(0).cpu()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main():
|
| 93 |
+
require_sudoku_csv(CSV_PATH)
|
| 94 |
+
|
| 95 |
+
puzzles, solutions = load_sudoku_csv(CSV_PATH, MAX_PUZZLES)
|
| 96 |
+
|
| 97 |
+
model = UniversalPotato().to(DEVICE)
|
| 98 |
+
model.load_state_dict(load_file(WEIGHTS_PATH), strict=True)
|
| 99 |
+
model.eval()
|
| 100 |
+
|
| 101 |
+
solved = 0
|
| 102 |
+
|
| 103 |
+
for i, (quiz, solution) in enumerate(zip(puzzles, solutions), 1):
|
| 104 |
+
clues = str_to_tensor(quiz)
|
| 105 |
+
target = solution
|
| 106 |
+
|
| 107 |
+
pred = run_inference(model, clues, STEPS)
|
| 108 |
+
pred_str = tensor_to_str(pred)
|
| 109 |
+
|
| 110 |
+
success = pred_str == target
|
| 111 |
+
solved += int(success)
|
| 112 |
+
|
| 113 |
+
print(f"\nPuzzle {i}")
|
| 114 |
+
print("Solved:", success)
|
| 115 |
+
print(pred.view(9, 9))
|
| 116 |
+
|
| 117 |
+
print("\n==============================")
|
| 118 |
+
print(f"Solved {solved}/{len(puzzles)} puzzles")
|
| 119 |
+
print("==============================")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
main()
|
| 124 |
+
|
model.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.init as init
|
| 4 |
+
|
| 5 |
+
# --- CONFIGURATION ---
|
| 6 |
+
INPUT_CELLS = 81
|
| 7 |
+
NUM_CLASSES = 10
|
| 8 |
+
HIDDEN_DIM = 128
|
| 9 |
+
ATTN_HEADS = 4 # MUST match training script
|
| 10 |
+
|
| 11 |
+
class StandardAttention2D(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Standard O(N^2) Multi-Head Attention for 2D grids.
|
| 14 |
+
Zero-initialized output projection to start as identity.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, dim, heads=ATTN_HEADS):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.scale = dim ** -0.5
|
| 19 |
+
self.heads = heads
|
| 20 |
+
self.head_dim = dim // heads
|
| 21 |
+
|
| 22 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
|
| 23 |
+
self.to_out = nn.Sequential(
|
| 24 |
+
nn.Conv2d(dim, dim, kernel_size=1),
|
| 25 |
+
nn.GroupNorm(8, dim)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Zero-init so attention starts as a no-op
|
| 29 |
+
init.zeros_(self.to_out[0].weight)
|
| 30 |
+
init.zeros_(self.to_out[0].bias)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
b, c, h, w = x.shape
|
| 34 |
+
n = h * w
|
| 35 |
+
|
| 36 |
+
qkv = self.to_qkv(x).view(b, 3 * c, n)
|
| 37 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 38 |
+
|
| 39 |
+
q = q.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
|
| 40 |
+
k = k.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
|
| 41 |
+
v = v.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
|
| 42 |
+
|
| 43 |
+
dots = (q @ k.transpose(-2, -1)) * self.scale
|
| 44 |
+
attn = dots.softmax(dim=-1)
|
| 45 |
+
|
| 46 |
+
out = (attn @ v).transpose(1, 2).reshape(b, c, h, w)
|
| 47 |
+
return self.to_out(out) + x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class UniversalPotato(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
EXACT match to the Colab-trained HybridPotato architecture.
|
| 53 |
+
No positional embeddings. Blindfold-compatible.
|
| 54 |
+
"""
|
| 55 |
+
def __init__(self):
|
| 56 |
+
super().__init__()
|
| 57 |
+
|
| 58 |
+
self.embed_clues = nn.Embedding(NUM_CLASSES, HIDDEN_DIM)
|
| 59 |
+
self.embed_board = nn.Embedding(NUM_CLASSES, HIDDEN_DIM)
|
| 60 |
+
|
| 61 |
+
self.input_proj = nn.Sequential(
|
| 62 |
+
nn.Conv2d(HIDDEN_DIM * 3, HIDDEN_DIM, kernel_size=1),
|
| 63 |
+
nn.GroupNorm(8, HIDDEN_DIM),
|
| 64 |
+
nn.SiLU()
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.core = nn.Sequential(
|
| 68 |
+
# Local
|
| 69 |
+
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=1),
|
| 70 |
+
nn.GroupNorm(8, HIDDEN_DIM),
|
| 71 |
+
nn.SiLU(),
|
| 72 |
+
|
| 73 |
+
# Global
|
| 74 |
+
StandardAttention2D(HIDDEN_DIM),
|
| 75 |
+
nn.SiLU(),
|
| 76 |
+
|
| 77 |
+
# Mid-range
|
| 78 |
+
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=2, dilation=2),
|
| 79 |
+
nn.GroupNorm(8, HIDDEN_DIM),
|
| 80 |
+
nn.SiLU(),
|
| 81 |
+
|
| 82 |
+
# Processing
|
| 83 |
+
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=4, dilation=4),
|
| 84 |
+
nn.GroupNorm(8, HIDDEN_DIM),
|
| 85 |
+
nn.SiLU()
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.head = nn.Conv2d(HIDDEN_DIM, NUM_CLASSES, kernel_size=1)
|
| 89 |
+
self.memory_norm = nn.GroupNorm(8, HIDDEN_DIM)
|
| 90 |
+
|
| 91 |
+
def run_core(self, x):
|
| 92 |
+
return self.core(x)
|
| 93 |
+
|
| 94 |
+
def forward(self, clues, current_board, memory, blindfold=False):
|
| 95 |
+
b, n = clues.shape
|
| 96 |
+
|
| 97 |
+
clues_emb = (
|
| 98 |
+
self.embed_clues(clues)
|
| 99 |
+
.transpose(1, 2)
|
| 100 |
+
.view(b, HIDDEN_DIM, 9, 9)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
board_emb = (
|
| 104 |
+
self.embed_board(current_board)
|
| 105 |
+
.transpose(1, 2)
|
| 106 |
+
.view(b, HIDDEN_DIM, 9, 9)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if blindfold:
|
| 110 |
+
board_emb = torch.zeros_like(board_emb)
|
| 111 |
+
|
| 112 |
+
raw = torch.cat([clues_emb, board_emb, memory], dim=1)
|
| 113 |
+
z = self.input_proj(raw)
|
| 114 |
+
z = self.core(z)
|
| 115 |
+
|
| 116 |
+
new_memory = self.memory_norm(memory + z)
|
| 117 |
+
|
| 118 |
+
logits = (
|
| 119 |
+
self.head(z)
|
| 120 |
+
.view(b, NUM_CLASSES, 81)
|
| 121 |
+
.transpose(1, 2)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return logits, new_memory
|
| 125 |
+
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b0fe7f3855cf4e18b9a13d0cb83a3c741f81be5e504825d731f5c1cfd44eca7
|
| 3 |
+
size 2254560
|
model_100k.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:052d53a28bc9141494e6e1f3a467be4a659ce61c203e1001dfee56179a9aa93a
|
| 3 |
+
size 2254560
|
pyproject.toml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "potato-agi"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "Official weights for PotatoAGI (RHAN)"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"torch",
|
| 9 |
+
"numpy",
|
| 10 |
+
"safetensors",
|
| 11 |
+
"requests",
|
| 12 |
+
"pandas",
|
| 13 |
+
"packaging"
|
| 14 |
+
]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|