Transformers
Safetensors
virtual_cell_distil
biology
genomics
bulk-rna-seq
patient-embedding
custom_code
Instructions to use ConvergeBio/virtual-cell-distil-bulk with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ConvergeBio/virtual-cell-distil-bulk with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ConvergeBio/virtual-cell-distil-bulk", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import torch | |
| from datasets import DatasetDict, load_dataset | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
| from transformers import EarlyStoppingCallback, Trainer, TrainingArguments | |
| from transformers.trainer_utils import EvalPrediction | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from modeling_virtual_cell_distil import ( | |
| VirtualCellDistilConfig, | |
| VirtualCellDistilForSequenceClassification, | |
| ) | |
| class BulkCollator: | |
| def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "input_ids": torch.stack([ | |
| torch.tensor(f["bulk_expression"], dtype=torch.float32) for f in features | |
| ]), | |
| "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long), | |
| } | |
| def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]: | |
| logits = eval_pred.predictions | |
| if isinstance(logits, tuple): | |
| logits = logits[0] | |
| labels = eval_pred.label_ids | |
| preds = np.argmax(logits, axis=1) | |
| return { | |
| "accuracy": accuracy_score(labels, preds), | |
| "f1_macro": f1_score(labels, preds, average="macro", zero_division=0), | |
| "precision": precision_score(labels, preds, average="macro", zero_division=0), | |
| "recall": recall_score(labels, preds, average="macro", zero_division=0), | |
| } | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--dataset_path", required=True, | |
| help="HF dataset ID or local path with train (and optionally validation) splits") | |
| p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-distil-bulk") | |
| p.add_argument("--hf_token", default=None) | |
| p.add_argument("--output_dir", default="./vc_distil_output") | |
| p.add_argument("--num_classes", type=int, default=None) | |
| p.add_argument("--freeze_encoder", action="store_true", | |
| help="Freeze the pretrained encoder and train the classification head only") | |
| p.add_argument("--num_train_epochs", type=int, default=15) | |
| p.add_argument("--per_device_train_batch_size", type=int, default=32) | |
| p.add_argument("--per_device_eval_batch_size", type=int, default=32) | |
| p.add_argument("--learning_rate", type=float, default=1e-4) | |
| p.add_argument("--weight_decay", type=float, default=0.05) | |
| p.add_argument("--warmup_ratio", type=float, default=0.1) | |
| p.add_argument("--lr_scheduler_type", default="cosine") | |
| p.add_argument("--patience", type=int, default=5) | |
| p.add_argument("--num_workers", type=int, default=4) | |
| p.add_argument("--prefetch_factor", type=int, default=2) | |
| p.add_argument("--wandb_project", default=None) | |
| p.add_argument("--run_name", default=None) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if os.path.isdir(args.dataset_path): | |
| ds = DatasetDict.load_from_disk(args.dataset_path) | |
| else: | |
| ds = load_dataset(args.dataset_path, | |
| num_proc=args.num_workers or None, | |
| token=args.hf_token or True) | |
| train_ds = ds["train"] | |
| val_ds: Optional[object] = ds.get("validation") | |
| hf_kwargs = {"trust_remote_code": True} | |
| if args.hf_token: | |
| hf_kwargs["token"] = args.hf_token | |
| config = VirtualCellDistilConfig.from_pretrained(args.model_name_or_path, **hf_kwargs) | |
| if args.num_classes is not None: | |
| config.num_labels = args.num_classes | |
| config.id2label = {str(i): str(i) for i in range(args.num_classes)} | |
| config.label2id = {str(i): i for i in range(args.num_classes)} | |
| model = VirtualCellDistilForSequenceClassification.from_pretrained( | |
| args.model_name_or_path, | |
| config=config, | |
| ignore_mismatched_sizes=True, | |
| **hf_kwargs, | |
| ) | |
| if args.freeze_encoder: | |
| for param in model.encoder.parameters(): | |
| param.requires_grad = False | |
| if args.wandb_project: | |
| os.environ["WANDB_PROJECT"] = args.wandb_project | |
| has_val = val_ds is not None | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.num_train_epochs, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| per_device_eval_batch_size=args.per_device_eval_batch_size, | |
| learning_rate=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| warmup_ratio=args.warmup_ratio, | |
| lr_scheduler_type=args.lr_scheduler_type, | |
| eval_strategy="epoch" if has_val else "no", | |
| save_strategy="epoch", | |
| load_best_model_at_end=has_val, | |
| metric_for_best_model="eval_loss" if has_val else None, | |
| greater_is_better=False, | |
| report_to="wandb" if args.wandb_project else "none", | |
| run_name=args.run_name, | |
| dataloader_num_workers=args.num_workers, | |
| dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # prefetch batches in background for CPU loading speedup | |
| remove_unused_columns=False, | |
| ) | |
| callbacks = [EarlyStoppingCallback(args.patience)] if has_val else [] | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| data_collator=BulkCollator(), | |
| compute_metrics=compute_metrics if has_val else None, | |
| callbacks=callbacks, | |
| ) | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| if __name__ == "__main__": | |
| main() |