danielle-miller-sayag's picture
Upload train.py with huggingface_hub
355ad37 verified
import argparse
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import EvalPrediction
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel
@dataclass
class PatientCollator:
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
return {
"input_ids": torch.stack([
torch.tensor(f["input_ids"], dtype=torch.float32) for f in features
]),
"attention_mask": torch.stack([
torch.tensor(f["attention_mask"], dtype=torch.bool) for f in features
]),
"labels": torch.tensor([f["labels"] for f in features], dtype=torch.long),
"entity_id": torch.tensor([f["entity_id"] for f in features], dtype=torch.long),
}
def _patient_predictions(logits: np.ndarray, entity_ids: np.ndarray):
"""Average softmax probabilities across augmented views, one row per patient."""
entity_ids = np.asarray(entity_ids).astype(str)
unique = np.unique(entity_ids)
agg = []
for eid in unique:
views = logits[entity_ids == eid]
exp = np.exp(views - np.max(views, axis=1, keepdims=True))
agg.append(np.mean(exp / exp.sum(axis=1, keepdims=True), axis=0))
return np.array(agg), unique
def _clf_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str) -> Dict[str, float]:
return {
f"{prefix}accuracy": accuracy_score(y_true, y_pred),
f"{prefix}f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0),
f"{prefix}precision": precision_score(y_true, y_pred, average="macro", zero_division=0),
f"{prefix}recall": recall_score(y_true, y_pred, average="macro", zero_division=0),
}
def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
logits_with_entity = eval_pred.predictions # (N, num_classes + 1)
logits = logits_with_entity[:, :-1]
entity_ids = logits_with_entity[:, -1].astype(int)
labels = eval_pred.label_ids
metrics = _clf_metrics(labels, np.argmax(logits, axis=1), "per_view/")
patient_preds, unique_entities = _patient_predictions(logits, entity_ids)
patient_labels = np.array([
labels[np.where(entity_ids == int(eid))[0][0]]
for eid in unique_entities
])
metrics.update(_clf_metrics(patient_labels, np.argmax(patient_preds, axis=1), "patient/"))
return metrics
class PatientTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
outputs = model(**inputs)
return (outputs.loss, outputs) if return_outputs else outputs.loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
entity_id = inputs.pop("entity_id")
loss, logits, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
if logits is not None:
entity_col = entity_id.float().unsqueeze(1).to(logits.device)
logits = torch.cat([logits, entity_col], dim=1)
return loss, logits, labels
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--dataset_path", required=True,
help="HF dataset ID or local path with train (and optionally validation) splits")
p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-patient")
p.add_argument("--hf_token", default=None)
p.add_argument("--output_dir", default="./vc_output")
p.add_argument("--from_scratch", action="store_true")
p.add_argument("--freeze_embedder", action="store_true")
p.add_argument("--num_classes", type=int, default=None)
p.add_argument("--num_train_epochs", type=int, default=15)
p.add_argument("--per_device_train_batch_size", type=int, default=32)
p.add_argument("--per_device_eval_batch_size", type=int, default=32)
p.add_argument("--learning_rate", type=float, default=1e-4)
p.add_argument("--weight_decay", type=float, default=0.05)
p.add_argument("--warmup_ratio", type=float, default=0.1)
p.add_argument("--lr_scheduler_type", default="cosine")
p.add_argument("--patience", type=int, default=5)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--prefetch_factor", type=int, default=2)
p.add_argument("--wandb_project", default=None)
p.add_argument("--run_name", default=None)
return p.parse_args()
def main():
args = parse_args()
if os.path.isdir(args.dataset_path):
ds = DatasetDict.load_from_disk(args.dataset_path)
else:
ds = load_dataset(args.dataset_path,
num_proc=args.num_workers or None,
token=args.hf_token or True)
train_ds = ds["train"]
val_ds: Optional[object] = ds.get("validation")
hf_kwargs = {"trust_remote_code": True}
if args.hf_token:
hf_kwargs["token"] = args.hf_token
config = VirtualCellPatientConfig.from_pretrained(args.model_name_or_path, **hf_kwargs)
if args.num_classes is not None:
config.num_classes = args.num_classes
config.id2label = {str(i): str(i) for i in range(args.num_classes)}
config.label2id = {str(i): i for i in range(args.num_classes)}
if args.from_scratch:
model = VirtualCellPatientModel(config)
else:
model = VirtualCellPatientModel.from_pretrained(
args.model_name_or_path, config=config,
ignore_mismatched_sizes=args.num_classes is not None,
**hf_kwargs,
)
if args.freeze_embedder:
for param in model.patient_embedder.parameters():
param.requires_grad = False
if args.wandb_project:
os.environ["WANDB_PROJECT"] = args.wandb_project
has_val = val_ds is not None
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.lr_scheduler_type,
eval_strategy="epoch" if has_val else "no",
save_strategy="epoch",
load_best_model_at_end=has_val,
metric_for_best_model="eval_loss" if has_val else None,
greater_is_better=False,
report_to="wandb" if args.wandb_project else "none",
run_name=args.run_name,
dataloader_num_workers=args.num_workers,
dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # prefetch batches in background for CPU loading speedup
remove_unused_columns=False,
)
callbacks = [EarlyStoppingCallback(args.patience)] if has_val else []
trainer = PatientTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=PatientCollator(),
compute_metrics=compute_metrics if has_val else None,
callbacks=callbacks,
)
trainer.train()
trainer.save_model(args.output_dir)
if __name__ == "__main__":
main()