import random import torch from torch import nn import torch.nn.functional as F from torch.utils.data import Dataset from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput class MethformerDataset(Dataset): """ Dataset that returns masked inputs, original labels, and attention masks. """ def __init__( self, data_tensor, chunk_size=128, mask_value=-1.0, masking_ratio=0.15 ): self.data = data_tensor self.n_samples, self.n_regions, self.n_channels = self.data.shape self.chunk_size = min(chunk_size, self.n_regions) self.mask_value = mask_value self.masking_ratio = masking_ratio def __len__(self): return self.n_samples * (self.n_regions // self.chunk_size) def __getitem__(self, idx): sample_idx = idx % self.n_samples chunk_start = random.randint(0, self.n_regions - self.chunk_size) chunk = self.data[sample_idx, chunk_start : chunk_start + self.chunk_size, :] x = torch.tensor(chunk, dtype=torch.float32) mask = torch.rand(self.chunk_size) < self.masking_ratio x_masked = x.clone() x_masked[mask] = self.mask_value return {"inputs": x_masked, "labels": x, "attention_mask": ~mask} class MethformerCollator: def __init__(self, masking_ratio=0.15): self.masking_ratio = masking_ratio def __call__(self, batch): def ensure_tensor(x): if isinstance(x, torch.Tensor): return x return torch.tensor(x, dtype=torch.float32) inputs = [ensure_tensor(item["inputs"]) for item in batch] labels = [ensure_tensor(item["labels"]) for item in batch] attention_mask = [ torch.tensor(item["attention_mask"], dtype=torch.bool) for item in batch ] inputs_tensor = torch.stack(inputs) labels_tensor = torch.stack(labels) attention_mask_tensor = torch.stack(attention_mask) return { "input_values": inputs_tensor, "labels": labels_tensor, "attention_mask": attention_mask_tensor, } class Methformer(PreTrainedModel): """ Masked Transformer model for methylation data. """ def __init__(self, config): super().__init__(config) self.input_dim = getattr(config, "input_dim", 2) hidden_dim = getattr(config, "hidden_dim", 128) num_layers = config.num_hidden_layers num_heads = config.num_attention_heads dropout = config.hidden_dropout_prob max_len = getattr(config, "max_position_embeddings", 1024) self.embed = nn.Linear(self.input_dim, hidden_dim) self.pos_embed = nn.Parameter(torch.randn(1, max_len, hidden_dim)) encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_head = nn.Linear(hidden_dim, self.input_dim) def forward(self, input_values, attention_mask, labels=None): x = self.embed(input_values) x = x + self.pos_embed[:, : x.size(1), :].to(x.device) attn_mask = ~attention_mask.bool() x = self.encoder(x, src_key_padding_mask=attn_mask) output = self.output_head(x) loss = None if labels is not None: mask = attention_mask.unsqueeze(-1).expand_as(labels) loss_fn = nn.MSELoss() loss = loss_fn(output[mask], labels[mask]) return ModelOutput(loss=loss, last_hidden_state=output) class MethformerRegressor(PreTrainedModel): """ Regression model that uses Methformer as the encoder. """ def __init__(self, config): super().__init__(config) self.encoder = Methformer(config) self.regression_head = nn.Linear(config.hidden_dim, 1) def forward(self, input_values, attention_mask, labels=None): x = self.encoder(input_values, attention_mask) pooled = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum( 1, keepdim=True ) logits = self.regression_head(pooled) loss = None if labels is not None: loss = F.mse_loss(logits, labels) return {"loss": loss, "logits": logits}