|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.init as init |
|
|
|
|
|
|
|
|
INPUT_CELLS = 81 |
|
|
NUM_CLASSES = 10 |
|
|
HIDDEN_DIM = 128 |
|
|
ATTN_HEADS = 4 |
|
|
|
|
|
class StandardAttention2D(nn.Module): |
|
|
""" |
|
|
Standard O(N^2) Multi-Head Attention for 2D grids. |
|
|
Zero-initialized output projection to start as identity. |
|
|
""" |
|
|
def __init__(self, dim, heads=ATTN_HEADS): |
|
|
super().__init__() |
|
|
self.scale = dim ** -0.5 |
|
|
self.heads = heads |
|
|
self.head_dim = dim // heads |
|
|
|
|
|
self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False) |
|
|
self.to_out = nn.Sequential( |
|
|
nn.Conv2d(dim, dim, kernel_size=1), |
|
|
nn.GroupNorm(8, dim) |
|
|
) |
|
|
|
|
|
|
|
|
init.zeros_(self.to_out[0].weight) |
|
|
init.zeros_(self.to_out[0].bias) |
|
|
|
|
|
def forward(self, x): |
|
|
b, c, h, w = x.shape |
|
|
n = h * w |
|
|
|
|
|
qkv = self.to_qkv(x).view(b, 3 * c, n) |
|
|
q, k, v = qkv.chunk(3, dim=1) |
|
|
|
|
|
q = q.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2) |
|
|
k = k.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2) |
|
|
v = v.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2) |
|
|
|
|
|
dots = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = dots.softmax(dim=-1) |
|
|
|
|
|
out = (attn @ v).transpose(1, 2).reshape(b, c, h, w) |
|
|
return self.to_out(out) + x |
|
|
|
|
|
|
|
|
class UniversalPotato(nn.Module): |
|
|
""" |
|
|
EXACT match to the Colab-trained HybridPotato architecture. |
|
|
No positional embeddings. Blindfold-compatible. |
|
|
""" |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.embed_clues = nn.Embedding(NUM_CLASSES, HIDDEN_DIM) |
|
|
self.embed_board = nn.Embedding(NUM_CLASSES, HIDDEN_DIM) |
|
|
|
|
|
self.input_proj = nn.Sequential( |
|
|
nn.Conv2d(HIDDEN_DIM * 3, HIDDEN_DIM, kernel_size=1), |
|
|
nn.GroupNorm(8, HIDDEN_DIM), |
|
|
nn.SiLU() |
|
|
) |
|
|
|
|
|
self.core = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=1), |
|
|
nn.GroupNorm(8, HIDDEN_DIM), |
|
|
nn.SiLU(), |
|
|
|
|
|
|
|
|
StandardAttention2D(HIDDEN_DIM), |
|
|
nn.SiLU(), |
|
|
|
|
|
|
|
|
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=2, dilation=2), |
|
|
nn.GroupNorm(8, HIDDEN_DIM), |
|
|
nn.SiLU(), |
|
|
|
|
|
|
|
|
nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=4, dilation=4), |
|
|
nn.GroupNorm(8, HIDDEN_DIM), |
|
|
nn.SiLU() |
|
|
) |
|
|
|
|
|
self.head = nn.Conv2d(HIDDEN_DIM, NUM_CLASSES, kernel_size=1) |
|
|
self.memory_norm = nn.GroupNorm(8, HIDDEN_DIM) |
|
|
|
|
|
def run_core(self, x): |
|
|
return self.core(x) |
|
|
|
|
|
def forward(self, clues, current_board, memory, blindfold=False): |
|
|
b, n = clues.shape |
|
|
|
|
|
clues_emb = ( |
|
|
self.embed_clues(clues) |
|
|
.transpose(1, 2) |
|
|
.view(b, HIDDEN_DIM, 9, 9) |
|
|
) |
|
|
|
|
|
board_emb = ( |
|
|
self.embed_board(current_board) |
|
|
.transpose(1, 2) |
|
|
.view(b, HIDDEN_DIM, 9, 9) |
|
|
) |
|
|
|
|
|
if blindfold: |
|
|
board_emb = torch.zeros_like(board_emb) |
|
|
|
|
|
raw = torch.cat([clues_emb, board_emb, memory], dim=1) |
|
|
z = self.input_proj(raw) |
|
|
z = self.core(z) |
|
|
|
|
|
new_memory = self.memory_norm(memory + z) |
|
|
|
|
|
logits = ( |
|
|
self.head(z) |
|
|
.view(b, NUM_CLASSES, 81) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
|
|
|
return logits, new_memory |
|
|
|
|
|
|