Methformer / scripts /methformer.py
CChahrour's picture
Upload folder using huggingface_hub
9a0f27c verified
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
class MethformerDataset(Dataset):
"""
Dataset that returns masked inputs, original labels, and attention masks.
"""
def __init__(
self, data_tensor, chunk_size=128, mask_value=-1.0, masking_ratio=0.15
):
self.data = data_tensor
self.n_samples, self.n_regions, self.n_channels = self.data.shape
self.chunk_size = min(chunk_size, self.n_regions)
self.mask_value = mask_value
self.masking_ratio = masking_ratio
def __len__(self):
return self.n_samples * (self.n_regions // self.chunk_size)
def __getitem__(self, idx):
sample_idx = idx % self.n_samples
chunk_start = random.randint(0, self.n_regions - self.chunk_size)
chunk = self.data[sample_idx, chunk_start : chunk_start + self.chunk_size, :]
x = torch.tensor(chunk, dtype=torch.float32)
mask = torch.rand(self.chunk_size) < self.masking_ratio
x_masked = x.clone()
x_masked[mask] = self.mask_value
return {"inputs": x_masked, "labels": x, "attention_mask": ~mask}
class MethformerCollator:
def __init__(self, masking_ratio=0.15):
self.masking_ratio = masking_ratio
def __call__(self, batch):
def ensure_tensor(x):
if isinstance(x, torch.Tensor):
return x
return torch.tensor(x, dtype=torch.float32)
inputs = [ensure_tensor(item["inputs"]) for item in batch]
labels = [ensure_tensor(item["labels"]) for item in batch]
attention_mask = [
torch.tensor(item["attention_mask"], dtype=torch.bool) for item in batch
]
inputs_tensor = torch.stack(inputs)
labels_tensor = torch.stack(labels)
attention_mask_tensor = torch.stack(attention_mask)
return {
"input_values": inputs_tensor,
"labels": labels_tensor,
"attention_mask": attention_mask_tensor,
}
class Methformer(PreTrainedModel):
"""
Masked Transformer model for methylation data.
"""
def __init__(self, config):
super().__init__(config)
self.input_dim = getattr(config, "input_dim", 2)
hidden_dim = getattr(config, "hidden_dim", 128)
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
dropout = config.hidden_dropout_prob
max_len = getattr(config, "max_position_embeddings", 1024)
self.embed = nn.Linear(self.input_dim, hidden_dim)
self.pos_embed = nn.Parameter(torch.randn(1, max_len, hidden_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_head = nn.Linear(hidden_dim, self.input_dim)
def forward(self, input_values, attention_mask, labels=None):
x = self.embed(input_values)
x = x + self.pos_embed[:, : x.size(1), :].to(x.device)
attn_mask = ~attention_mask.bool()
x = self.encoder(x, src_key_padding_mask=attn_mask)
output = self.output_head(x)
loss = None
if labels is not None:
mask = attention_mask.unsqueeze(-1).expand_as(labels)
loss_fn = nn.MSELoss()
loss = loss_fn(output[mask], labels[mask])
return ModelOutput(loss=loss, last_hidden_state=output)
class MethformerRegressor(PreTrainedModel):
"""
Regression model that uses Methformer as the encoder.
"""
def __init__(self, config):
super().__init__(config)
self.encoder = Methformer(config)
self.regression_head = nn.Linear(config.hidden_dim, 1)
def forward(self, input_values, attention_mask, labels=None):
x = self.encoder(input_values, attention_mask)
pooled = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
1, keepdim=True
)
logits = self.regression_head(pooled)
loss = None
if labels is not None:
loss = F.mse_loss(logits, labels)
return {"loss": loss, "logits": logits}