snn-guardrail / save_decoder_weights.py
hafufu-stack's picture
v4.0: Add Canary Pulse (Tab 5) - Real-time Entropy EKG + Self-Healing
4debe27 verified
"""
Save pre-trained decoder weights for HuggingFace zero-latency demo.
Run this locally (RTX 5080) to generate decoder.pth.
Upload decoder.pth to HF Spaces alongside app.py.
IMPORTANT: Architecture must EXACTLY match LightweightBrainDecoder in app.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
# ---------- Decoder (EXACT copy from app.py) ----------
class LightweightBrainDecoder(nn.Module):
"""Must be identical to the class in app.py"""
def __init__(self, latent_dim=16, num_steps=4):
super().__init__()
self.latent_dim = latent_dim
self.num_steps = num_steps
# Encoder (for training)
self.enc_conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1) # 28->14
self.enc_bn1 = nn.BatchNorm2d(16)
self.enc_conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1) # 14->7
self.enc_bn2 = nn.BatchNorm2d(32)
self.enc_fc = nn.Linear(32 * 7 * 7, 128)
self.fc_mu = nn.Linear(128, latent_dim)
self.fc_logvar = nn.Linear(128, latent_dim)
nn.init.constant_(self.fc_logvar.bias, -5.0)
# Decoder (the brain state visualizer)
self.dec_fc1 = nn.Linear(latent_dim, 128)
self.dec_fc2 = nn.Linear(128, 32 * 7 * 7)
self.dec_deconv1 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) # 7->14
self.dec_bn1 = nn.BatchNorm2d(16)
self.dec_deconv2 = nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1) # 14->28
def encode(self, x):
h = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1)
h = F.leaky_relu(self.enc_bn2(self.enc_conv2(h)), 0.1)
h = h.view(h.size(0), -1)
h = F.leaky_relu(self.enc_fc(h), 0.1)
return self.fc_mu(h), self.fc_logvar(h)
def decode(self, z):
# Temporal averaging (SNN-like behavior)
output_sum = torch.zeros(z.size(0), 1, 28, 28, device=z.device)
for t in range(self.num_steps):
noise = torch.randn_like(z) * 0.05 * (1 - t / self.num_steps)
z_t = z + noise
h = F.leaky_relu(self.dec_fc1(z_t), 0.1)
h = F.leaky_relu(self.dec_fc2(h), 0.1)
h = h.view(-1, 32, 7, 7)
h = F.leaky_relu(self.dec_bn1(self.dec_deconv1(h)), 0.1)
h = self.dec_deconv2(h)
output_sum += h
return torch.sigmoid(output_sum / self.num_steps)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
return mu + torch.randn_like(std) * std
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
model = LightweightBrainDecoder(latent_dim=16, num_steps=4).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load Fashion-MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)
epochs = 15
print(f"\nTraining for {epochs} epochs...")
t0 = time.time()
for epoch in range(epochs):
model.train()
total_loss = 0
beta_kl = min(1.0, epoch / 3.0)
for data, _ in loader:
data = data.to(device)
optimizer.zero_grad()
recon, mu, logvar = model(data)
bce = F.binary_cross_entropy(recon, data, reduction='sum')
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
kld = torch.clamp(kld, min=0)
loss = bce + beta_kl * kld
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg = total_loss / len(train_ds)
lr = optimizer.param_groups[0]['lr']
print(f" Epoch {epoch+1:2d}/{epochs} | Loss={avg:.1f} | LR={lr:.5f}")
elapsed = time.time() - t0
print(f"\nTraining done in {elapsed:.1f}s")
# Save weights (CPU for portability)
model.eval()
model.cpu()
save_path = "decoder.pth"
torch.save(model.state_dict(), save_path)
import os
size_kb = os.path.getsize(save_path) / 1024
print(f"Saved: {save_path} ({size_kb:.0f} KB)")
# Quick quality check
print("\nQuality check: decoding 5 random latent vectors...")
with torch.no_grad():
for i in range(5):
z = torch.randn(1, 16)
img = model.decode(z).squeeze().numpy()
print(f" z[{i}]: min={img.min():.3f} max={img.max():.3f} mean={img.mean():.3f}")
print("\nDone! Upload decoder.pth to HuggingFace Spaces.")
if __name__ == "__main__":
main()