""" InfoCORE Batch Correction Module Removes batch effects and confounding factors from multi-omics data using Conditional Mutual Information Maximization. Based on: "Removing Biases from Molecular Representations via Information Maximization" (arxiv:2312.00718) Method: Maximize I(Z_content; Z_style | X_batch) via reweighted InfoNCE loss. This preserves biological signal while removing batch-specific variation. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List class InfoCOREBatchCorrector(nn.Module): """ Batch effect removal via InfoCORE: reweighted contrastive learning that maximizes conditional mutual information between content and style representations given batch identity. The key insight: rather than restricting negatives to same-batch samples (as in CCL), InfoCORE uses ALL samples but reweights them by posterior batch probability. This gives both debiasing AND sufficient negative samples. """ def __init__( self, input_dim: int, content_dim: int = 128, style_dim: int = 32, temperature: float = 0.07, num_batches: int = 10, ): """ Args: input_dim: dimension of input features (e.g., gene expression) content_dim: dimension of content (biological) representation style_dim: dimension of style (batch-specific) representation temperature: softmax temperature for InfoNCE num_batches: expected number of distinct batches """ super().__init__() # Content encoder: captures biological signal self.content_encoder = nn.Sequential( nn.Linear(input_dim, 512), nn.SELU(), nn.AlphaDropout(0.1), nn.Linear(512, 256), nn.SELU(), nn.AlphaDropout(0.1), nn.Linear(256, content_dim), ) # Style encoder: captures batch-specific variation self.style_encoder = nn.Sequential( nn.Linear(input_dim, 128), nn.SELU(), nn.Linear(128, style_dim), ) # Batch classifier (for posterior weight estimation) self.batch_classifier = nn.Linear(content_dim, num_batches) self.temperature = temperature self.num_batches = num_batches def forward( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Returns: z_content: batch-corrected biological representation z_style: batch-specific style representation """ z_content = self.content_encoder(x) z_style = self.style_encoder(x) return z_content, z_style def compute_infocore_loss( self, z_content: torch.Tensor, z_style: torch.Tensor, batch_ids: torch.Tensor, ) -> torch.Tensor: """ InfoCORE loss: maximize I(z_content; z_style | batch) via reweighted InfoNCE. Args: z_content: (N, D_c) content representations z_style: (N, D_s) style representations batch_ids: (N,) batch identity for each sample """ N = z_content.shape[0] device = z_content.device # Normalize representations z_content = F.normalize(z_content, dim=-1) z_style = F.normalize(z_style, dim=-1) # Compute similarity matrix sim = torch.matmul(z_content, z_style.T) / self.temperature # (N, N) # Compute batch posterior probabilities for reweighting batch_logits = self.batch_classifier(z_content) batch_probs = F.softmax(batch_logits, dim=-1) # (N, num_batches) # Reweight negatives by posterior batch probability # For each sample i, weight sample j by how likely j's batch is # given i's content: p(batch_j | z_content_i) batch_onehot = F.one_hot(batch_ids.long(), self.num_batches).float() # Weight matrix: w_{i,j} = p(batch_j | content_i) # = sum_k p(batch=k | content_i) * 1[batch_j = k] weights = torch.matmul(batch_probs, batch_onehot.T) # (N, N) # Apply reweighting: positives get weight=1, negatives are reweighted # Create mask for positives (same batch) pos_mask = (batch_ids.unsqueeze(0) == batch_ids.unsqueeze(1)).float() pos_mask.fill_diagonal_(0.0) # remove self # Reweighted InfoNCE numerator: sum of positive similarities pos_sim = (sim * pos_mask).sum(dim=-1) / (pos_mask.sum(dim=-1) + 1e-8) # Reweighted denominator: sum of ALL reweighted similarities # Exclude self from denominator neg_mask = 1.0 - torch.eye(N, device=device) weighted_sim = sim * weights * neg_mask # InfoNCE loss log_denom = torch.logsumexp( torch.cat([ pos_sim.unsqueeze(-1), weighted_sim ], dim=-1), dim=-1 ) loss = -pos_sim + log_denom return loss.mean() def correct_batch_effects( self, x: torch.Tensor ) -> torch.Tensor: """ Apply batch correction: return only the content (biological) representation for downstream use. """ z_content, _ = self.forward(x) return z_content class BatchHarmonizer(nn.Module): """ Full batch harmonization pipeline wrapping InfoCORE. Can be applied per-modality before feeding into MuLGIT. Also handles multi-omics: each modality gets its own corrector. """ def __init__( self, modality_dims: dict[str, int], content_dim: int = 128, num_batches: int = 10, ): """ Args: modality_dims: {"methylation": 20000, "mrna": 20000, ...} content_dim: output dimension for corrected features num_batches: expected number of batches """ super().__init__() self.correctors = nn.ModuleDict({ name: InfoCOREBatchCorrector(dim, content_dim, num_batches=num_batches) for name, dim in modality_dims.items() }) self.modality_dims = modality_dims def forward( self, modalities: dict[str, torch.Tensor], batch_ids: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """ Correct batch effects in all modalities. Args: modalities: {"mrna": tensor, "methylation": tensor, ...} batch_ids: (N,) batch labels (optional, only for training) """ corrected = {} for name, x in modalities.items(): if name in self.correctors: corrected[name] = self.correctors[name].correct_batch_effects(x) else: corrected[name] = x return corrected def compute_total_loss( self, modalities: dict[str, torch.Tensor], batch_ids: torch.Tensor, ) -> torch.Tensor: """Sum of InfoCORE losses across all modalities.""" total_loss = 0.0 for name in self.correctors: x = modalities[name] z_c, z_s = self.correctors[name](x) total_loss += self.correctors[name].compute_infocore_loss(z_c, z_s, batch_ids) return total_loss