import torch import torch.nn as nn import torch.nn.init as init # --- CONFIGURATION --- INPUT_CELLS = 81 NUM_CLASSES = 10 HIDDEN_DIM = 128 ATTN_HEADS = 4 # MUST match training script class StandardAttention2D(nn.Module): """ Standard O(N^2) Multi-Head Attention for 2D grids. Zero-initialized output projection to start as identity. """ def __init__(self, dim, heads=ATTN_HEADS): super().__init__() self.scale = dim ** -0.5 self.heads = heads self.head_dim = dim // heads self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False) self.to_out = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=1), nn.GroupNorm(8, dim) ) # Zero-init so attention starts as a no-op init.zeros_(self.to_out[0].weight) init.zeros_(self.to_out[0].bias) def forward(self, x): b, c, h, w = x.shape n = h * w qkv = self.to_qkv(x).view(b, 3 * c, n) q, k, v = qkv.chunk(3, dim=1) q = q.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2) k = k.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2) v = v.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2) dots = (q @ k.transpose(-2, -1)) * self.scale attn = dots.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(b, c, h, w) return self.to_out(out) + x class UniversalPotato(nn.Module): """ EXACT match to the Colab-trained HybridPotato architecture. No positional embeddings. Blindfold-compatible. """ def __init__(self): super().__init__() self.embed_clues = nn.Embedding(NUM_CLASSES, HIDDEN_DIM) self.embed_board = nn.Embedding(NUM_CLASSES, HIDDEN_DIM) self.input_proj = nn.Sequential( nn.Conv2d(HIDDEN_DIM * 3, HIDDEN_DIM, kernel_size=1), nn.GroupNorm(8, HIDDEN_DIM), nn.SiLU() ) self.core = nn.Sequential( # Local nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=1), nn.GroupNorm(8, HIDDEN_DIM), nn.SiLU(), # Global StandardAttention2D(HIDDEN_DIM), nn.SiLU(), # Mid-range nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=2, dilation=2), nn.GroupNorm(8, HIDDEN_DIM), nn.SiLU(), # Processing nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=4, dilation=4), nn.GroupNorm(8, HIDDEN_DIM), nn.SiLU() ) self.head = nn.Conv2d(HIDDEN_DIM, NUM_CLASSES, kernel_size=1) self.memory_norm = nn.GroupNorm(8, HIDDEN_DIM) def run_core(self, x): return self.core(x) def forward(self, clues, current_board, memory, blindfold=False): b, n = clues.shape clues_emb = ( self.embed_clues(clues) .transpose(1, 2) .view(b, HIDDEN_DIM, 9, 9) ) board_emb = ( self.embed_board(current_board) .transpose(1, 2) .view(b, HIDDEN_DIM, 9, 9) ) if blindfold: board_emb = torch.zeros_like(board_emb) raw = torch.cat([clues_emb, board_emb, memory], dim=1) z = self.input_proj(raw) z = self.core(z) new_memory = self.memory_norm(memory + z) logits = ( self.head(z) .view(b, NUM_CLASSES, 81) .transpose(1, 2) ) return logits, new_memory