""" Save pre-trained decoder weights for HuggingFace zero-latency demo. Run this locally (RTX 5080) to generate decoder.pth. Upload decoder.pth to HF Spaces alongside app.py. IMPORTANT: Architecture must EXACTLY match LightweightBrainDecoder in app.py """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms import time # ---------- Decoder (EXACT copy from app.py) ---------- class LightweightBrainDecoder(nn.Module): """Must be identical to the class in app.py""" def __init__(self, latent_dim=16, num_steps=4): super().__init__() self.latent_dim = latent_dim self.num_steps = num_steps # Encoder (for training) self.enc_conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1) # 28->14 self.enc_bn1 = nn.BatchNorm2d(16) self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1) # 14->7 self.enc_bn2 = nn.BatchNorm2d(32) self.enc_fc = nn.Linear(32 * 7 * 7, 128) self.fc_mu = nn.Linear(128, latent_dim) self.fc_logvar = nn.Linear(128, latent_dim) nn.init.constant_(self.fc_logvar.bias, -5.0) # Decoder (the brain state visualizer) self.dec_fc1 = nn.Linear(latent_dim, 128) self.dec_fc2 = nn.Linear(128, 32 * 7 * 7) self.dec_deconv1 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) # 7->14 self.dec_bn1 = nn.BatchNorm2d(16) self.dec_deconv2 = nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1) # 14->28 def encode(self, x): h = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1) h = F.leaky_relu(self.enc_bn2(self.enc_conv2(h)), 0.1) h = h.view(h.size(0), -1) h = F.leaky_relu(self.enc_fc(h), 0.1) return self.fc_mu(h), self.fc_logvar(h) def decode(self, z): # Temporal averaging (SNN-like behavior) output_sum = torch.zeros(z.size(0), 1, 28, 28, device=z.device) for t in range(self.num_steps): noise = torch.randn_like(z) * 0.05 * (1 - t / self.num_steps) z_t = z + noise h = F.leaky_relu(self.dec_fc1(z_t), 0.1) h = F.leaky_relu(self.dec_fc2(h), 0.1) h = h.view(-1, 32, 7, 7) h = F.leaky_relu(self.dec_bn1(self.dec_deconv1(h)), 0.1) h = self.dec_deconv2(h) output_sum += h return torch.sigmoid(output_sum / self.num_steps) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) return mu + torch.randn_like(std) * std def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") model = LightweightBrainDecoder(latent_dim=16, num_steps=4).to(device) print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") # Load Fashion-MNIST transform = transforms.Compose([transforms.ToTensor()]) train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform) loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True) optimizer = torch.optim.Adam(model.parameters(), lr=2e-3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15) epochs = 15 print(f"\nTraining for {epochs} epochs...") t0 = time.time() for epoch in range(epochs): model.train() total_loss = 0 beta_kl = min(1.0, epoch / 3.0) for data, _ in loader: data = data.to(device) optimizer.zero_grad() recon, mu, logvar = model(data) bce = F.binary_cross_entropy(recon, data, reduction='sum') kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) kld = torch.clamp(kld, min=0) loss = bce + beta_kl * kld loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() scheduler.step() avg = total_loss / len(train_ds) lr = optimizer.param_groups[0]['lr'] print(f" Epoch {epoch+1:2d}/{epochs} | Loss={avg:.1f} | LR={lr:.5f}") elapsed = time.time() - t0 print(f"\nTraining done in {elapsed:.1f}s") # Save weights (CPU for portability) model.eval() model.cpu() save_path = "decoder.pth" torch.save(model.state_dict(), save_path) import os size_kb = os.path.getsize(save_path) / 1024 print(f"Saved: {save_path} ({size_kb:.0f} KB)") # Quick quality check print("\nQuality check: decoding 5 random latent vectors...") with torch.no_grad(): for i in range(5): z = torch.randn(1, 16) img = model.decode(z).squeeze().numpy() print(f" z[{i}]: min={img.min():.3f} max={img.max():.3f} mean={img.mean():.3f}") print("\nDone! Upload decoder.pth to HuggingFace Spaces.") if __name__ == "__main__": main()