import argparse import os import sys from dataclasses import dataclass from typing import Dict, List, Optional import numpy as np import torch from datasets import DatasetDict, load_dataset from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_utils import EvalPrediction sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from modeling_virtual_cell_distil import ( VirtualCellDistilConfig, VirtualCellDistilForSequenceClassification, ) @dataclass class BulkCollator: def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]: return { "input_ids": torch.stack([ torch.tensor(f["bulk_expression"], dtype=torch.float32) for f in features ]), "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long), } def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]: logits = eval_pred.predictions if isinstance(logits, tuple): logits = logits[0] labels = eval_pred.label_ids preds = np.argmax(logits, axis=1) return { "accuracy": accuracy_score(labels, preds), "f1_macro": f1_score(labels, preds, average="macro", zero_division=0), "precision": precision_score(labels, preds, average="macro", zero_division=0), "recall": recall_score(labels, preds, average="macro", zero_division=0), } def parse_args(): p = argparse.ArgumentParser() p.add_argument("--dataset_path", required=True, help="HF dataset ID or local path with train (and optionally validation) splits") p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-distil-bulk") p.add_argument("--hf_token", default=None) p.add_argument("--output_dir", default="./vc_distil_output") p.add_argument("--num_classes", type=int, default=None) p.add_argument("--freeze_encoder", action="store_true", help="Freeze the pretrained encoder and train the classification head only") p.add_argument("--num_train_epochs", type=int, default=15) p.add_argument("--per_device_train_batch_size", type=int, default=32) p.add_argument("--per_device_eval_batch_size", type=int, default=32) p.add_argument("--learning_rate", type=float, default=1e-4) p.add_argument("--weight_decay", type=float, default=0.05) p.add_argument("--warmup_ratio", type=float, default=0.1) p.add_argument("--lr_scheduler_type", default="cosine") p.add_argument("--patience", type=int, default=5) p.add_argument("--num_workers", type=int, default=4) p.add_argument("--prefetch_factor", type=int, default=2) p.add_argument("--wandb_project", default=None) p.add_argument("--run_name", default=None) return p.parse_args() def main(): args = parse_args() if os.path.isdir(args.dataset_path): ds = DatasetDict.load_from_disk(args.dataset_path) else: ds = load_dataset(args.dataset_path, num_proc=args.num_workers or None, token=args.hf_token or True) train_ds = ds["train"] val_ds: Optional[object] = ds.get("validation") hf_kwargs = {"trust_remote_code": True} if args.hf_token: hf_kwargs["token"] = args.hf_token config = VirtualCellDistilConfig.from_pretrained(args.model_name_or_path, **hf_kwargs) if args.num_classes is not None: config.num_labels = args.num_classes config.id2label = {str(i): str(i) for i in range(args.num_classes)} config.label2id = {str(i): i for i in range(args.num_classes)} model = VirtualCellDistilForSequenceClassification.from_pretrained( args.model_name_or_path, config=config, ignore_mismatched_sizes=True, **hf_kwargs, ) if args.freeze_encoder: for param in model.encoder.parameters(): param.requires_grad = False if args.wandb_project: os.environ["WANDB_PROJECT"] = args.wandb_project has_val = val_ds is not None training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.lr_scheduler_type, eval_strategy="epoch" if has_val else "no", save_strategy="epoch", load_best_model_at_end=has_val, metric_for_best_model="eval_loss" if has_val else None, greater_is_better=False, report_to="wandb" if args.wandb_project else "none", run_name=args.run_name, dataloader_num_workers=args.num_workers, dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # prefetch batches in background for CPU loading speedup remove_unused_columns=False, ) callbacks = [EarlyStoppingCallback(args.patience)] if has_val else [] trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=BulkCollator(), compute_metrics=compute_metrics if has_val else None, callbacks=callbacks, ) trainer.train() trainer.save_model(args.output_dir) if __name__ == "__main__": main()