CChahrour commited on
Commit
9a0f27c
·
verified ·
1 Parent(s): 1926288

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.ipynb
3
+ data/
4
+ logs/
5
+ notebooks/
6
+ output/
7
+ run.sh
8
+ wandb/
README.md CHANGED
@@ -1,7 +1,105 @@
 
 
 
 
1
  ---
2
- license: mit
3
- language:
4
- - en
5
- pipeline_tag: image-feature-extraction
6
- library_name: transformers
7
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🧚 MethFormer: A Transformer for DNA Methylation
2
+
3
+ **MethFormer** is a masked regression transformer model trained to learn local and long-range patterns in DNA methylation (5mC and 5hmC) across genomic regions. Pretrained on binned methylation data, it is designed for downstream fine-tuning on tasks such as predicting MLL binding or chromatin state.
4
+
5
  ---
6
+
7
+ ## 🚀 Overview
8
+
9
+ * **Inputs**: Binned methylation values (5mC, 5hmC) over 1024bp windows (32 bins × 2 channels)
10
+ * **Pretraining objective**: Masked methylation imputation (per-bin regression)
11
+ * **Architecture**: Transformer encoder with linear projection head
12
+ * **Downstream tasks**: MLL binding prediction, chromatin state inference, or enhancer classification
13
+
14
+ ---
15
+
16
+ ## 📁 Project Structure
17
+
18
+ ```
19
+ .
20
+ ├── config/ # config
21
+ ├── data/ # Binned methylation datasets (HuggingFace format)
22
+ ├── output/ # Pretrained models, logs, and checkpoints
23
+ ├── scripts/
24
+ │ ├── methformer.py # Model classes, data collator,
25
+ │ ├── pretrain_methformer.py # Main training script
26
+ │ └── finetune_mll.py # (optional) downstream fine-tuning
27
+ ├── requirements.txt
28
+ └── README.md
29
+ ```
30
+
31
+ ---
32
+
33
+ ## 👩‍💻 Pretraining MethFormer
34
+
35
+ ### Step 1: Prepare Dataset
36
+
37
+ Preprocess 5mC and 5hmC data into 1024bp windows, binned into 32 bins × 2 features. Save using Hugging Face's `datasets.DatasetDict` format:
38
+
39
+ ```
40
+ DatasetDict({
41
+ train: Dataset({
42
+ features: ['input_values', 'attention_mask', 'labels']
43
+ }),
44
+ validation: Dataset(...)
45
+ })
46
+ ```
47
+
48
+ ### Step 2: Run Pretraining
49
+
50
+ ```bash
51
+ python scripts/pretrain_methformer.py
52
+ ```
53
+
54
+ Options can be customized inside the script or modified for sweep tuning. This will:
55
+
56
+ * Train the model using masked regression loss
57
+ * Evaluate on a held-out chromosome (e.g., `chr8`)
58
+ * Log metrics to [Weights & Biases](https://wandb.ai)
59
+ * Save the best model checkpoint
60
+
61
+ ---
62
+
63
+ ## 📊 Metrics
64
+
65
+ * `masked_mse`: Mean squared error over unmasked positions
66
+ * `masked_mae`: Mean absolute error
67
+
68
+ ---
69
+
70
+ ## 🧪 Fine-tuning on MLL Binding
71
+
72
+ After pretraining:
73
+
74
+ 1. Replace the regression head with a scalar head for MLL prediction.
75
+ 2. Use a `Trainer` to fine-tune on log1p-transformed MLL-N RPKM values mean over 1kb regions.
76
+
77
+ See `scripts/finetune_mll.py` for an example.
78
+
79
+ ---
80
+
81
+ ## 🔍 Visualizations & Interpretability
82
+
83
+ You can run [Captum](https://captum.ai) or SHAP for:
84
+
85
+ * Per-bin attribution of 5mC/5hmC to MLL binding
86
+ * Visualizing what MethFormer attends to during fine-tuning
87
+
88
+ ---
89
+
90
+ ## 🛠️ Dependencies
91
+
92
+ Key packages:
93
+
94
+ * `transformers`
95
+ * `datasets`
96
+ * `wandb`
97
+ * `torch`
98
+ * `anndata`
99
+ * `scikit-learn`
100
+
101
+ ---
102
+
103
+ ## 🧠 Acknowledgements
104
+
105
+ * Built with inspiration from DNABERT, Grover, and vision transformers
config/pretrain_sweep_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "methformer_pretrain_sweep",
3
+ "method": "bayes",
4
+ "metric": {"name": "eval/masked_mse", "goal": "minimize"},
5
+ "early_terminate": {
6
+ "type": "hyperband",
7
+ "min_iter": 4,
8
+ "eta": 2
9
+ },
10
+ "parameters": {
11
+ "masking_ratio": {"values": [0.1, 0.15, 0.2]},
12
+ "hidden_dim": {"values": [64, 128, 256]},
13
+ "num_hidden_layers": {"values": [6, 8, 12]},
14
+ "num_attention_heads": {"values": [4, 8]},
15
+ "hidden_dropout_prob": {"values": [0.1, 0.2, 0.3]}
16
+ }
17
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets
2
+ scikit-learn
3
+ torch
4
+ transformers
5
+ wandb
scripts/feature_extract.py ADDED
File without changes
scripts/finetune_mll.py ADDED
File without changes
scripts/methformer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import ModelOutput
8
+
9
+
10
+ class MethformerDataset(Dataset):
11
+ """
12
+ Dataset that returns masked inputs, original labels, and attention masks.
13
+ """
14
+
15
+ def __init__(
16
+ self, data_tensor, chunk_size=128, mask_value=-1.0, masking_ratio=0.15
17
+ ):
18
+ self.data = data_tensor
19
+ self.n_samples, self.n_regions, self.n_channels = self.data.shape
20
+ self.chunk_size = min(chunk_size, self.n_regions)
21
+ self.mask_value = mask_value
22
+ self.masking_ratio = masking_ratio
23
+
24
+ def __len__(self):
25
+ return self.n_samples * (self.n_regions // self.chunk_size)
26
+
27
+ def __getitem__(self, idx):
28
+ sample_idx = idx % self.n_samples
29
+ chunk_start = random.randint(0, self.n_regions - self.chunk_size)
30
+ chunk = self.data[sample_idx, chunk_start : chunk_start + self.chunk_size, :]
31
+
32
+ x = torch.tensor(chunk, dtype=torch.float32)
33
+ mask = torch.rand(self.chunk_size) < self.masking_ratio
34
+ x_masked = x.clone()
35
+ x_masked[mask] = self.mask_value
36
+
37
+ return {"inputs": x_masked, "labels": x, "attention_mask": ~mask}
38
+
39
+
40
+ class MethformerCollator:
41
+ def __init__(self, masking_ratio=0.15):
42
+ self.masking_ratio = masking_ratio
43
+
44
+ def __call__(self, batch):
45
+ def ensure_tensor(x):
46
+ if isinstance(x, torch.Tensor):
47
+ return x
48
+ return torch.tensor(x, dtype=torch.float32)
49
+
50
+ inputs = [ensure_tensor(item["inputs"]) for item in batch]
51
+ labels = [ensure_tensor(item["labels"]) for item in batch]
52
+ attention_mask = [
53
+ torch.tensor(item["attention_mask"], dtype=torch.bool) for item in batch
54
+ ]
55
+
56
+ inputs_tensor = torch.stack(inputs)
57
+ labels_tensor = torch.stack(labels)
58
+ attention_mask_tensor = torch.stack(attention_mask)
59
+
60
+ return {
61
+ "input_values": inputs_tensor,
62
+ "labels": labels_tensor,
63
+ "attention_mask": attention_mask_tensor,
64
+ }
65
+
66
+
67
+ class Methformer(PreTrainedModel):
68
+ """
69
+ Masked Transformer model for methylation data.
70
+ """
71
+
72
+ def __init__(self, config):
73
+ super().__init__(config)
74
+ self.input_dim = getattr(config, "input_dim", 2)
75
+ hidden_dim = getattr(config, "hidden_dim", 128)
76
+ num_layers = config.num_hidden_layers
77
+ num_heads = config.num_attention_heads
78
+ dropout = config.hidden_dropout_prob
79
+ max_len = getattr(config, "max_position_embeddings", 1024)
80
+
81
+ self.embed = nn.Linear(self.input_dim, hidden_dim)
82
+ self.pos_embed = nn.Parameter(torch.randn(1, max_len, hidden_dim))
83
+
84
+ encoder_layer = nn.TransformerEncoderLayer(
85
+ d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True
86
+ )
87
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
88
+ self.output_head = nn.Linear(hidden_dim, self.input_dim)
89
+
90
+ def forward(self, input_values, attention_mask, labels=None):
91
+ x = self.embed(input_values)
92
+ x = x + self.pos_embed[:, : x.size(1), :].to(x.device)
93
+
94
+ attn_mask = ~attention_mask.bool()
95
+ x = self.encoder(x, src_key_padding_mask=attn_mask)
96
+ output = self.output_head(x)
97
+
98
+ loss = None
99
+ if labels is not None:
100
+ mask = attention_mask.unsqueeze(-1).expand_as(labels)
101
+ loss_fn = nn.MSELoss()
102
+ loss = loss_fn(output[mask], labels[mask])
103
+
104
+ return ModelOutput(loss=loss, last_hidden_state=output)
105
+
106
+
107
+ class MethformerRegressor(PreTrainedModel):
108
+ """
109
+ Regression model that uses Methformer as the encoder.
110
+ """
111
+
112
+ def __init__(self, config):
113
+ super().__init__(config)
114
+ self.encoder = Methformer(config)
115
+ self.regression_head = nn.Linear(config.hidden_dim, 1)
116
+
117
+ def forward(self, input_values, attention_mask, labels=None):
118
+ x = self.encoder(input_values, attention_mask)
119
+ pooled = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
120
+ 1, keepdim=True
121
+ )
122
+ logits = self.regression_head(pooled)
123
+ loss = None
124
+ if labels is not None:
125
+ loss = F.mse_loss(logits, labels)
126
+ return {"loss": loss, "logits": logits}
scripts/pretrain_methformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+
4
+ import torch
5
+ import wandb
6
+ from datasets import load_from_disk
7
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
8
+ from transformers import (
9
+ EarlyStoppingCallback,
10
+ PretrainedConfig,
11
+ Trainer,
12
+ TrainingArguments,
13
+ )
14
+
15
+ from methformer import (
16
+ Methformer,
17
+ MethformerCollator,
18
+ )
19
+
20
+ run_name = f"mf_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}"
21
+ print(f"Run name: {run_name}")
22
+
23
+ out_dir = "/home/ubuntu/project/MethFormer/output/methformer_pretrained/"
24
+ os.makedirs(out_dir, exist_ok=True)
25
+
26
+
27
+ device = (
28
+ "cuda"
29
+ if torch.cuda.is_available()
30
+ else "mps"
31
+ if torch.backends.mps.is_available()
32
+ else "cpu"
33
+ )
34
+
35
+ dataset = load_from_disk("/home/ubuntu/project/MethFormer/data/methformer_pretrain_binned")
36
+ train_dataset = dataset["train"].shuffle(seed=42)
37
+ eval_dataset = dataset["validation"]
38
+
39
+ data_collator = MethformerCollator()
40
+
41
+ config = PretrainedConfig(
42
+ input_dim=2,
43
+ hidden_dim=128,
44
+ num_hidden_layers=12,
45
+ num_attention_heads=8,
46
+ hidden_dropout_prob=0.1,
47
+ )
48
+
49
+ model = Methformer(config)
50
+ model.to(device)
51
+
52
+ training_args = TrainingArguments(
53
+ run_name=run_name,
54
+ output_dir=os.path.join(out_dir, "checkpoints"),
55
+ eval_on_start=True,
56
+ per_device_train_batch_size=128,
57
+ per_device_eval_batch_size=256,
58
+ gradient_accumulation_steps=1,
59
+ max_grad_norm=1.0,
60
+ learning_rate=1e-5,
61
+ warmup_ratio=0.05,
62
+ lr_scheduler_type="cosine",
63
+ num_train_epochs=20,
64
+ logging_dir=os.path.join(out_dir, "logs"),
65
+ save_strategy="steps",
66
+ save_total_limit=1,
67
+ eval_strategy="steps",
68
+ logging_steps=1000,
69
+ eval_steps=1000,
70
+ save_steps=5000,
71
+ metric_for_best_model="masked_mse",
72
+ greater_is_better=False,
73
+ report_to="wandb",
74
+ disable_tqdm=False,
75
+ dataloader_num_workers=8,
76
+ remove_unused_columns=False,
77
+ fp16=not torch.backends.mps.is_available(),
78
+ load_best_model_at_end=True,
79
+ seed=42,
80
+ )
81
+
82
+
83
+ def compute_metrics(eval_preds):
84
+ logits, labels = eval_preds
85
+ logits = torch.tensor(logits)
86
+ labels = torch.tensor(labels)
87
+ mask = labels != -1.0
88
+ masked_logits = logits[mask].cpu.numpy()
89
+ masked_labels = labels[mask].cpu.numpy()
90
+ mse = mean_squared_error(masked_labels, masked_logits)
91
+ mae = mean_absolute_error(masked_labels, masked_logits)
92
+ return {
93
+ "masked_mse": mse,
94
+ "masked_mae": mae,
95
+ }
96
+
97
+
98
+ trainer = Trainer(
99
+ model=model,
100
+ args=training_args,
101
+ train_dataset=train_dataset,
102
+ eval_dataset=eval_dataset,
103
+ compute_metrics=compute_metrics,
104
+ data_collator=data_collator,
105
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
106
+ )
107
+
108
+ print("Starting training...")
109
+
110
+ wandb.init(
111
+ group="methformer_pretrain",
112
+ job_type="pretrain_full",
113
+ name=run_name,
114
+ dir=out_dir,
115
+ reinit="finish_previous",
116
+ config=config.to_dict(),
117
+ )
118
+
119
+ trainer.train()
120
+ print("Training complete. Saving model...")
121
+
122
+ save_path = f"{out_dir}/model"
123
+ os.makedirs(save_path, exist_ok=True)
124
+ trainer.save_model(save_path)
125
+ model.config.save_pretrained(save_path)
126
+ print(f"Model saved to {save_path}")
127
+
128
+ wandb.finish()
scripts/pretrain_sweep.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ import wandb
7
+ from datasets import load_from_disk
8
+ from transformers import (
9
+ EarlyStoppingCallback,
10
+ PretrainedConfig,
11
+ Trainer,
12
+ TrainingArguments,
13
+ )
14
+
15
+ from methformer import Methformer, MethformerCollator
16
+
17
+
18
+ def compute_metrics(eval_preds):
19
+ logits, labels = eval_preds
20
+ logits = torch.tensor(logits)
21
+ labels = torch.tensor(labels)
22
+
23
+ # Only evaluate masked positions (label == -1.0 was masked during input)
24
+ mask = labels != -1.0
25
+
26
+ masked_mse = torch.mean((logits[mask] - labels[mask]) ** 2).item()
27
+ masked_mae = torch.mean(torch.abs(logits[mask] - labels[mask])).item()
28
+
29
+ return {
30
+ "masked_mse": masked_mse,
31
+ "masked_mae": masked_mae,
32
+ }
33
+
34
+ device = (
35
+ "cuda"
36
+ if torch.cuda.is_available()
37
+ else "mps"
38
+ if torch.backends.mps.is_available()
39
+ else "cpu"
40
+ )
41
+
42
+ dataset = load_from_disk("/home/ubuntu/project/MethFormer/data/methformer_pretrain_binned")
43
+ train_dataset = dataset["train"].shuffle(seed=42)
44
+ eval_dataset = dataset["validation"]
45
+
46
+
47
+ def train():
48
+ wandb.init(
49
+ group="methformer_pretrain_sweep",
50
+ job_type="pretrain_sweep",
51
+ name=f"mf_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}",
52
+ dir="/home/ubuntu/project/MethFormer/output/methformer_pretrain_sweep",
53
+ reinit="finish_previous",
54
+ )
55
+ config = wandb.config
56
+
57
+ run_name = f"mf_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}"
58
+ out_dir = f"/home/ubuntu/project/MethFormer/output/methformer_pretrain_sweep/{run_name}"
59
+ os.makedirs(out_dir, exist_ok=True)
60
+
61
+ model_config = PretrainedConfig(
62
+ input_dim=2,
63
+ hidden_dim=config.hidden_dim,
64
+ num_hidden_layers=config.num_hidden_layers,
65
+ num_attention_heads=config.num_attention_heads,
66
+ hidden_dropout_prob=config.hidden_dropout_prob,
67
+ )
68
+
69
+ model = Methformer(model_config)
70
+ model.to(device)
71
+
72
+ training_args = TrainingArguments(
73
+ run_name=run_name,
74
+ output_dir=os.path.join(out_dir, "checkpoints"),
75
+ eval_on_start=True,
76
+ per_device_train_batch_size=128,
77
+ per_device_eval_batch_size=256,
78
+ gradient_accumulation_steps=1,
79
+ max_grad_norm=1.0,
80
+ learning_rate=1e-5,
81
+ warmup_ratio=0.05,
82
+ lr_scheduler_type="cosine",
83
+ num_train_epochs=20,
84
+ logging_dir=os.path.join(out_dir, "logs"),
85
+ save_strategy="steps",
86
+ save_total_limit=1,
87
+ eval_strategy="steps",
88
+ logging_steps=500,
89
+ eval_steps=5000,
90
+ save_steps=5000,
91
+ metric_for_best_model="masked_mse",
92
+ greater_is_better=False,
93
+ report_to="wandb",
94
+ disable_tqdm=False,
95
+ dataloader_num_workers=8,
96
+ remove_unused_columns=False,
97
+ fp16=not torch.backends.mps.is_available(),
98
+ load_best_model_at_end=True,
99
+ seed=42,
100
+ )
101
+
102
+ trainer = Trainer(
103
+ model=model,
104
+ args=training_args,
105
+ train_dataset=train_dataset,
106
+ eval_dataset=eval_dataset,
107
+ compute_metrics=compute_metrics,
108
+ data_collator=MethformerCollator(masking_ratio=config.masking_ratio),
109
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
110
+ )
111
+
112
+ trainer.train()
113
+
114
+ # Save the final model
115
+ model.save_pretrained(os.path.join(out_dir, "model"))
116
+ model.config.save_pretrained(os.path.join(out_dir, "model"))
117
+
118
+
119
+ with open("/home/ubuntu/project/MethFormer/config/pretrain_sweep_config.json", "r") as f:
120
+ sweep_config = json.load(f)
121
+
122
+ sweep_id = wandb.sweep(
123
+ sweep=sweep_config,
124
+ project="MethFormer",
125
+ )
126
+
127
+ wandb.agent(sweep_id, train, count=20)
128
+
129
+ # After the sweep
130
+ api = wandb.Api()
131
+
132
+ sweep_path = f"{wandb.run.entity}/{wandb.run.project}/{sweep_id}"
133
+ sweep = api.sweep(sweep_path)
134
+
135
+ # Filter only finished runs with masked_r2
136
+ runs = [
137
+ run for run in sweep.runs if run.state == "finished" and "masked_r2" in run.summary
138
+ ]
139
+
140
+ # Find best run by highest masked_r2
141
+ best_run = max(runs, key=lambda r: r.summary["masked_r2"])
142
+
143
+ # Save best config
144
+ best_config = {k: v for k, v in best_run.config.items() if not k.startswith("_")}
145
+ with open("/home/ubuntu/project/MethFormer/config/best_config.json", "w") as f:
146
+ json.dump(best_config, f, indent=2)
147
+
148
+ print(f"Best run ID: {best_run.id}")
149
+ print(f"Best masked_r2: {best_run.summary['masked_r2']}")