Spaces:
Running
Running
Upload 4 files
Browse files- app.py +16 -7
- decoder.pth +3 -0
app.py
CHANGED
|
@@ -394,25 +394,36 @@ class LightweightBrainDecoder(nn.Module):
|
|
| 394 |
brain_decoder = None
|
| 395 |
|
| 396 |
def get_brain_decoder():
|
| 397 |
-
"""Get
|
| 398 |
global brain_decoder
|
| 399 |
if brain_decoder is not None:
|
| 400 |
return brain_decoder
|
| 401 |
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
t0 = time.time()
|
| 404 |
|
| 405 |
from torchvision import datasets, transforms
|
| 406 |
from torch.utils.data import DataLoader
|
| 407 |
|
| 408 |
-
decoder = LightweightBrainDecoder(latent_dim=16, num_steps=4)
|
| 409 |
decoder.train()
|
| 410 |
-
|
| 411 |
transform = transforms.Compose([transforms.ToTensor()])
|
| 412 |
try:
|
| 413 |
train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
|
| 414 |
except Exception:
|
| 415 |
-
# Fallback: use random data if download fails
|
| 416 |
print("[Brain Imaging] FashionMNIST download failed, using synthetic data")
|
| 417 |
train_ds = torch.utils.data.TensorDataset(
|
| 418 |
torch.rand(1000, 1, 28, 28),
|
|
@@ -922,8 +933,6 @@ with gr.Blocks(
|
|
| 922 |
4. Listen to the AI's "heartbeat" — steady for normal, arrhythmic for attacks
|
| 923 |
|
| 924 |
> 💡 **Try a normal prompt first, then a jailbreak prompt** to see the dramatic difference!
|
| 925 |
-
|
| 926 |
-
> ⚠️ The decoder trains on first use (~30 seconds on CPU). Please be patient!
|
| 927 |
""")
|
| 928 |
|
| 929 |
with gr.Row():
|
|
|
|
| 394 |
brain_decoder = None
|
| 395 |
|
| 396 |
def get_brain_decoder():
|
| 397 |
+
"""Get pre-trained brain decoder (loads weights, no training needed)"""
|
| 398 |
global brain_decoder
|
| 399 |
if brain_decoder is not None:
|
| 400 |
return brain_decoder
|
| 401 |
|
| 402 |
+
decoder = LightweightBrainDecoder(latent_dim=16, num_steps=4)
|
| 403 |
+
|
| 404 |
+
# Try to load pre-trained weights (zero latency!)
|
| 405 |
+
import os
|
| 406 |
+
weights_path = os.path.join(os.path.dirname(__file__), "decoder.pth")
|
| 407 |
+
if os.path.exists(weights_path):
|
| 408 |
+
print("[Brain Imaging] Loading pre-trained decoder (instant!)...")
|
| 409 |
+
decoder.load_state_dict(torch.load(weights_path, map_location='cpu', weights_only=True))
|
| 410 |
+
decoder.eval()
|
| 411 |
+
brain_decoder = decoder
|
| 412 |
+
print("[Brain Imaging] Decoder ready (pre-trained weights loaded)")
|
| 413 |
+
return decoder
|
| 414 |
+
|
| 415 |
+
# Fallback: train from scratch if weights not found
|
| 416 |
+
print("[Brain Imaging] Pre-trained weights not found, training from scratch (~30s)...")
|
| 417 |
t0 = time.time()
|
| 418 |
|
| 419 |
from torchvision import datasets, transforms
|
| 420 |
from torch.utils.data import DataLoader
|
| 421 |
|
|
|
|
| 422 |
decoder.train()
|
|
|
|
| 423 |
transform = transforms.Compose([transforms.ToTensor()])
|
| 424 |
try:
|
| 425 |
train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
|
| 426 |
except Exception:
|
|
|
|
| 427 |
print("[Brain Imaging] FashionMNIST download failed, using synthetic data")
|
| 428 |
train_ds = torch.utils.data.TensorDataset(
|
| 429 |
torch.rand(1000, 1, 28, 28),
|
|
|
|
| 933 |
4. Listen to the AI's "heartbeat" — steady for normal, arrhythmic for attacks
|
| 934 |
|
| 935 |
> 💡 **Try a normal prompt first, then a jailbreak prompt** to see the dramatic difference!
|
|
|
|
|
|
|
| 936 |
""")
|
| 937 |
|
| 938 |
with gr.Row():
|
decoder.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d27712da397e15c2921c2e4014438f2fe052cbfd30bbfac846cd5d38db77d527
|
| 3 |
+
size 1701760
|