Georgios-Ak commited on
Commit
9efc160
·
verified ·
1 Parent(s): 0487664

Upload 3 files

Browse files
Files changed (3) hide show
  1. best_g.pt +3 -0
  2. config.json +40 -0
  3. inference.py +108 -0
best_g.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f95490b3f85f96d425e4cbdbdad12a21f2a287235b98e36fa9e8667df4eeb3b
3
+ size 237307698
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "EEG Data Synthesis with WGAN-GP",
3
+ "architecture": "conditional_wgan_gp",
4
+ "latent_dim": 128,
5
+ "num_subjects": 109,
6
+ "num_channels": 64,
7
+ "segment_length": 480,
8
+ "generator_fc_layers": [256, 2048, 7680],
9
+ "deconv_layers": [
10
+ {"type": "ConvTranspose1d", "kernel_size": 4, "stride": 4, "dilation": 1},
11
+ {"type": "Conv1d", "kernel_size": 3, "padding": 2, "dilation": 2},
12
+ {"type": "Conv1d", "kernel_size": 3, "padding": 4, "dilation": 4}
13
+ ],
14
+ "activation": "tanh",
15
+ "optimizer": {
16
+ "type": "Adam",
17
+ "beta1": 0.0,
18
+ "beta2": 0.9,
19
+ "lr_generator": 1e-4,
20
+ "lr_discriminator": 5e-5
21
+ },
22
+ "training": {
23
+ "epochs": 300,
24
+ "batch_size": 42,
25
+ "gradient_penalty_lambda": 5,
26
+ "drift_regularization": 0.001,
27
+ "n_critic": 1,
28
+ "mixed_precision": true
29
+ },
30
+ "dataset": {
31
+ "name": "PhysioNet EEG Motor Movement/Imagery",
32
+ "num_subjects": 109,
33
+ "sampling_rate": 160,
34
+ "channels": 64,
35
+ "segment_length": 480,
36
+ "normalization": "[-1, 1]",
37
+ "tasks": ["left_fist", "right_fist", "both_fists", "both_feet", "eyes_open", "eyes_closed"]
38
+ },
39
+ "description": "Trained conditional EEG generator (WGAN-GP) using 109 subjects from the PhysioNet EEG Motor Movement/Imagery dataset."
40
+ }
inference.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import json
5
+ from pathlib import Path
6
+
7
+ # ----------------------------------------------------------
8
+ # Load config
9
+ # ----------------------------------------------------------
10
+ CONFIG_PATH = Path(__file__).parent / "config.json"
11
+ with open(CONFIG_PATH, "r") as f:
12
+ config = json.load(f)
13
+
14
+ latent_dim = config["latent_dim"]
15
+ num_subjects = config["num_subjects"]
16
+ num_channels = config["num_channels"]
17
+ segment_length = config["segment_length"]
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # ----------------------------------------------------------
22
+ # Define the same Generator as in your training script
23
+ # ----------------------------------------------------------
24
+ class Generator(nn.Module):
25
+ def __init__(self, latent_dim=128, n_classes=109, channels=64, seq_len=480):
26
+ super().__init__()
27
+ self.latent_dim = latent_dim
28
+ self.label_emb = nn.Embedding(n_classes, latent_dim)
29
+
30
+ self.fc = nn.Sequential(
31
+ nn.Linear(latent_dim * 2, 2048),
32
+ nn.ReLU(),
33
+ nn.Dropout(0.1),
34
+ nn.Linear(2048, channels * 120),
35
+ nn.ReLU()
36
+ )
37
+
38
+ self.deconv = nn.Sequential(
39
+ nn.ConvTranspose1d(channels, channels, kernel_size=4, stride=4, padding=0, dilation=1),
40
+ nn.ReLU(),
41
+ nn.Conv1d(channels, channels, kernel_size=3, padding=2, dilation=2),
42
+ nn.ReLU(),
43
+ nn.Conv1d(channels, channels, kernel_size=3, padding=4, dilation=4),
44
+ nn.Tanh()
45
+ )
46
+
47
+ def forward(self, z, labels):
48
+ label_emb = self.label_emb(labels)
49
+ x = torch.cat([z, label_emb], dim=1)
50
+ x = self.fc(x)
51
+ x = x.view(x.size(0), 64, 120)
52
+ x = x + 0.05 * torch.randn_like(x) # noise injection
53
+ x = self.deconv(x)
54
+ return x # shape: (batch, channels, seq_len)
55
+
56
+
57
+ # ----------------------------------------------------------
58
+ # Load model weights
59
+ # ----------------------------------------------------------
60
+ MODEL_PATH = Path(__file__).parent / "best_g.pt"
61
+ generator = Generator(latent_dim, num_subjects, num_channels, segment_length).to(device)
62
+
63
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
64
+ if "G_state" in checkpoint: # full checkpoint from training
65
+ generator.load_state_dict(checkpoint["G_state"])
66
+ else: # only weights saved
67
+ generator.load_state_dict(checkpoint)
68
+ generator.eval()
69
+
70
+ print("Generator loaded successfully on", device)
71
+
72
+
73
+ # ----------------------------------------------------------
74
+ # EEG generation function
75
+ # ----------------------------------------------------------
76
+ def generate_eeg(subject_id: int, num_samples: int = 1, seed: int | None = None):
77
+ """
78
+ Generate synthetic EEG segments for a given subject ID.
79
+
80
+ Args:
81
+ subject_id (int): Subject label
82
+ num_samples (int): Number of EEG samples to generate
83
+ seed (int, optional): Random seed for reproducibility
84
+
85
+ Returns:
86
+ np.ndarray: Generated EEG of shape (num_samples, num_channels, segment_length)
87
+ """
88
+ if seed is not None:
89
+ torch.manual_seed(seed)
90
+
91
+ z = torch.randn(num_samples, latent_dim, device=device)
92
+ labels = torch.full((num_samples,), subject_id, dtype=torch.long, device=device)
93
+
94
+ with torch.no_grad():
95
+ fake_eeg = generator(z, labels).cpu().numpy()
96
+
97
+ return fake_eeg
98
+
99
+
100
+ # ----------------------------------------------------------
101
+ # Example usage
102
+ # ----------------------------------------------------------
103
+ if __name__ == "__main__":
104
+ subject_id = 42
105
+ samples = generate_eeg(subject_id=subject_id, num_samples=5, seed=123)
106
+ print(f"Generated {samples.shape[0]} EEG samples for subject {subject_id}")
107
+ print("EEG shape:", samples.shape)
108
+ print("Value range:", samples.min(), "to", samples.max())