File size: 3,536 Bytes
0b51134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import torch.nn.init as init

# --- CONFIGURATION ---
INPUT_CELLS = 81
NUM_CLASSES = 10
HIDDEN_DIM = 128
ATTN_HEADS = 4   # MUST match training script

class StandardAttention2D(nn.Module):
    """
    Standard O(N^2) Multi-Head Attention for 2D grids.
    Zero-initialized output projection to start as identity.
    """
    def __init__(self, dim, heads=ATTN_HEADS):
        super().__init__()
        self.scale = dim ** -0.5
        self.heads = heads
        self.head_dim = dim // heads

        self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
        self.to_out = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GroupNorm(8, dim)
        )

        # Zero-init so attention starts as a no-op
        init.zeros_(self.to_out[0].weight)
        init.zeros_(self.to_out[0].bias)

    def forward(self, x):
        b, c, h, w = x.shape
        n = h * w

        qkv = self.to_qkv(x).view(b, 3 * c, n)
        q, k, v = qkv.chunk(3, dim=1)

        q = q.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
        k = k.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)
        v = v.view(b, self.heads, self.head_dim, n).permute(0, 1, 3, 2)

        dots = (q @ k.transpose(-2, -1)) * self.scale
        attn = dots.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(b, c, h, w)
        return self.to_out(out) + x


class UniversalPotato(nn.Module):
    """
    EXACT match to the Colab-trained HybridPotato architecture.
    No positional embeddings. Blindfold-compatible.
    """
    def __init__(self):
        super().__init__()

        self.embed_clues = nn.Embedding(NUM_CLASSES, HIDDEN_DIM)
        self.embed_board = nn.Embedding(NUM_CLASSES, HIDDEN_DIM)

        self.input_proj = nn.Sequential(
            nn.Conv2d(HIDDEN_DIM * 3, HIDDEN_DIM, kernel_size=1),
            nn.GroupNorm(8, HIDDEN_DIM),
            nn.SiLU()
        )

        self.core = nn.Sequential(
            # Local
            nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=1),
            nn.GroupNorm(8, HIDDEN_DIM),
            nn.SiLU(),

            # Global
            StandardAttention2D(HIDDEN_DIM),
            nn.SiLU(),

            # Mid-range
            nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=2, dilation=2),
            nn.GroupNorm(8, HIDDEN_DIM),
            nn.SiLU(),

            # Processing
            nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, padding=4, dilation=4),
            nn.GroupNorm(8, HIDDEN_DIM),
            nn.SiLU()
        )

        self.head = nn.Conv2d(HIDDEN_DIM, NUM_CLASSES, kernel_size=1)
        self.memory_norm = nn.GroupNorm(8, HIDDEN_DIM)

    def run_core(self, x):
        return self.core(x)

    def forward(self, clues, current_board, memory, blindfold=False):
        b, n = clues.shape

        clues_emb = (
            self.embed_clues(clues)
            .transpose(1, 2)
            .view(b, HIDDEN_DIM, 9, 9)
        )

        board_emb = (
            self.embed_board(current_board)
            .transpose(1, 2)
            .view(b, HIDDEN_DIM, 9, 9)
        )

        if blindfold:
            board_emb = torch.zeros_like(board_emb)

        raw = torch.cat([clues_emb, board_emb, memory], dim=1)
        z = self.input_proj(raw)
        z = self.core(z)

        new_memory = self.memory_norm(memory + z)

        logits = (
            self.head(z)
            .view(b, NUM_CLASSES, 81)
            .transpose(1, 2)
        )

        return logits, new_memory