PotatoAGI / model.py
ebuzertaha's picture
Initial commit with Xet-managed safetensors
0b51134
import torch
import torch.nn as nn
import torch.nn.init as init
# --- CONFIGURATION ---
INPUT_CELLS = 81
NUM_CLASSES = 10
HIDDEN_DIM = 128
ATTN_HEADS = 4 # MUST match training script
class StandardAttention2D(nn.Module):
"""
Standard O(N^2) Multi-Head Attention for 2D grids.
Zero-initialized output projection to start as identity.
"""
def __init__(self, dim, heads=ATTN_HEADS):
super().__init__()
self.scale = dim ** -0.5
self.heads = heads
self.head_dim = dim // heads
self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
self.to_out = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1),
nn.GroupNorm(8, dim)
)
# Zero-init so attention starts as a no-op
init.zeros_(self.to_out[0].weight)
init.zeros_(self.to_out[0].bias)
def forward(self, x):
b, c, h, w = x.shape
n = h * w
qkv = self.to_qkv(x).view(b, 3 * c, n)
q, k, v = qkv.chunk(3, dim=1)
q = q.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
k = k.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
v = v.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
dots = (q @ k.transpose(-2, -1)) * self.scale
attn = dots.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(b, c, h, w)
return self.to_out(out) + x
class UniversalPotato(nn.Module):
"""
EXACT match to the Colab-trained HybridPotato architecture.
No positional embeddings. Blindfold-compatible.
"""
def __init__(self):
super().__init__()
self.embed_clues = nn.Embedding(NUM_CLASSES, HIDDEN_DIM)
self.embed_board = nn.Embedding(NUM_CLASSES, HIDDEN_DIM)
self.input_proj = nn.Sequential(
nn.Conv2d(HIDDEN_DIM * 3, HIDDEN_DIM, kernel_size=1),
nn.GroupNorm(8, HIDDEN_DIM),
nn.SiLU()
)
self.core = nn.Sequential(
# Local
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=1),
nn.GroupNorm(8, HIDDEN_DIM),
nn.SiLU(),
# Global
StandardAttention2D(HIDDEN_DIM),
nn.SiLU(),
# Mid-range
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=2, dilation=2),
nn.GroupNorm(8, HIDDEN_DIM),
nn.SiLU(),
# Processing
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=4, dilation=4),
nn.GroupNorm(8, HIDDEN_DIM),
nn.SiLU()
)
self.head = nn.Conv2d(HIDDEN_DIM, NUM_CLASSES, kernel_size=1)
self.memory_norm = nn.GroupNorm(8, HIDDEN_DIM)
def run_core(self, x):
return self.core(x)
def forward(self, clues, current_board, memory, blindfold=False):
b, n = clues.shape
clues_emb = (
self.embed_clues(clues)
.transpose(1, 2)
.view(b, HIDDEN_DIM, 9, 9)
)
board_emb = (
self.embed_board(current_board)
.transpose(1, 2)
.view(b, HIDDEN_DIM, 9, 9)
)
if blindfold:
board_emb = torch.zeros_like(board_emb)
raw = torch.cat([clues_emb, board_emb, memory], dim=1)
z = self.input_proj(raw)
z = self.core(z)
new_memory = self.memory_norm(memory + z)
logits = (
self.head(z)
.view(b, NUM_CLASSES, 81)
.transpose(1, 2)
)
return logits, new_memory