import argparse import os import sys from dataclasses import dataclass from typing import Dict, List, Optional import numpy as np import torch from datasets import DatasetDict, load_dataset from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_utils import EvalPrediction sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel @dataclass class PatientCollator: def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]: return { "input_ids": torch.stack([ torch.tensor(f["input_ids"], dtype=torch.float32) for f in features ]), "attention_mask": torch.stack([ torch.tensor(f["attention_mask"], dtype=torch.bool) for f in features ]), "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long), "entity_id": torch.tensor([f["entity_id"] for f in features], dtype=torch.long), } def _patient_predictions(logits: np.ndarray, entity_ids: np.ndarray): """Average softmax probabilities across augmented views, one row per patient.""" entity_ids = np.asarray(entity_ids).astype(str) unique = np.unique(entity_ids) agg = [] for eid in unique: views = logits[entity_ids == eid] exp = np.exp(views - np.max(views, axis=1, keepdims=True)) agg.append(np.mean(exp / exp.sum(axis=1, keepdims=True), axis=0)) return np.array(agg), unique def _clf_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str) -> Dict[str, float]: return { f"{prefix}accuracy": accuracy_score(y_true, y_pred), f"{prefix}f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0), f"{prefix}precision": precision_score(y_true, y_pred, average="macro", zero_division=0), f"{prefix}recall": recall_score(y_true, y_pred, average="macro", zero_division=0), } def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]: logits_with_entity = eval_pred.predictions # (N, num_classes + 1) logits = logits_with_entity[:, :-1] entity_ids = logits_with_entity[:, -1].astype(int) labels = eval_pred.label_ids metrics = _clf_metrics(labels, np.argmax(logits, axis=1), "per_view/") patient_preds, unique_entities = _patient_predictions(logits, entity_ids) patient_labels = np.array([ labels[np.where(entity_ids == int(eid))[0][0]] for eid in unique_entities ]) metrics.update(_clf_metrics(patient_labels, np.argmax(patient_preds, axis=1), "patient/")) return metrics class PatientTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): outputs = model(**inputs) return (outputs.loss, outputs) if return_outputs else outputs.loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): entity_id = inputs.pop("entity_id") loss, logits, labels = super().prediction_step( model, inputs, prediction_loss_only, ignore_keys=ignore_keys ) if logits is not None: entity_col = entity_id.float().unsqueeze(1).to(logits.device) logits = torch.cat([logits, entity_col], dim=1) return loss, logits, labels def parse_args(): p = argparse.ArgumentParser() p.add_argument("--dataset_path", required=True, help="HF dataset ID or local path with train (and optionally validation) splits") p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-patient") p.add_argument("--hf_token", default=None) p.add_argument("--output_dir", default="./vc_output") p.add_argument("--from_scratch", action="store_true") p.add_argument("--freeze_embedder", action="store_true") p.add_argument("--num_classes", type=int, default=None) p.add_argument("--num_train_epochs", type=int, default=15) p.add_argument("--per_device_train_batch_size", type=int, default=32) p.add_argument("--per_device_eval_batch_size", type=int, default=32) p.add_argument("--learning_rate", type=float, default=1e-4) p.add_argument("--weight_decay", type=float, default=0.05) p.add_argument("--warmup_ratio", type=float, default=0.1) p.add_argument("--lr_scheduler_type", default="cosine") p.add_argument("--patience", type=int, default=5) p.add_argument("--num_workers", type=int, default=4) p.add_argument("--prefetch_factor", type=int, default=2) p.add_argument("--wandb_project", default=None) p.add_argument("--run_name", default=None) return p.parse_args() def main(): args = parse_args() if os.path.isdir(args.dataset_path): ds = DatasetDict.load_from_disk(args.dataset_path) else: ds = load_dataset(args.dataset_path, num_proc=args.num_workers or None, token=args.hf_token or True) train_ds = ds["train"] val_ds: Optional[object] = ds.get("validation") hf_kwargs = {"trust_remote_code": True} if args.hf_token: hf_kwargs["token"] = args.hf_token config = VirtualCellPatientConfig.from_pretrained(args.model_name_or_path, **hf_kwargs) if args.num_classes is not None: config.num_classes = args.num_classes config.id2label = {str(i): str(i) for i in range(args.num_classes)} config.label2id = {str(i): i for i in range(args.num_classes)} if args.from_scratch: model = VirtualCellPatientModel(config) else: model = VirtualCellPatientModel.from_pretrained( args.model_name_or_path, config=config, ignore_mismatched_sizes=args.num_classes is not None, **hf_kwargs, ) if args.freeze_embedder: for param in model.patient_embedder.parameters(): param.requires_grad = False if args.wandb_project: os.environ["WANDB_PROJECT"] = args.wandb_project has_val = val_ds is not None training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.lr_scheduler_type, eval_strategy="epoch" if has_val else "no", save_strategy="epoch", load_best_model_at_end=has_val, metric_for_best_model="eval_loss" if has_val else None, greater_is_better=False, report_to="wandb" if args.wandb_project else "none", run_name=args.run_name, dataloader_num_workers=args.num_workers, dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # prefetch batches in background for CPU loading speedup remove_unused_columns=False, ) callbacks = [EarlyStoppingCallback(args.patience)] if has_val else [] trainer = PatientTrainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=PatientCollator(), compute_metrics=compute_metrics if has_val else None, callbacks=callbacks, ) trainer.train() trainer.save_model(args.output_dir) if __name__ == "__main__": main()