import torch import torch.nn as nn import numpy as np import json from pathlib import Path # ---------------------------------------------------------- # Load config # ---------------------------------------------------------- CONFIG_PATH = Path(__file__).parent / "config.json" with open(CONFIG_PATH, "r") as f: config = json.load(f) latent_dim = config["latent_dim"] num_subjects = config["num_subjects"] num_channels = config["num_channels"] segment_length = config["segment_length"] device = "cuda" if torch.cuda.is_available() else "cpu" # ---------------------------------------------------------- # Define the same Generator as in your training script # ---------------------------------------------------------- class Generator(nn.Module): def __init__(self, latent_dim=128, n_classes=109, channels=64, seq_len=480): super().__init__() self.latent_dim = latent_dim self.label_emb = nn.Embedding(n_classes, latent_dim) self.fc = nn.Sequential( nn.Linear(latent_dim * 2, 2048), nn.ReLU(), nn.Dropout(0.1), nn.Linear(2048, channels * 120), nn.ReLU() ) self.deconv = nn.Sequential( nn.ConvTranspose1d(channels, channels, kernel_size=4, stride=4, padding=0, dilation=1), nn.ReLU(), nn.Conv1d(channels, channels, kernel_size=3, padding=2, dilation=2), nn.ReLU(), nn.Conv1d(channels, channels, kernel_size=3, padding=4, dilation=4), nn.Tanh() ) def forward(self, z, labels): label_emb = self.label_emb(labels) x = torch.cat([z, label_emb], dim=1) x = self.fc(x) x = x.view(x.size(0), 64, 120) x = x + 0.05 * torch.randn_like(x) # noise injection x = self.deconv(x) return x # shape: (batch, channels, seq_len) # ---------------------------------------------------------- # Load model weights # ---------------------------------------------------------- MODEL_PATH = Path(__file__).parent / "best_g.pt" generator = Generator(latent_dim, num_subjects, num_channels, segment_length).to(device) checkpoint = torch.load(MODEL_PATH, map_location=device) if "G_state" in checkpoint: # full checkpoint from training generator.load_state_dict(checkpoint["G_state"]) else: # only weights saved generator.load_state_dict(checkpoint) generator.eval() print("Generator loaded successfully on", device) # ---------------------------------------------------------- # EEG generation function # ---------------------------------------------------------- def generate_eeg(subject_id: int, num_samples: int = 1, seed: int | None = None): """ Generate synthetic EEG segments for a given subject ID. Args: subject_id (int): Subject label num_samples (int): Number of EEG samples to generate seed (int, optional): Random seed for reproducibility Returns: np.ndarray: Generated EEG of shape (num_samples, num_channels, segment_length) """ if seed is not None: torch.manual_seed(seed) z = torch.randn(num_samples, latent_dim, device=device) labels = torch.full((num_samples,), subject_id, dtype=torch.long, device=device) with torch.no_grad(): fake_eeg = generator(z, labels).cpu().numpy() return fake_eeg # ---------------------------------------------------------- # Example usage # ---------------------------------------------------------- if __name__ == "__main__": subject_id = 42 samples = generate_eeg(subject_id=subject_id, num_samples=5, seed=123) print(f"Generated {samples.shape[0]} EEG samples for subject {subject_id}") print("EEG shape:", samples.shape) print("Value range:", samples.min(), "to", samples.max())