Spaces:
Running
Running
File size: 5,283 Bytes
4debe27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """
Save pre-trained decoder weights for HuggingFace zero-latency demo.
Run this locally (RTX 5080) to generate decoder.pth.
Upload decoder.pth to HF Spaces alongside app.py.
IMPORTANT: Architecture must EXACTLY match LightweightBrainDecoder in app.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
# ---------- Decoder (EXACT copy from app.py) ----------
class LightweightBrainDecoder(nn.Module):
"""Must be identical to the class in app.py"""
def __init__(self, latent_dim=16, num_steps=4):
super().__init__()
self.latent_dim = latent_dim
self.num_steps = num_steps
# Encoder (for training)
self.enc_conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1) # 28->14
self.enc_bn1 = nn.BatchNorm2d(16)
self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1) # 14->7
self.enc_bn2 = nn.BatchNorm2d(32)
self.enc_fc = nn.Linear(32 * 7 * 7, 128)
self.fc_mu = nn.Linear(128, latent_dim)
self.fc_logvar = nn.Linear(128, latent_dim)
nn.init.constant_(self.fc_logvar.bias, -5.0)
# Decoder (the brain state visualizer)
self.dec_fc1 = nn.Linear(latent_dim, 128)
self.dec_fc2 = nn.Linear(128, 32 * 7 * 7)
self.dec_deconv1 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) # 7->14
self.dec_bn1 = nn.BatchNorm2d(16)
self.dec_deconv2 = nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1) # 14->28
def encode(self, x):
h = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1)
h = F.leaky_relu(self.enc_bn2(self.enc_conv2(h)), 0.1)
h = h.view(h.size(0), -1)
h = F.leaky_relu(self.enc_fc(h), 0.1)
return self.fc_mu(h), self.fc_logvar(h)
def decode(self, z):
# Temporal averaging (SNN-like behavior)
output_sum = torch.zeros(z.size(0), 1, 28, 28, device=z.device)
for t in range(self.num_steps):
noise = torch.randn_like(z) * 0.05 * (1 - t / self.num_steps)
z_t = z + noise
h = F.leaky_relu(self.dec_fc1(z_t), 0.1)
h = F.leaky_relu(self.dec_fc2(h), 0.1)
h = h.view(-1, 32, 7, 7)
h = F.leaky_relu(self.dec_bn1(self.dec_deconv1(h)), 0.1)
h = self.dec_deconv2(h)
output_sum += h
return torch.sigmoid(output_sum / self.num_steps)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
return mu + torch.randn_like(std) * std
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
model = LightweightBrainDecoder(latent_dim=16, num_steps=4).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load Fashion-MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)
epochs = 15
print(f"\nTraining for {epochs} epochs...")
t0 = time.time()
for epoch in range(epochs):
model.train()
total_loss = 0
beta_kl = min(1.0, epoch / 3.0)
for data, _ in loader:
data = data.to(device)
optimizer.zero_grad()
recon, mu, logvar = model(data)
bce = F.binary_cross_entropy(recon, data, reduction='sum')
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
kld = torch.clamp(kld, min=0)
loss = bce + beta_kl * kld
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg = total_loss / len(train_ds)
lr = optimizer.param_groups[0]['lr']
print(f" Epoch {epoch+1:2d}/{epochs} | Loss={avg:.1f} | LR={lr:.5f}")
elapsed = time.time() - t0
print(f"\nTraining done in {elapsed:.1f}s")
# Save weights (CPU for portability)
model.eval()
model.cpu()
save_path = "decoder.pth"
torch.save(model.state_dict(), save_path)
import os
size_kb = os.path.getsize(save_path) / 1024
print(f"Saved: {save_path} ({size_kb:.0f} KB)")
# Quick quality check
print("\nQuality check: decoding 5 random latent vectors...")
with torch.no_grad():
for i in range(5):
z = torch.randn(1, 16)
img = model.decode(z).squeeze().numpy()
print(f" z[{i}]: min={img.min():.3f} max={img.max():.3f} mean={img.mean():.3f}")
print("\nDone! Upload decoder.pth to HuggingFace Spaces.")
if __name__ == "__main__":
main()
|