|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG_PATH = Path(__file__).parent / "config.json" |
|
|
with open(CONFIG_PATH, "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
latent_dim = config["latent_dim"] |
|
|
num_subjects = config["num_subjects"] |
|
|
num_channels = config["num_channels"] |
|
|
segment_length = config["segment_length"] |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
|
def __init__(self, latent_dim=128, n_classes=109, channels=64, seq_len=480): |
|
|
super().__init__() |
|
|
self.latent_dim = latent_dim |
|
|
self.label_emb = nn.Embedding(n_classes, latent_dim) |
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(latent_dim * 2, 2048), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(2048, channels * 120), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
self.deconv = nn.Sequential( |
|
|
nn.ConvTranspose1d(channels, channels, kernel_size=4, stride=4, padding=0, dilation=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv1d(channels, channels, kernel_size=3, padding=2, dilation=2), |
|
|
nn.ReLU(), |
|
|
nn.Conv1d(channels, channels, kernel_size=3, padding=4, dilation=4), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def forward(self, z, labels): |
|
|
label_emb = self.label_emb(labels) |
|
|
x = torch.cat([z, label_emb], dim=1) |
|
|
x = self.fc(x) |
|
|
x = x.view(x.size(0), 64, 120) |
|
|
x = x + 0.05 * torch.randn_like(x) |
|
|
x = self.deconv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = Path(__file__).parent / "best_g.pt" |
|
|
generator = Generator(latent_dim, num_subjects, num_channels, segment_length).to(device) |
|
|
|
|
|
checkpoint = torch.load(MODEL_PATH, map_location=device) |
|
|
if "G_state" in checkpoint: |
|
|
generator.load_state_dict(checkpoint["G_state"]) |
|
|
else: |
|
|
generator.load_state_dict(checkpoint) |
|
|
generator.eval() |
|
|
|
|
|
print("Generator loaded successfully on", device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_eeg(subject_id: int, num_samples: int = 1, seed: int | None = None): |
|
|
""" |
|
|
Generate synthetic EEG segments for a given subject ID. |
|
|
|
|
|
Args: |
|
|
subject_id (int): Subject label |
|
|
num_samples (int): Number of EEG samples to generate |
|
|
seed (int, optional): Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Generated EEG of shape (num_samples, num_channels, segment_length) |
|
|
""" |
|
|
if seed is not None: |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
z = torch.randn(num_samples, latent_dim, device=device) |
|
|
labels = torch.full((num_samples,), subject_id, dtype=torch.long, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
fake_eeg = generator(z, labels).cpu().numpy() |
|
|
|
|
|
return fake_eeg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
subject_id = 42 |
|
|
samples = generate_eeg(subject_id=subject_id, num_samples=5, seed=123) |
|
|
print(f"Generated {samples.shape[0]} EEG samples for subject {subject_id}") |
|
|
print("EEG shape:", samples.shape) |
|
|
print("Value range:", samples.min(), "to", samples.max()) |
|
|
|