Feature Extraction
Transformers
Safetensors
virtual_cell_patient
biology
genomics
single-cell-rna-seq
patient-classification
custom_code
Instructions to use ConvergeBio/virtual-cell-patient with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ConvergeBio/virtual-cell-patient with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ConvergeBio/virtual-cell-patient", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ConvergeBio/virtual-cell-patient", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import torch | |
| from datasets import DatasetDict, load_dataset | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
| from transformers import EarlyStoppingCallback, Trainer, TrainingArguments | |
| from transformers.trainer_utils import EvalPrediction | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel | |
| class PatientCollator: | |
| def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "input_ids": torch.stack([ | |
| torch.tensor(f["input_ids"], dtype=torch.float32) for f in features | |
| ]), | |
| "attention_mask": torch.stack([ | |
| torch.tensor(f["attention_mask"], dtype=torch.bool) for f in features | |
| ]), | |
| "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long), | |
| "entity_id": torch.tensor([f["entity_id"] for f in features], dtype=torch.long), | |
| } | |
| def _patient_predictions(logits: np.ndarray, entity_ids: np.ndarray): | |
| """Average softmax probabilities across augmented views, one row per patient.""" | |
| entity_ids = np.asarray(entity_ids).astype(str) | |
| unique = np.unique(entity_ids) | |
| agg = [] | |
| for eid in unique: | |
| views = logits[entity_ids == eid] | |
| exp = np.exp(views - np.max(views, axis=1, keepdims=True)) | |
| agg.append(np.mean(exp / exp.sum(axis=1, keepdims=True), axis=0)) | |
| return np.array(agg), unique | |
| def _clf_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str) -> Dict[str, float]: | |
| return { | |
| f"{prefix}accuracy": accuracy_score(y_true, y_pred), | |
| f"{prefix}f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0), | |
| f"{prefix}precision": precision_score(y_true, y_pred, average="macro", zero_division=0), | |
| f"{prefix}recall": recall_score(y_true, y_pred, average="macro", zero_division=0), | |
| } | |
| def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]: | |
| logits_with_entity = eval_pred.predictions # (N, num_classes + 1) | |
| logits = logits_with_entity[:, :-1] | |
| entity_ids = logits_with_entity[:, -1].astype(int) | |
| labels = eval_pred.label_ids | |
| metrics = _clf_metrics(labels, np.argmax(logits, axis=1), "per_view/") | |
| patient_preds, unique_entities = _patient_predictions(logits, entity_ids) | |
| patient_labels = np.array([ | |
| labels[np.where(entity_ids == int(eid))[0][0]] | |
| for eid in unique_entities | |
| ]) | |
| metrics.update(_clf_metrics(patient_labels, np.argmax(patient_preds, axis=1), "patient/")) | |
| return metrics | |
| class PatientTrainer(Trainer): | |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): | |
| outputs = model(**inputs) | |
| return (outputs.loss, outputs) if return_outputs else outputs.loss | |
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): | |
| entity_id = inputs.pop("entity_id") | |
| loss, logits, labels = super().prediction_step( | |
| model, inputs, prediction_loss_only, ignore_keys=ignore_keys | |
| ) | |
| if logits is not None: | |
| entity_col = entity_id.float().unsqueeze(1).to(logits.device) | |
| logits = torch.cat([logits, entity_col], dim=1) | |
| return loss, logits, labels | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--dataset_path", required=True, | |
| help="HF dataset ID or local path with train (and optionally validation) splits") | |
| p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-patient") | |
| p.add_argument("--hf_token", default=None) | |
| p.add_argument("--output_dir", default="./vc_output") | |
| p.add_argument("--from_scratch", action="store_true") | |
| p.add_argument("--freeze_embedder", action="store_true") | |
| p.add_argument("--num_classes", type=int, default=None) | |
| p.add_argument("--num_train_epochs", type=int, default=15) | |
| p.add_argument("--per_device_train_batch_size", type=int, default=32) | |
| p.add_argument("--per_device_eval_batch_size", type=int, default=32) | |
| p.add_argument("--learning_rate", type=float, default=1e-4) | |
| p.add_argument("--weight_decay", type=float, default=0.05) | |
| p.add_argument("--warmup_ratio", type=float, default=0.1) | |
| p.add_argument("--lr_scheduler_type", default="cosine") | |
| p.add_argument("--patience", type=int, default=5) | |
| p.add_argument("--num_workers", type=int, default=4) | |
| p.add_argument("--prefetch_factor", type=int, default=2) | |
| p.add_argument("--wandb_project", default=None) | |
| p.add_argument("--run_name", default=None) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if os.path.isdir(args.dataset_path): | |
| ds = DatasetDict.load_from_disk(args.dataset_path) | |
| else: | |
| ds = load_dataset(args.dataset_path, | |
| num_proc=args.num_workers or None, | |
| token=args.hf_token or True) | |
| train_ds = ds["train"] | |
| val_ds: Optional[object] = ds.get("validation") | |
| hf_kwargs = {"trust_remote_code": True} | |
| if args.hf_token: | |
| hf_kwargs["token"] = args.hf_token | |
| config = VirtualCellPatientConfig.from_pretrained(args.model_name_or_path, **hf_kwargs) | |
| if args.num_classes is not None: | |
| config.num_classes = args.num_classes | |
| config.id2label = {str(i): str(i) for i in range(args.num_classes)} | |
| config.label2id = {str(i): i for i in range(args.num_classes)} | |
| if args.from_scratch: | |
| model = VirtualCellPatientModel(config) | |
| else: | |
| model = VirtualCellPatientModel.from_pretrained( | |
| args.model_name_or_path, config=config, | |
| ignore_mismatched_sizes=args.num_classes is not None, | |
| **hf_kwargs, | |
| ) | |
| if args.freeze_embedder: | |
| for param in model.patient_embedder.parameters(): | |
| param.requires_grad = False | |
| if args.wandb_project: | |
| os.environ["WANDB_PROJECT"] = args.wandb_project | |
| has_val = val_ds is not None | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.num_train_epochs, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| per_device_eval_batch_size=args.per_device_eval_batch_size, | |
| learning_rate=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| warmup_ratio=args.warmup_ratio, | |
| lr_scheduler_type=args.lr_scheduler_type, | |
| eval_strategy="epoch" if has_val else "no", | |
| save_strategy="epoch", | |
| load_best_model_at_end=has_val, | |
| metric_for_best_model="eval_loss" if has_val else None, | |
| greater_is_better=False, | |
| report_to="wandb" if args.wandb_project else "none", | |
| run_name=args.run_name, | |
| dataloader_num_workers=args.num_workers, | |
| dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # prefetch batches in background for CPU loading speedup | |
| remove_unused_columns=False, | |
| ) | |
| callbacks = [EarlyStoppingCallback(args.patience)] if has_val else [] | |
| trainer = PatientTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| data_collator=PatientCollator(), | |
| compute_metrics=compute_metrics if has_val else None, | |
| callbacks=callbacks, | |
| ) | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |