nas / BioReason /train_dna_only.py
yuccaaa's picture
Add files using upload-large-folder tool
ffcfc75 verified
import os
import time
import argparse
import torch
import wandb
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
from datasets import load_dataset, concatenate_datasets
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DeepSpeedStrategy
from bioreason.models.dna_only import DNAClassifierModel
from bioreason.dataset.utils import truncate_dna
from bioreason.dataset.kegg import dna_collate_fn
from bioreason.dataset.variant_effect import clean_variant_effect_example
from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer
register_evo2_tokenizer()
class DNAClassifierModelTrainer(pl.LightningModule):
"""
PyTorch Lightning module for training the DNA classifier.
"""
def __init__(self, args):
"""
Initialize the DNAClassifierModelTrainer.
Args:
args: Command line arguments
"""
super().__init__()
self.save_hyperparameters(args)
# Load dataset and labels
self.dataset, self.labels = self.load_dataset()
self.label2id = {label: i for i, label in enumerate(self.labels)}
# Load model
self.dna_model = DNAClassifierModel(
dna_model_name=self.hparams.dna_model_name,
cache_dir=self.hparams.cache_dir,
max_length_dna=self.hparams.max_length_dna,
num_classes=len(self.labels),
dna_is_evo2=self.hparams.dna_is_evo2,
dna_embedding_layer=self.hparams.dna_embedding_layer,
train_just_classifier=self.hparams.train_just_classifier,
)
self.dna_tokenizer = self.dna_model.dna_tokenizer
# Set the training mode for the classifier and pooler
self.dna_model.pooler.train()
self.dna_model.classifier.train()
# Freeze the DNA model parameters
if self.hparams.dna_is_evo2:
self.dna_model_params = self.dna_model.dna_model.model.parameters()
else:
self.dna_model_params = self.dna_model.dna_model.parameters()
if self.hparams.train_just_classifier:
for param in self.dna_model_params:
param.requires_grad = False
def _step(self, prefix, batch_idx, batch):
"""
Performs a single training/validation step.
Args:
batch: Dictionary containing the batch data
prefix: String indicating the step type ('train' or 'val')
Returns:
torch.Tensor: The computed loss for this batch
"""
ref_ids = batch["ref_ids"].to(self.device)
alt_ids = batch["alt_ids"].to(self.device)
ref_attention_mask = batch["ref_attention_mask"].to(self.device)
alt_attention_mask = batch["alt_attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
# Forward pass
logits = self.dna_model(ref_ids=ref_ids, alt_ids=alt_ids, ref_attention_mask=ref_attention_mask, alt_attention_mask=alt_attention_mask)
# Calculate loss
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == labels).float().mean()
# Calculate F1 score, precision, and recall for binary classification
# Assuming label 1 is positive and label 0 is negative as mentioned
true_positives = ((preds == 1) & (labels == 1)).float().sum()
false_positives = ((preds == 1) & (labels == 0)).float().sum()
false_negatives = ((preds == 0) & (labels == 1)).float().sum()
# Calculate precision, recall, and F1 score
precision = true_positives / (true_positives + false_positives + 1e-8) # add small epsilon to avoid division by zero
recall = true_positives / (true_positives + false_negatives + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
# Logging metrics
self.log(
f"{prefix}_loss",
loss,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.log(
f"{prefix}_acc",
acc,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.log(
f"{prefix}_loss_epoch",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
f"{prefix}_acc_epoch",
acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
f"{prefix}_precision",
precision,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.log(
f"{prefix}_precision_epoch",
precision,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
f"{prefix}_recall",
recall,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.log(
f"{prefix}_recall_epoch",
recall,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
f"{prefix}_f1",
f1,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.log(
f"{prefix}_f1_epoch",
f1,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
if (prefix == "test") or (prefix == "train" and (self.global_step % 1000 == 0)) or (prefix == "val" and (batch_idx % 100 == 0)):
wandb_logger = self.logger.experiment
pred_label = self.labels[preds[0]]
true_label = self.labels[labels[0]]
timestamp = time.time()
step_id = f"gen_{self.global_step}-{timestamp}"
wandb_logger.log(
{
step_id: wandb.Table(
columns=["timestamp", "prefix", "pred_label", "true_label"],
data=[[timestamp, prefix, pred_label, true_label]],
)
}
)
print(f"Example {prefix} {batch_idx} {self.global_step}: Prediction: {pred_label}, Target: {true_label}")
return loss
def training_step(self, batch, batch_idx):
"""Perform a training step."""
return self._step(prefix="train", batch_idx=batch_idx, batch=batch)
def validation_step(self, batch, batch_idx):
"""Perform a validation step."""
return self._step(prefix="val", batch_idx=batch_idx, batch=batch)
def test_step(self, batch, batch_idx):
"""Perform a test step."""
return self._step(prefix="test", batch_idx=batch_idx, batch=batch)
def configure_optimizers(self):
"""Configure optimizers and learning rate schedulers."""
# Only include parameters that require gradients
classifier_params = [
{
"params": self.dna_model.classifier.parameters(),
"lr": self.hparams.learning_rate,
},
{
"params": self.dna_model.pooler.parameters(),
"lr": self.hparams.learning_rate,
}
]
dna_model_params = [
{
"params": self.dna_model_params,
"lr": self.hparams.learning_rate * 0.1,
},
]
if self.hparams.train_just_classifier:
# Only train classifier parameters
optimizer = AdamW(
classifier_params,
weight_decay=self.hparams.weight_decay,
)
else:
# Train both DNA model and classifier with different learning rates
optimizer = AdamW(
classifier_params + dna_model_params,
weight_decay=self.hparams.weight_decay,
)
# Get total steps from trainer's estimated stepping batches
total_steps = self.trainer.estimated_stepping_batches
warmup_steps = int(0.1 * total_steps)
# Create scheduler
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def load_dataset(self):
"""Load the dataset based on the dataset type."""
if self.hparams.dataset_type == "kegg":
dataset = load_dataset(self.hparams.kegg_data_dir_huggingface)
if self.hparams.truncate_dna_per_side:
dataset = dataset.map(
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
)
labels = []
for split, data in dataset.items():
labels.extend(data["answer"])
labels = list(set(labels))
elif self.hparams.dataset_type == "variant_effect_coding":
dataset = load_dataset("wanglab/bioR_tasks", "variant_effect_coding")
dataset = dataset.map(clean_variant_effect_example)
if self.hparams.truncate_dna_per_side:
dataset = dataset.map(
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
)
labels = []
for split, data in dataset.items():
labels.extend(data["answer"])
labels = sorted(list(set(labels)))
elif self.hparams.dataset_type == "variant_effect_non_snv":
dataset = load_dataset("wanglab/bioR_tasks", "task5_variant_effect_non_snv")
dataset = dataset.rename_column("mutated_sequence", "variant_sequence")
dataset = dataset.map(clean_variant_effect_example)
if self.hparams.truncate_dna_per_side:
dataset = dataset.map(
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
)
labels = []
for split, data in dataset.items():
labels.extend(data["answer"])
labels = sorted(list(set(labels)))
else:
raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}")
print(f"Dataset:\n{dataset}\nLabels:\n{labels}\nNumber of labels:{len(labels)}")
return dataset, labels
def train_dataloader(self):
"""Create and return the training DataLoader."""
if self.hparams.dataset_type == "kegg":
train_dataset = self.dataset["train"]
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
elif self.hparams.dataset_type == "variant_effect_coding":
train_dataset = self.dataset["train"]
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
elif self.hparams.dataset_type == "variant_effect_non_snv":
train_dataset = self.dataset["train"]
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
else:
raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}")
return DataLoader(
train_dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=self.hparams.num_workers,
persistent_workers=True,
)
def val_dataloader(self):
"""Create and return the training DataLoader."""
if self.hparams.dataset_type == "kegg":
if self.hparams.merge_val_test_set:
val_dataset = concatenate_datasets([self.dataset['test'], self.dataset['val']])
else:
val_dataset = self.dataset["val"]
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
elif self.hparams.dataset_type == "variant_effect_coding":
val_dataset = self.dataset["test"]
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
elif self.hparams.dataset_type == "variant_effect_non_snv":
val_dataset = self.dataset["test"]
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
else:
raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}")
return DataLoader(
val_dataset,
batch_size=self.hparams.batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=self.hparams.num_workers,
persistent_workers=True,
)
def test_dataloader(self):
"""Create and return the test DataLoader."""
return self.val_dataloader()
def main(args):
"""Main function to run the training process."""
# Set random seed and environment variables
pl.seed_everything(args.seed)
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")
# Initialize model
model = DNAClassifierModelTrainer(args)
# Setup directories
run_name = f"{args.wandb_project}-{args.dataset_type}-{args.dna_model_name.split('/')[-1]}"
args.checkpoint_dir = f"{args.checkpoint_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}"
args.output_dir = f"{args.output_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}"
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)
# Setup callbacks
callbacks = [
ModelCheckpoint(
dirpath=args.checkpoint_dir,
filename=f"{run_name}-" + "{epoch:02d}-{val_loss_epoch:.4f}",
save_top_k=2,
monitor="val_acc_epoch",
mode="max",
save_last=True,
),
LearningRateMonitor(logging_interval="step"),
]
# Setup logger
is_resuming = args.ckpt_path is not None
logger = WandbLogger(
project=args.wandb_project,
entity=args.wandb_entity,
save_dir=args.log_dir,
name=run_name,
resume="allow" if is_resuming else None, # Allow resuming existing run
)
# Initialize trainer
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices=args.num_gpus,
strategy=(
"ddp"
if args.strategy == "ddp"
else DeepSpeedStrategy(stage=2, offload_optimizer=False, allgather_bucket_size=5e8, reduce_bucket_size=5e8)
),
precision="bf16-mixed",
callbacks=callbacks,
logger=logger,
deterministic=False,
enable_checkpointing=True,
enable_progress_bar=True,
enable_model_summary=True,
log_every_n_steps=5,
accumulate_grad_batches=args.gradient_accumulation_steps,
gradient_clip_val=1.0,
val_check_interval=1 / 3,
)
# Train model
trainer.fit(model, ckpt_path=args.ckpt_path)
trainer.test(model, ckpt_path=args.ckpt_path if args.ckpt_path else "best")
# Save final model
final_model_path = os.path.join(args.output_dir, "final_model")
torch.save(model.dna_model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description="Train DNA Classifier")
# Model parameters
parser.add_argument(
"--dna_model_name",
type=str,
default="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species",
)
parser.add_argument("--cache_dir", type=str, default="/model-weights")
parser.add_argument("--max_length_dna", type=int, default=1024)
parser.add_argument("--dna_is_evo2", type=bool, default=False)
parser.add_argument("--dna_embedding_layer", type=str, default=None)
# Training parameters
parser.add_argument("--strategy", type=str, default="ddp")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--max_epochs", type=int, default=5)
parser.add_argument("--max_steps", type=int, default=-1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument("--train_just_classifier", type=bool, default=True)
parser.add_argument("--dataset_type", type=str, choices=["kegg", "variant_effect_coding", "variant_effect_non_snv"], default="kegg")
parser.add_argument("--kegg_data_dir_huggingface", type=str, default="wanglab/kegg")
parser.add_argument("--truncate_dna_per_side", type=int, default=0)
# Output parameters
parser.add_argument("--output_dir", type=str, default="dna_classifier_output")
parser.add_argument(
"--checkpoint_dir", type=str, default="checkpoints"
)
parser.add_argument("--ckpt_path", type=str, default=None)
parser.add_argument("--log_dir", type=str, default="logs")
parser.add_argument("--wandb_project", type=str, default="dna-only-nt-500m")
parser.add_argument("--wandb_entity", type=str, default="adibvafa")
parser.add_argument("--merge_val_test_set", type=bool, default=True)
# Other parameters
parser.add_argument("--seed", type=int, default=23)
args = parser.parse_args()
main(args)