MuLGIT / mulgit /batch_correction.py
vedatonuryilmaz's picture
Upload mulgit/batch_correction.py with huggingface_hub
31a5e86 verified
"""
InfoCORE Batch Correction Module
Removes batch effects and confounding factors from multi-omics data
using Conditional Mutual Information Maximization.
Based on: "Removing Biases from Molecular Representations via
Information Maximization" (arxiv:2312.00718)
Method:
Maximize I(Z_content; Z_style | X_batch) via reweighted InfoNCE loss.
This preserves biological signal while removing batch-specific variation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List
class InfoCOREBatchCorrector(nn.Module):
"""
Batch effect removal via InfoCORE: reweighted contrastive learning
that maximizes conditional mutual information between content and
style representations given batch identity.
The key insight: rather than restricting negatives to same-batch
samples (as in CCL), InfoCORE uses ALL samples but reweights them
by posterior batch probability. This gives both debiasing AND
sufficient negative samples.
"""
def __init__(
self,
input_dim: int,
content_dim: int = 128,
style_dim: int = 32,
temperature: float = 0.07,
num_batches: int = 10,
):
"""
Args:
input_dim: dimension of input features (e.g., gene expression)
content_dim: dimension of content (biological) representation
style_dim: dimension of style (batch-specific) representation
temperature: softmax temperature for InfoNCE
num_batches: expected number of distinct batches
"""
super().__init__()
# Content encoder: captures biological signal
self.content_encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.SELU(),
nn.AlphaDropout(0.1),
nn.Linear(512, 256),
nn.SELU(),
nn.AlphaDropout(0.1),
nn.Linear(256, content_dim),
)
# Style encoder: captures batch-specific variation
self.style_encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.SELU(),
nn.Linear(128, style_dim),
)
# Batch classifier (for posterior weight estimation)
self.batch_classifier = nn.Linear(content_dim, num_batches)
self.temperature = temperature
self.num_batches = num_batches
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
z_content: batch-corrected biological representation
z_style: batch-specific style representation
"""
z_content = self.content_encoder(x)
z_style = self.style_encoder(x)
return z_content, z_style
def compute_infocore_loss(
self,
z_content: torch.Tensor,
z_style: torch.Tensor,
batch_ids: torch.Tensor,
) -> torch.Tensor:
"""
InfoCORE loss: maximize I(z_content; z_style | batch) via
reweighted InfoNCE.
Args:
z_content: (N, D_c) content representations
z_style: (N, D_s) style representations
batch_ids: (N,) batch identity for each sample
"""
N = z_content.shape[0]
device = z_content.device
# Normalize representations
z_content = F.normalize(z_content, dim=-1)
z_style = F.normalize(z_style, dim=-1)
# Compute similarity matrix
sim = torch.matmul(z_content, z_style.T) / self.temperature # (N, N)
# Compute batch posterior probabilities for reweighting
batch_logits = self.batch_classifier(z_content)
batch_probs = F.softmax(batch_logits, dim=-1) # (N, num_batches)
# Reweight negatives by posterior batch probability
# For each sample i, weight sample j by how likely j's batch is
# given i's content: p(batch_j | z_content_i)
batch_onehot = F.one_hot(batch_ids.long(), self.num_batches).float()
# Weight matrix: w_{i,j} = p(batch_j | content_i)
# = sum_k p(batch=k | content_i) * 1[batch_j = k]
weights = torch.matmul(batch_probs, batch_onehot.T) # (N, N)
# Apply reweighting: positives get weight=1, negatives are reweighted
# Create mask for positives (same batch)
pos_mask = (batch_ids.unsqueeze(0) == batch_ids.unsqueeze(1)).float()
pos_mask.fill_diagonal_(0.0) # remove self
# Reweighted InfoNCE numerator: sum of positive similarities
pos_sim = (sim * pos_mask).sum(dim=-1) / (pos_mask.sum(dim=-1) + 1e-8)
# Reweighted denominator: sum of ALL reweighted similarities
# Exclude self from denominator
neg_mask = 1.0 - torch.eye(N, device=device)
weighted_sim = sim * weights * neg_mask
# InfoNCE loss
log_denom = torch.logsumexp(
torch.cat([
pos_sim.unsqueeze(-1),
weighted_sim
], dim=-1),
dim=-1
)
loss = -pos_sim + log_denom
return loss.mean()
def correct_batch_effects(
self, x: torch.Tensor
) -> torch.Tensor:
"""
Apply batch correction: return only the content (biological)
representation for downstream use.
"""
z_content, _ = self.forward(x)
return z_content
class BatchHarmonizer(nn.Module):
"""
Full batch harmonization pipeline wrapping InfoCORE.
Can be applied per-modality before feeding into MuLGIT.
Also handles multi-omics: each modality gets its own corrector.
"""
def __init__(
self,
modality_dims: dict[str, int],
content_dim: int = 128,
num_batches: int = 10,
):
"""
Args:
modality_dims: {"methylation": 20000, "mrna": 20000, ...}
content_dim: output dimension for corrected features
num_batches: expected number of batches
"""
super().__init__()
self.correctors = nn.ModuleDict({
name: InfoCOREBatchCorrector(dim, content_dim, num_batches=num_batches)
for name, dim in modality_dims.items()
})
self.modality_dims = modality_dims
def forward(
self,
modalities: dict[str, torch.Tensor],
batch_ids: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""
Correct batch effects in all modalities.
Args:
modalities: {"mrna": tensor, "methylation": tensor, ...}
batch_ids: (N,) batch labels (optional, only for training)
"""
corrected = {}
for name, x in modalities.items():
if name in self.correctors:
corrected[name] = self.correctors[name].correct_batch_effects(x)
else:
corrected[name] = x
return corrected
def compute_total_loss(
self,
modalities: dict[str, torch.Tensor],
batch_ids: torch.Tensor,
) -> torch.Tensor:
"""Sum of InfoCORE losses across all modalities."""
total_loss = 0.0
for name in self.correctors:
x = modalities[name]
z_c, z_s = self.correctors[name](x)
total_loss += self.correctors[name].compute_infocore_loss(z_c, z_s, batch_ids)
return total_loss