EEG-WGAN-GP / inference.py
Georgios-Ak's picture
Upload 3 files
9efc160 verified
import torch
import torch.nn as nn
import numpy as np
import json
from pathlib import Path
# ----------------------------------------------------------
# Load config
# ----------------------------------------------------------
CONFIG_PATH = Path(__file__).parent / "config.json"
with open(CONFIG_PATH, "r") as f:
config = json.load(f)
latent_dim = config["latent_dim"]
num_subjects = config["num_subjects"]
num_channels = config["num_channels"]
segment_length = config["segment_length"]
device = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------------------------------------
# Define the same Generator as in your training script
# ----------------------------------------------------------
class Generator(nn.Module):
def __init__(self, latent_dim=128, n_classes=109, channels=64, seq_len=480):
super().__init__()
self.latent_dim = latent_dim
self.label_emb = nn.Embedding(n_classes, latent_dim)
self.fc = nn.Sequential(
nn.Linear(latent_dim * 2, 2048),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(2048, channels * 120),
nn.ReLU()
)
self.deconv = nn.Sequential(
nn.ConvTranspose1d(channels, channels, kernel_size=4, stride=4, padding=0, dilation=1),
nn.ReLU(),
nn.Conv1d(channels, channels, kernel_size=3, padding=2, dilation=2),
nn.ReLU(),
nn.Conv1d(channels, channels, kernel_size=3, padding=4, dilation=4),
nn.Tanh()
)
def forward(self, z, labels):
label_emb = self.label_emb(labels)
x = torch.cat([z, label_emb], dim=1)
x = self.fc(x)
x = x.view(x.size(0), 64, 120)
x = x + 0.05 * torch.randn_like(x) # noise injection
x = self.deconv(x)
return x # shape: (batch, channels, seq_len)
# ----------------------------------------------------------
# Load model weights
# ----------------------------------------------------------
MODEL_PATH = Path(__file__).parent / "best_g.pt"
generator = Generator(latent_dim, num_subjects, num_channels, segment_length).to(device)
checkpoint = torch.load(MODEL_PATH, map_location=device)
if "G_state" in checkpoint: # full checkpoint from training
generator.load_state_dict(checkpoint["G_state"])
else: # only weights saved
generator.load_state_dict(checkpoint)
generator.eval()
print("Generator loaded successfully on", device)
# ----------------------------------------------------------
# EEG generation function
# ----------------------------------------------------------
def generate_eeg(subject_id: int, num_samples: int = 1, seed: int | None = None):
"""
Generate synthetic EEG segments for a given subject ID.
Args:
subject_id (int): Subject label
num_samples (int): Number of EEG samples to generate
seed (int, optional): Random seed for reproducibility
Returns:
np.ndarray: Generated EEG of shape (num_samples, num_channels, segment_length)
"""
if seed is not None:
torch.manual_seed(seed)
z = torch.randn(num_samples, latent_dim, device=device)
labels = torch.full((num_samples,), subject_id, dtype=torch.long, device=device)
with torch.no_grad():
fake_eeg = generator(z, labels).cpu().numpy()
return fake_eeg
# ----------------------------------------------------------
# Example usage
# ----------------------------------------------------------
if __name__ == "__main__":
subject_id = 42
samples = generate_eeg(subject_id=subject_id, num_samples=5, seed=123)
print(f"Generated {samples.shape[0]} EEG samples for subject {subject_id}")
print("EEG shape:", samples.shape)
print("Value range:", samples.min(), "to", samples.max())