Spaces:
Running
Running
| """ | |
| Save pre-trained decoder weights for HuggingFace zero-latency demo. | |
| Run this locally (RTX 5080) to generate decoder.pth. | |
| Upload decoder.pth to HF Spaces alongside app.py. | |
| IMPORTANT: Architecture must EXACTLY match LightweightBrainDecoder in app.py | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets, transforms | |
| import time | |
| # ---------- Decoder (EXACT copy from app.py) ---------- | |
| class LightweightBrainDecoder(nn.Module): | |
| """Must be identical to the class in app.py""" | |
| def __init__(self, latent_dim=16, num_steps=4): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.num_steps = num_steps | |
| # Encoder (for training) | |
| self.enc_conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1) # 28->14 | |
| self.enc_bn1 = nn.BatchNorm2d(16) | |
| self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1) # 14->7 | |
| self.enc_bn2 = nn.BatchNorm2d(32) | |
| self.enc_fc = nn.Linear(32 * 7 * 7, 128) | |
| self.fc_mu = nn.Linear(128, latent_dim) | |
| self.fc_logvar = nn.Linear(128, latent_dim) | |
| nn.init.constant_(self.fc_logvar.bias, -5.0) | |
| # Decoder (the brain state visualizer) | |
| self.dec_fc1 = nn.Linear(latent_dim, 128) | |
| self.dec_fc2 = nn.Linear(128, 32 * 7 * 7) | |
| self.dec_deconv1 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) # 7->14 | |
| self.dec_bn1 = nn.BatchNorm2d(16) | |
| self.dec_deconv2 = nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1) # 14->28 | |
| def encode(self, x): | |
| h = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1) | |
| h = F.leaky_relu(self.enc_bn2(self.enc_conv2(h)), 0.1) | |
| h = h.view(h.size(0), -1) | |
| h = F.leaky_relu(self.enc_fc(h), 0.1) | |
| return self.fc_mu(h), self.fc_logvar(h) | |
| def decode(self, z): | |
| # Temporal averaging (SNN-like behavior) | |
| output_sum = torch.zeros(z.size(0), 1, 28, 28, device=z.device) | |
| for t in range(self.num_steps): | |
| noise = torch.randn_like(z) * 0.05 * (1 - t / self.num_steps) | |
| z_t = z + noise | |
| h = F.leaky_relu(self.dec_fc1(z_t), 0.1) | |
| h = F.leaky_relu(self.dec_fc2(h), 0.1) | |
| h = h.view(-1, 32, 7, 7) | |
| h = F.leaky_relu(self.dec_bn1(self.dec_deconv1(h)), 0.1) | |
| h = self.dec_deconv2(h) | |
| output_sum += h | |
| return torch.sigmoid(output_sum / self.num_steps) | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| return mu + torch.randn_like(std) * std | |
| def forward(self, x): | |
| mu, logvar = self.encode(x) | |
| z = self.reparameterize(mu, logvar) | |
| return self.decode(z), mu, logvar | |
| def main(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Device: {device}") | |
| model = LightweightBrainDecoder(latent_dim=16, num_steps=4).to(device) | |
| print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # Load Fashion-MNIST | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform) | |
| loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=2e-3) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15) | |
| epochs = 15 | |
| print(f"\nTraining for {epochs} epochs...") | |
| t0 = time.time() | |
| for epoch in range(epochs): | |
| model.train() | |
| total_loss = 0 | |
| beta_kl = min(1.0, epoch / 3.0) | |
| for data, _ in loader: | |
| data = data.to(device) | |
| optimizer.zero_grad() | |
| recon, mu, logvar = model(data) | |
| bce = F.binary_cross_entropy(recon, data, reduction='sum') | |
| kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| kld = torch.clamp(kld, min=0) | |
| loss = bce + beta_kl * kld | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| scheduler.step() | |
| avg = total_loss / len(train_ds) | |
| lr = optimizer.param_groups[0]['lr'] | |
| print(f" Epoch {epoch+1:2d}/{epochs} | Loss={avg:.1f} | LR={lr:.5f}") | |
| elapsed = time.time() - t0 | |
| print(f"\nTraining done in {elapsed:.1f}s") | |
| # Save weights (CPU for portability) | |
| model.eval() | |
| model.cpu() | |
| save_path = "decoder.pth" | |
| torch.save(model.state_dict(), save_path) | |
| import os | |
| size_kb = os.path.getsize(save_path) / 1024 | |
| print(f"Saved: {save_path} ({size_kb:.0f} KB)") | |
| # Quick quality check | |
| print("\nQuality check: decoding 5 random latent vectors...") | |
| with torch.no_grad(): | |
| for i in range(5): | |
| z = torch.randn(1, 16) | |
| img = model.decode(z).squeeze().numpy() | |
| print(f" z[{i}]: min={img.min():.3f} max={img.max():.3f} mean={img.mean():.3f}") | |
| print("\nDone! Upload decoder.pth to HuggingFace Spaces.") | |
| if __name__ == "__main__": | |
| main() | |