danielle-miller-sayag's picture
Upload train.py with huggingface_hub
939f41f verified
import argparse
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import EvalPrediction
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from modeling_virtual_cell_distil import (
VirtualCellDistilConfig,
VirtualCellDistilForSequenceClassification,
)
@dataclass
class BulkCollator:
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
return {
"input_ids": torch.stack([
torch.tensor(f["bulk_expression"], dtype=torch.float32) for f in features
]),
"labels": torch.tensor([f["labels"] for f in features], dtype=torch.long),
}
def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
logits = eval_pred.predictions
if isinstance(logits, tuple):
logits = logits[0]
labels = eval_pred.label_ids
preds = np.argmax(logits, axis=1)
return {
"accuracy": accuracy_score(labels, preds),
"f1_macro": f1_score(labels, preds, average="macro", zero_division=0),
"precision": precision_score(labels, preds, average="macro", zero_division=0),
"recall": recall_score(labels, preds, average="macro", zero_division=0),
}
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--dataset_path", required=True,
help="HF dataset ID or local path with train (and optionally validation) splits")
p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-distil-bulk")
p.add_argument("--hf_token", default=None)
p.add_argument("--output_dir", default="./vc_distil_output")
p.add_argument("--num_classes", type=int, default=None)
p.add_argument("--freeze_encoder", action="store_true",
help="Freeze the pretrained encoder and train the classification head only")
p.add_argument("--num_train_epochs", type=int, default=15)
p.add_argument("--per_device_train_batch_size", type=int, default=32)
p.add_argument("--per_device_eval_batch_size", type=int, default=32)
p.add_argument("--learning_rate", type=float, default=1e-4)
p.add_argument("--weight_decay", type=float, default=0.05)
p.add_argument("--warmup_ratio", type=float, default=0.1)
p.add_argument("--lr_scheduler_type", default="cosine")
p.add_argument("--patience", type=int, default=5)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--prefetch_factor", type=int, default=2)
p.add_argument("--wandb_project", default=None)
p.add_argument("--run_name", default=None)
return p.parse_args()
def main():
args = parse_args()
if os.path.isdir(args.dataset_path):
ds = DatasetDict.load_from_disk(args.dataset_path)
else:
ds = load_dataset(args.dataset_path,
num_proc=args.num_workers or None,
token=args.hf_token or True)
train_ds = ds["train"]
val_ds: Optional[object] = ds.get("validation")
hf_kwargs = {"trust_remote_code": True}
if args.hf_token:
hf_kwargs["token"] = args.hf_token
config = VirtualCellDistilConfig.from_pretrained(args.model_name_or_path, **hf_kwargs)
if args.num_classes is not None:
config.num_labels = args.num_classes
config.id2label = {str(i): str(i) for i in range(args.num_classes)}
config.label2id = {str(i): i for i in range(args.num_classes)}
model = VirtualCellDistilForSequenceClassification.from_pretrained(
args.model_name_or_path,
config=config,
ignore_mismatched_sizes=True,
**hf_kwargs,
)
if args.freeze_encoder:
for param in model.encoder.parameters():
param.requires_grad = False
if args.wandb_project:
os.environ["WANDB_PROJECT"] = args.wandb_project
has_val = val_ds is not None
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.lr_scheduler_type,
eval_strategy="epoch" if has_val else "no",
save_strategy="epoch",
load_best_model_at_end=has_val,
metric_for_best_model="eval_loss" if has_val else None,
greater_is_better=False,
report_to="wandb" if args.wandb_project else "none",
run_name=args.run_name,
dataloader_num_workers=args.num_workers,
dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # prefetch batches in background for CPU loading speedup
remove_unused_columns=False,
)
callbacks = [EarlyStoppingCallback(args.patience)] if has_val else []
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=BulkCollator(),
compute_metrics=compute_metrics if has_val else None,
callbacks=callbacks,
)
trainer.train()
trainer.save_model(args.output_dir)
if __name__ == "__main__":
main()