| import os |
| import torch |
| import pandas as pd |
| import numpy as np |
| from torch.utils.data import Dataset, IterableDataset |
| from typing import List, Tuple, Optional, Dict, Union |
| from scipy import stats |
| from .utils import pad_sequences, create_padding_mask |
|
|
|
|
| class CollateWrapper: |
| """Wrapper class for collate function to avoid pickling issues with multiprocessing.""" |
| def __init__(self, padding_value): |
| self.padding_value = padding_value |
| |
| def __call__(self, batch): |
| return collate_nb_glm_batch(batch, padding_value=self.padding_value) |
|
|
|
|
| def collate_nb_glm_batch(batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], |
| padding_value: float = -1e9) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Collate function for variable-length NB GLM sequences. |
| |
| Args: |
| batch: List of (set_1, set_2, targets) tuples |
| padding_value: Value to use for padding |
| |
| Returns: |
| Tuple of (set_1_batch, set_2_batch, set_1_mask, set_2_mask, targets_batch) |
| """ |
| set_1_list, set_2_list, targets_list = zip(*batch) |
| |
| |
| set_1_padded = pad_sequences(list(set_1_list), padding_value=padding_value) |
| set_2_padded = pad_sequences(list(set_2_list), padding_value=padding_value) |
| |
| |
| set_1_mask = create_padding_mask(list(set_1_list)) |
| set_2_mask = create_padding_mask(list(set_2_list)) |
| |
| |
| targets_batch = torch.stack(targets_list) |
| |
| return set_1_padded, set_2_padded, set_1_mask, set_2_mask, targets_batch |
|
|
|
|
| class SyntheticNBGLMDataset(IterableDataset): |
| """ |
| Online synthetic data generator for Negative Binomial GLM parameter estimation. |
| |
| Generates training examples on-the-fly with known ground truth parameters: |
| - mu: Base mean parameter (log scale) |
| - beta: Log fold change between conditions |
| - alpha: Dispersion parameter (log scale) |
| |
| Each example consists of two sets of samples drawn from: |
| - Condition 1: x ~ NB(l * exp(mu), exp(alpha)) |
| - Condition 2: x ~ NB(l * exp(mu + beta), exp(alpha)) |
| |
| Counts are transformed to: y = log10(1e4 * x / l + 1) |
| """ |
| |
| TARGET_COLUMNS = ['mu', 'beta', 'alpha'] |
| |
| def __init__(self, |
| num_examples_per_epoch: int = 100000, |
| min_samples_per_condition: int = 2, |
| max_samples_per_condition: int = 10, |
| mu_loc: float = -1.0, |
| mu_scale: float = 2.0, |
| alpha_mean: float = -2.0, |
| alpha_std: float = 1.0, |
| beta_prob_de: float = 0.3, |
| beta_std: float = 1.0, |
| library_size_mean: float = 10000, |
| library_size_cv: float = 0.3, |
| seed: Optional[int] = None): |
| """ |
| Initialize synthetic NB GLM dataset. |
| |
| Args: |
| num_examples_per_epoch: Number of examples to generate per epoch |
| min_samples_per_condition: Minimum samples per condition |
| max_samples_per_condition: Maximum samples per condition |
| mu_loc: Location parameter for mu log-normal distribution |
| mu_scale: Scale parameter for mu log-normal distribution |
| alpha_mean: Mean of alpha normal distribution |
| alpha_std: Std of alpha normal distribution |
| beta_prob_de: Probability of differential expression (non-zero beta) |
| beta_std: Standard deviation of beta when DE |
| library_size_mean: Mean library size |
| library_size_cv: Coefficient of variation for library size |
| seed: Random seed for reproducibility |
| """ |
| self.num_examples_per_epoch = num_examples_per_epoch |
| self.min_samples = min_samples_per_condition |
| self.max_samples = max_samples_per_condition |
| |
| |
| self.mu_loc = mu_loc |
| self.mu_scale = mu_scale |
| self.alpha_mean = alpha_mean |
| self.alpha_std = alpha_std |
| self.beta_prob_de = beta_prob_de |
| self.beta_std = beta_std |
| |
| |
| self.library_size_mean = library_size_mean |
| self.library_size_cv = library_size_cv |
| self.library_size_std = library_size_mean * library_size_cv |
| |
| |
| self.target_stats = { |
| 'mu': {'mean': mu_loc, 'std': mu_scale}, |
| 'alpha': {'mean': alpha_mean, 'std': alpha_std}, |
| |
| 'beta': {'mean': 0.0, 'std': (beta_prob_de * beta_std**2)**0.5} |
| } |
| |
| |
| self.seed = seed |
| self.rng = np.random.RandomState(seed) |
| |
| def __len__(self): |
| """Return the number of examples per epoch for progress tracking.""" |
| return self.num_examples_per_epoch |
| |
| def __iter__(self): |
| """Infinite iterator that generates examples on-the-fly.""" |
| worker_info = torch.utils.data.get_worker_info() |
| |
| |
| if worker_info is None: |
| |
| examples_per_worker = self.num_examples_per_epoch |
| worker_id = 0 |
| else: |
| |
| examples_per_worker = self.num_examples_per_epoch // worker_info.num_workers |
| worker_id = worker_info.id |
| |
| |
| if self.seed is not None: |
| self.rng = np.random.RandomState(self.seed + worker_id) |
| |
| |
| for _ in range(examples_per_worker): |
| yield self._generate_example() |
| |
| def _generate_example(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Generate a single training example.""" |
| |
| mu = self._sample_mu() |
| alpha = self._sample_alpha(mu) |
| beta = self._sample_beta() |
| |
| |
| n1 = self.rng.randint(self.min_samples, self.max_samples + 1) |
| n2 = self.rng.randint(self.min_samples, self.max_samples + 1) |
| |
| |
| set_1 = self._generate_set(mu, alpha, n1) |
| |
| |
| set_2 = self._generate_set(mu + beta, alpha, n2) |
| |
| |
| targets_raw = {'mu': mu, 'beta': beta, 'alpha': alpha} |
| targets_normalized = self._normalize_targets(targets_raw) |
| targets = torch.tensor([targets_normalized['mu'], targets_normalized['beta'], targets_normalized['alpha']], dtype=torch.float32) |
| |
| return set_1, set_2, targets |
| |
| def _normalize_targets(self, targets: Dict[str, float]) -> Dict[str, float]: |
| """Normalize targets to unit normal for better regression performance.""" |
| normalized = {} |
| for param in ['mu', 'beta', 'alpha']: |
| mean = self.target_stats[param]['mean'] |
| std = self.target_stats[param]['std'] |
| |
| std = max(std, 1e-8) |
| normalized[param] = (targets[param] - mean) / std |
| return normalized |
| |
| def denormalize_targets(self, normalized_targets: Dict[str, float]) -> Dict[str, float]: |
| """Denormalize targets back to original scale.""" |
| denormalized = {} |
| for param in ['mu', 'beta', 'alpha']: |
| mean = self.target_stats[param]['mean'] |
| std = self.target_stats[param]['std'] |
| denormalized[param] = normalized_targets[param] * std + mean |
| return denormalized |
| |
| def _sample_mu(self) -> float: |
| """Sample base mean parameter from log-normal distribution.""" |
| return self.rng.normal(self.mu_loc, self.mu_scale) |
| |
| def _sample_alpha(self, mu: float) -> float: |
| """ |
| Sample dispersion parameter. |
| |
| For now, we use a simple normal distribution. |
| In the future, this could model the mean-dispersion relationship. |
| """ |
| |
| return self.rng.normal(self.alpha_mean, self.alpha_std) |
| |
| def _sample_beta(self) -> float: |
| """Sample log fold change with mixture distribution.""" |
| if self.rng.random() < self.beta_prob_de: |
| |
| return self.rng.normal(0, self.beta_std) |
| else: |
| |
| return 0.0 |
| |
| def _sample_library_sizes(self, n_samples: int) -> np.ndarray: |
| """Sample library sizes from log-normal distribution.""" |
| |
| log_mean = np.log(self.library_size_mean) - 0.5 * np.log(1 + self.library_size_cv**2) |
| log_std = np.sqrt(np.log(1 + self.library_size_cv**2)) |
| |
| return self.rng.lognormal(log_mean, log_std, size=n_samples) |
| |
| def _generate_set(self, mu: float, alpha: float, n_samples: int) -> torch.Tensor: |
| """ |
| Generate a set of transformed counts from NB distribution. |
| |
| Args: |
| mu: Log mean parameter |
| alpha: Log dispersion parameter |
| n_samples: Number of samples to generate |
| |
| Returns: |
| Tensor of shape (n_samples, 1) with transformed counts |
| """ |
| |
| library_sizes = self._sample_library_sizes(n_samples) |
| |
| |
| mean_expr = np.exp(mu) |
| dispersion = np.exp(alpha) |
| |
| |
| counts = [] |
| for lib_size in library_sizes: |
| |
| mean_count = lib_size * mean_expr |
| |
| |
| |
| |
| |
| |
| r = 1.0 / dispersion |
| p = r / (r + mean_count) |
| |
| |
| count = self.rng.negative_binomial(r, p) |
| counts.append(count) |
| |
| counts = np.array(counts) |
| |
| |
| transformed = np.log10(1e4 * counts / library_sizes + 1) |
| |
| |
| return torch.tensor(transformed, dtype=torch.float32).unsqueeze(-1) |
|
|
|
|
| class ParameterDistributions: |
| """ |
| Container for parameter distributions learned from empirical data. |
| |
| This class loads and stores the distributions needed for realistic |
| synthetic data generation. |
| """ |
| |
| def __init__(self, empirical_stats_file: Optional[str] = None): |
| """ |
| Initialize parameter distributions. |
| |
| Args: |
| empirical_stats_file: Path to empirical statistics file |
| If None, uses default distributions |
| """ |
| if empirical_stats_file is not None: |
| self._load_empirical_distributions(empirical_stats_file) |
| else: |
| self._set_default_distributions() |
| |
| def _load_empirical_distributions(self, filepath: str): |
| """Load parameter distributions from empirical data analysis.""" |
| |
| |
| raise NotImplementedError("Empirical distribution loading not yet implemented") |
| |
| def _set_default_distributions(self): |
| """Set reasonable default distributions for synthetic data.""" |
| |
| self.mu_params = { |
| 'loc': -1.0, |
| 'scale': 2.0 |
| } |
| |
| |
| self.alpha_params = { |
| 'mean': -2.0, |
| 'std': 1.0 |
| } |
| |
| |
| self.beta_params = { |
| 'prob_de': 0.3, |
| 'std': 1.0 |
| } |
| |
| |
| self.library_params = { |
| 'mean': 10000, |
| 'cv': 0.3 |
| } |
| |
| |
| self.target_stats = { |
| 'mu': {'mean': self.mu_params['loc'], 'std': self.mu_params['scale']}, |
| 'alpha': {'mean': self.alpha_params['mean'], 'std': self.alpha_params['std']}, |
| |
| |
| 'beta': {'mean': 0.0, 'std': (self.beta_params['prob_de'] * self.beta_params['std']**2)**0.5} |
| } |
|
|
|
|
| def create_dataloaders(batch_size: int = 32, |
| num_workers: int = 4, |
| num_examples_per_epoch: int = 100000, |
| parameter_distributions: Optional[ParameterDistributions] = None, |
| padding_value: float = -1e9, |
| seed: Optional[int] = None, |
| persistent_workers: bool = False) -> torch.utils.data.DataLoader: |
| """ |
| Create dataloader for synthetic NB GLM training. |
| |
| Args: |
| batch_size: Batch size for training |
| num_workers: Number of worker processes for data generation |
| num_examples_per_epoch: Examples to generate per epoch |
| parameter_distributions: Parameter distributions for generation |
| padding_value: Padding value for variable-length sequences |
| seed: Random seed for reproducibility |
| persistent_workers: Whether to keep workers alive between epochs |
| |
| Returns: |
| DataLoader for training |
| """ |
| |
| if parameter_distributions is None: |
| parameter_distributions = ParameterDistributions() |
| |
| |
| dataset = SyntheticNBGLMDataset( |
| num_examples_per_epoch=num_examples_per_epoch, |
| mu_loc=parameter_distributions.mu_params['loc'], |
| mu_scale=parameter_distributions.mu_params['scale'], |
| alpha_mean=parameter_distributions.alpha_params['mean'], |
| alpha_std=parameter_distributions.alpha_params['std'], |
| beta_prob_de=parameter_distributions.beta_params['prob_de'], |
| beta_std=parameter_distributions.beta_params['std'], |
| library_size_mean=parameter_distributions.library_params['mean'], |
| library_size_cv=parameter_distributions.library_params['cv'], |
| seed=seed |
| ) |
| |
| |
| collate_fn = CollateWrapper(padding_value) |
| |
| |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| persistent_workers=persistent_workers and num_workers > 0 |
| ) |
| |
| return dataloader |