hafufu-stack commited on
Commit
bf50eca
·
verified ·
1 Parent(s): cbc28b8

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +16 -7
  2. decoder.pth +3 -0
app.py CHANGED
@@ -394,25 +394,36 @@ class LightweightBrainDecoder(nn.Module):
394
  brain_decoder = None
395
 
396
  def get_brain_decoder():
397
- """Get or train the brain decoder on first use"""
398
  global brain_decoder
399
  if brain_decoder is not None:
400
  return brain_decoder
401
 
402
- print("[Brain Imaging] Training lightweight decoder on Fashion-MNIST (CPU, ~30s)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  t0 = time.time()
404
 
405
  from torchvision import datasets, transforms
406
  from torch.utils.data import DataLoader
407
 
408
- decoder = LightweightBrainDecoder(latent_dim=16, num_steps=4)
409
  decoder.train()
410
-
411
  transform = transforms.Compose([transforms.ToTensor()])
412
  try:
413
  train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
414
  except Exception:
415
- # Fallback: use random data if download fails
416
  print("[Brain Imaging] FashionMNIST download failed, using synthetic data")
417
  train_ds = torch.utils.data.TensorDataset(
418
  torch.rand(1000, 1, 28, 28),
@@ -922,8 +933,6 @@ with gr.Blocks(
922
  4. Listen to the AI's "heartbeat" — steady for normal, arrhythmic for attacks
923
 
924
  > 💡 **Try a normal prompt first, then a jailbreak prompt** to see the dramatic difference!
925
-
926
- > ⚠️ The decoder trains on first use (~30 seconds on CPU). Please be patient!
927
  """)
928
 
929
  with gr.Row():
 
394
  brain_decoder = None
395
 
396
  def get_brain_decoder():
397
+ """Get pre-trained brain decoder (loads weights, no training needed)"""
398
  global brain_decoder
399
  if brain_decoder is not None:
400
  return brain_decoder
401
 
402
+ decoder = LightweightBrainDecoder(latent_dim=16, num_steps=4)
403
+
404
+ # Try to load pre-trained weights (zero latency!)
405
+ import os
406
+ weights_path = os.path.join(os.path.dirname(__file__), "decoder.pth")
407
+ if os.path.exists(weights_path):
408
+ print("[Brain Imaging] Loading pre-trained decoder (instant!)...")
409
+ decoder.load_state_dict(torch.load(weights_path, map_location='cpu', weights_only=True))
410
+ decoder.eval()
411
+ brain_decoder = decoder
412
+ print("[Brain Imaging] Decoder ready (pre-trained weights loaded)")
413
+ return decoder
414
+
415
+ # Fallback: train from scratch if weights not found
416
+ print("[Brain Imaging] Pre-trained weights not found, training from scratch (~30s)...")
417
  t0 = time.time()
418
 
419
  from torchvision import datasets, transforms
420
  from torch.utils.data import DataLoader
421
 
 
422
  decoder.train()
 
423
  transform = transforms.Compose([transforms.ToTensor()])
424
  try:
425
  train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
426
  except Exception:
 
427
  print("[Brain Imaging] FashionMNIST download failed, using synthetic data")
428
  train_ds = torch.utils.data.TensorDataset(
429
  torch.rand(1000, 1, 28, 28),
 
933
  4. Listen to the AI's "heartbeat" — steady for normal, arrhythmic for attacks
934
 
935
  > 💡 **Try a normal prompt first, then a jailbreak prompt** to see the dramatic difference!
 
 
936
  """)
937
 
938
  with gr.Row():
decoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d27712da397e15c2921c2e4014438f2fe052cbfd30bbfac846cd5d38db77d527
3
+ size 1701760