File size: 5,283 Bytes
4debe27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""

Save pre-trained decoder weights for HuggingFace zero-latency demo.

Run this locally (RTX 5080) to generate decoder.pth.

Upload decoder.pth to HF Spaces alongside app.py.



IMPORTANT: Architecture must EXACTLY match LightweightBrainDecoder in app.py

"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

# ---------- Decoder (EXACT copy from app.py) ----------

class LightweightBrainDecoder(nn.Module):
    """Must be identical to the class in app.py"""
    def __init__(self, latent_dim=16, num_steps=4):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_steps = num_steps

        # Encoder (for training)
        self.enc_conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1)   # 28->14
        self.enc_bn1 = nn.BatchNorm2d(16)
        self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)  # 14->7
        self.enc_bn2 = nn.BatchNorm2d(32)
        self.enc_fc = nn.Linear(32 * 7 * 7, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        nn.init.constant_(self.fc_logvar.bias, -5.0)

        # Decoder (the brain state visualizer)
        self.dec_fc1 = nn.Linear(latent_dim, 128)
        self.dec_fc2 = nn.Linear(128, 32 * 7 * 7)
        self.dec_deconv1 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1)  # 7->14
        self.dec_bn1 = nn.BatchNorm2d(16)
        self.dec_deconv2 = nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1)   # 14->28

    def encode(self, x):
        h = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1)
        h = F.leaky_relu(self.enc_bn2(self.enc_conv2(h)), 0.1)
        h = h.view(h.size(0), -1)
        h = F.leaky_relu(self.enc_fc(h), 0.1)
        return self.fc_mu(h), self.fc_logvar(h)

    def decode(self, z):
        # Temporal averaging (SNN-like behavior)
        output_sum = torch.zeros(z.size(0), 1, 28, 28, device=z.device)
        for t in range(self.num_steps):
            noise = torch.randn_like(z) * 0.05 * (1 - t / self.num_steps)
            z_t = z + noise
            h = F.leaky_relu(self.dec_fc1(z_t), 0.1)
            h = F.leaky_relu(self.dec_fc2(h), 0.1)
            h = h.view(-1, 32, 7, 7)
            h = F.leaky_relu(self.dec_bn1(self.dec_deconv1(h)), 0.1)
            h = self.dec_deconv2(h)
            output_sum += h
        return torch.sigmoid(output_sum / self.num_steps)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    model = LightweightBrainDecoder(latent_dim=16, num_steps=4).to(device)
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Load Fashion-MNIST
    transform = transforms.Compose([transforms.ToTensor()])
    train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
    loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)

    epochs = 15
    print(f"\nTraining for {epochs} epochs...")
    t0 = time.time()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        beta_kl = min(1.0, epoch / 3.0)

        for data, _ in loader:
            data = data.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            bce = F.binary_cross_entropy(recon, data, reduction='sum')
            kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            kld = torch.clamp(kld, min=0)
            loss = bce + beta_kl * kld
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        avg = total_loss / len(train_ds)
        lr = optimizer.param_groups[0]['lr']
        print(f"  Epoch {epoch+1:2d}/{epochs} | Loss={avg:.1f} | LR={lr:.5f}")

    elapsed = time.time() - t0
    print(f"\nTraining done in {elapsed:.1f}s")

    # Save weights (CPU for portability)
    model.eval()
    model.cpu()
    save_path = "decoder.pth"
    torch.save(model.state_dict(), save_path)

    import os
    size_kb = os.path.getsize(save_path) / 1024
    print(f"Saved: {save_path} ({size_kb:.0f} KB)")

    # Quick quality check
    print("\nQuality check: decoding 5 random latent vectors...")
    with torch.no_grad():
        for i in range(5):
            z = torch.randn(1, 16)
            img = model.decode(z).squeeze().numpy()
            print(f"  z[{i}]: min={img.min():.3f} max={img.max():.3f} mean={img.mean():.3f}")

    print("\nDone! Upload decoder.pth to HuggingFace Spaces.")


if __name__ == "__main__":
    main()