| """ |
| InfoCORE Batch Correction Module |
| |
| Removes batch effects and confounding factors from multi-omics data |
| using Conditional Mutual Information Maximization. |
| |
| Based on: "Removing Biases from Molecular Representations via |
| Information Maximization" (arxiv:2312.00718) |
| |
| Method: |
| Maximize I(Z_content; Z_style | X_batch) via reweighted InfoNCE loss. |
| This preserves biological signal while removing batch-specific variation. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, List |
|
|
|
|
| class InfoCOREBatchCorrector(nn.Module): |
| """ |
| Batch effect removal via InfoCORE: reweighted contrastive learning |
| that maximizes conditional mutual information between content and |
| style representations given batch identity. |
| |
| The key insight: rather than restricting negatives to same-batch |
| samples (as in CCL), InfoCORE uses ALL samples but reweights them |
| by posterior batch probability. This gives both debiasing AND |
| sufficient negative samples. |
| """ |
| |
| def __init__( |
| self, |
| input_dim: int, |
| content_dim: int = 128, |
| style_dim: int = 32, |
| temperature: float = 0.07, |
| num_batches: int = 10, |
| ): |
| """ |
| Args: |
| input_dim: dimension of input features (e.g., gene expression) |
| content_dim: dimension of content (biological) representation |
| style_dim: dimension of style (batch-specific) representation |
| temperature: softmax temperature for InfoNCE |
| num_batches: expected number of distinct batches |
| """ |
| super().__init__() |
| |
| |
| self.content_encoder = nn.Sequential( |
| nn.Linear(input_dim, 512), |
| nn.SELU(), |
| nn.AlphaDropout(0.1), |
| nn.Linear(512, 256), |
| nn.SELU(), |
| nn.AlphaDropout(0.1), |
| nn.Linear(256, content_dim), |
| ) |
| |
| |
| self.style_encoder = nn.Sequential( |
| nn.Linear(input_dim, 128), |
| nn.SELU(), |
| nn.Linear(128, style_dim), |
| ) |
| |
| |
| self.batch_classifier = nn.Linear(content_dim, num_batches) |
| |
| self.temperature = temperature |
| self.num_batches = num_batches |
|
|
| def forward( |
| self, x: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| z_content: batch-corrected biological representation |
| z_style: batch-specific style representation |
| """ |
| z_content = self.content_encoder(x) |
| z_style = self.style_encoder(x) |
| return z_content, z_style |
|
|
| def compute_infocore_loss( |
| self, |
| z_content: torch.Tensor, |
| z_style: torch.Tensor, |
| batch_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| InfoCORE loss: maximize I(z_content; z_style | batch) via |
| reweighted InfoNCE. |
| |
| Args: |
| z_content: (N, D_c) content representations |
| z_style: (N, D_s) style representations |
| batch_ids: (N,) batch identity for each sample |
| """ |
| N = z_content.shape[0] |
| device = z_content.device |
| |
| |
| z_content = F.normalize(z_content, dim=-1) |
| z_style = F.normalize(z_style, dim=-1) |
| |
| |
| sim = torch.matmul(z_content, z_style.T) / self.temperature |
| |
| |
| batch_logits = self.batch_classifier(z_content) |
| batch_probs = F.softmax(batch_logits, dim=-1) |
| |
| |
| |
| |
| batch_onehot = F.one_hot(batch_ids.long(), self.num_batches).float() |
| |
| |
| |
| weights = torch.matmul(batch_probs, batch_onehot.T) |
| |
| |
| |
| pos_mask = (batch_ids.unsqueeze(0) == batch_ids.unsqueeze(1)).float() |
| pos_mask.fill_diagonal_(0.0) |
| |
| |
| pos_sim = (sim * pos_mask).sum(dim=-1) / (pos_mask.sum(dim=-1) + 1e-8) |
| |
| |
| |
| neg_mask = 1.0 - torch.eye(N, device=device) |
| weighted_sim = sim * weights * neg_mask |
| |
| |
| log_denom = torch.logsumexp( |
| torch.cat([ |
| pos_sim.unsqueeze(-1), |
| weighted_sim |
| ], dim=-1), |
| dim=-1 |
| ) |
| |
| loss = -pos_sim + log_denom |
| return loss.mean() |
|
|
| def correct_batch_effects( |
| self, x: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Apply batch correction: return only the content (biological) |
| representation for downstream use. |
| """ |
| z_content, _ = self.forward(x) |
| return z_content |
|
|
|
|
| class BatchHarmonizer(nn.Module): |
| """ |
| Full batch harmonization pipeline wrapping InfoCORE. |
| |
| Can be applied per-modality before feeding into MuLGIT. |
| Also handles multi-omics: each modality gets its own corrector. |
| """ |
| |
| def __init__( |
| self, |
| modality_dims: dict[str, int], |
| content_dim: int = 128, |
| num_batches: int = 10, |
| ): |
| """ |
| Args: |
| modality_dims: {"methylation": 20000, "mrna": 20000, ...} |
| content_dim: output dimension for corrected features |
| num_batches: expected number of batches |
| """ |
| super().__init__() |
| self.correctors = nn.ModuleDict({ |
| name: InfoCOREBatchCorrector(dim, content_dim, num_batches=num_batches) |
| for name, dim in modality_dims.items() |
| }) |
| self.modality_dims = modality_dims |
|
|
| def forward( |
| self, |
| modalities: dict[str, torch.Tensor], |
| batch_ids: Optional[torch.Tensor] = None, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Correct batch effects in all modalities. |
| |
| Args: |
| modalities: {"mrna": tensor, "methylation": tensor, ...} |
| batch_ids: (N,) batch labels (optional, only for training) |
| """ |
| corrected = {} |
| for name, x in modalities.items(): |
| if name in self.correctors: |
| corrected[name] = self.correctors[name].correct_batch_effects(x) |
| else: |
| corrected[name] = x |
| return corrected |
|
|
| def compute_total_loss( |
| self, |
| modalities: dict[str, torch.Tensor], |
| batch_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| """Sum of InfoCORE losses across all modalities.""" |
| total_loss = 0.0 |
| for name in self.correctors: |
| x = modalities[name] |
| z_c, z_s = self.correctors[name](x) |
| total_loss += self.correctors[name].compute_infocore_loss(z_c, z_s, batch_ids) |
| return total_loss |
|
|