MuLGIT / mulgit /trainer.py
vedatonuryilmaz's picture
Upload mulgit/trainer.py with huggingface_hub
4d8ed56 verified
"""
HuggingFace Trainer Integration for MuLGIT
Wraps MuLGIT custom models to work with HuggingFace's Trainer API,
enabling:
- Push to Hub
- Trackio monitoring
- Mixed precision training
- Gradient accumulation
- Learning rate scheduling
- Checkpointing
Pattern: Custom model + HF Trainer via a custom forward-compatible wrapper
that accepts batch dicts and returns loss-compatible outputs.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Optional, Dict, Any, Tuple
import numpy as np
from transformers import (
Trainer,
TrainingArguments,
TrainerCallback,
PreTrainedModel,
PretrainedConfig,
)
import trackio
# ─── Custom HF-Style Config ─────────────────────────────────────────────────
class MuLGITHFConfig(PretrainedConfig):
"""HuggingFace-compatible config for MuLGIT models."""
model_type = "mulgit"
def __init__(
self,
dim_methylation: int = 20000,
dim_cnv: int = 20000,
dim_mrna: int = 20000,
dim_mirna: int = 2000,
latent_dim: int = 48,
dropout: float = 0.1058,
num_classes: int = 1,
**kwargs,
):
super().__init__(**kwargs)
self.dim_methylation = dim_methylation
self.dim_cnv = dim_cnv
self.dim_mrna = dim_mrna
self.dim_mirna = dim_mirna
self.latent_dim = latent_dim
self.dropout = dropout
self.num_classes = num_classes
# ─── HF-Compatible Model Wrapper ─────────────────────────────────────────────
class MuLGITForSurvival(PreTrainedModel):
"""
HuggingFace-compatible wrapper around MuLGIT that integrates with
the Trainer API for survival prediction.
Handles the custom forward pass (multi-omics input dict β†’ risk score)
and loss computation internally.
"""
config_class = MuLGITHFConfig
def __init__(self, config: MuLGITHFConfig):
super().__init__(config)
# Import here to avoid circular dependency
from .models import MuLGITModel
self.mulgit = MuLGITModel(
dim_methylation=config.dim_methylation,
dim_cnv=config.dim_cnv,
dim_mrna=config.dim_mrna,
dim_mirna=config.dim_mirna,
latent_dim=config.latent_dim,
dropout=config.dropout,
num_classes=config.num_classes,
)
self.loss_fn = None # set after init if needed
def forward(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
survival_times: Optional[torch.Tensor] = None,
event_observed: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Forward pass compatible with HF Trainer.
HF Trainer passes all batch keys as kwargs. We extract the omics
features and optional labels.
"""
outputs = self.mulgit(
methylation=methylation,
cnv=cnv,
mrna=mrna,
mirna=mirna,
)
# Compute loss during training
if survival_times is not None and event_observed is not None:
from .losses import cox_negative_log_likelihood
loss = cox_negative_log_likelihood(
outputs["risk"], survival_times, event_observed
)
outputs["loss"] = loss
return outputs
def _init_weights(self, module):
"""Initialize weights (called by HF)."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
# ─── Custom Trainer for Multi-Omics ──────────────────────────────────────────
class MuLGITTrainer(Trainer):
"""
Custom HF Trainer for MuLGIT multi-omics survival prediction.
Extends Trainer to:
- Accept custom batch format (multi-omics dicts)
- Compute Cox loss correctly
- Track survival-specific metrics (C-index)
"""
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
Compute Cox loss from multi-omics batch.
HF Trainer calls this internally. Our model forward pass returns
a dict with "loss" key if labels are provided.
"""
outputs = model(**inputs)
loss = outputs.get("loss", None)
if loss is None:
# If model didn't compute loss, do it here
from .losses import cox_negative_log_likelihood
loss = cox_negative_log_likelihood(
outputs["risk"],
inputs["survival_times"],
inputs["event_observed"],
)
outputs["loss"] = loss
return (loss, outputs) if return_outputs else loss
# ─── Trackio Callback for Survival Metrics ───────────────────────────────────
class SurvivalMetricsCallback(TrainerCallback):
"""
Logs survival-specific metrics (C-index, AUC) to Trackio.
Also fires alerts for training divergence or convergence.
"""
def __init__(self, eval_dataloader=None):
self.eval_dataloader = eval_dataloader
self.best_c_index = 0.0
self.loss_history = []
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log training metrics and fire alerts if needed."""
if logs is None:
return
loss = logs.get("loss", None)
if loss is not None:
self.loss_history.append(loss)
# Alert on loss spikes
if len(self.loss_history) > 10:
recent_mean = np.mean(self.loss_history[-10:])
if loss > recent_mean * 3:
trackio.alert(
title="Loss Spike Detected",
text=f"loss={loss:.4f} at step {state.global_step} β€” "
f"3x above recent mean {recent_mean:.4f}. "
f"Consider reducing learning rate.",
level="WARN",
)
# Alert on NaN
if np.isnan(loss):
trackio.alert(
title="NaN Loss",
text=f"NaN loss at step {state.global_step}. "
f"Training diverged. Reduce learning rate by 10x and restart.",
level="ERROR",
)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
"""Log eval metrics and fire convergence alerts."""
if metrics is None:
return
c_index = metrics.get("eval_c_index", None)
if c_index is not None:
if c_index > self.best_c_index:
self.best_c_index = c_index
trackio.alert(
title="New Best C-index",
text=f"C-index={c_index:.4f} at step {state.global_step}. "
f"New best model.",
level="INFO",
)
if c_index < 0.5:
trackio.alert(
title="C-index Below Random",
text=f"C-index={c_index:.4f} below 0.5. Model is not learning. "
f"Check data preprocessing or increase model capacity.",
level="WARN",
)
# ─── Training Script ─────────────────────────────────────────────────────────
def train_mulgit(
model: MuLGITForSurvival,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
output_dir: str = "./mulgit_output",
num_epochs: int = 100,
learning_rate: float = 5.8e-4,
batch_size: int = 256,
weight_decay: float = 0.00598,
push_to_hub: bool = True,
hub_model_id: Optional[str] = None,
trackio_project: str = "mulgit",
trackio_space_id: Optional[str] = None,
**kwargs,
) -> Trainer:
"""
Train MuLGIT model with HF Trainer + Trackio monitoring.
Args:
model: MuLGITForSurvival model instance
train_dataloader: training data
eval_dataloader: evaluation data (optional)
output_dir: where to save checkpoints
num_epochs: number of training epochs
learning_rate: from SeNMo paper (5.8e-4)
batch_size: from SeNMo paper (256)
weight_decay: from SeNMo paper (0.00598)
push_to_hub: whether to push model to Hub
hub_model_id: Hub repo ID
trackio_project: Trackio project name
trackio_space_id: Trackio Space for dashboard
"""
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size // 4, # adjust for multi-GPU
per_device_eval_batch_size=batch_size // 2,
gradient_accumulation_steps=4, # effective batch size = batch_size
learning_rate=learning_rate,
weight_decay=weight_decay,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
eval_strategy="epoch" if eval_dataloader else "no",
save_strategy="epoch",
save_total_limit=3,
load_best_model_at_end=True if eval_dataloader else False,
metric_for_best_model="c_index" if eval_dataloader else "loss",
greater_is_better=True,
bf16=True,
dataloader_drop_last=True,
report_to="trackio",
run_name=f"mulgit_lr{learning_rate}_bs{batch_size}",
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
disable_tqdm=True,
ddp_find_unused_parameters=False,
**kwargs,
)
# Setup Trackio
if trackio_project:
import os
os.environ["TRACKIO_PROJECT"] = trackio_project
if trackio_space_id:
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
# Create trainer
trainer = MuLGITTrainer(
model=model,
args=training_args,
train_dataset=train_dataloader.dataset,
eval_dataset=eval_dataloader.dataset if eval_dataloader else None,
callbacks=[SurvivalMetricsCallback(eval_dataloader)],
)
# Train
trainer.train()
# Save and push
trainer.save_model()
if push_to_hub:
trainer.push_to_hub()
trackio.alert(
title="Training Complete",
text=f"MuLGIT training finished. "
f"Steps: {trainer.state.global_step}, "
f"Best metric: {trainer.state.best_metric}",
level="INFO",
)
return trainer
# ─── Standalone Training Loop (without HF Trainer) ──────────────────────────
def train_mulgit_standalone(
model: nn.Module,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
num_epochs: int = 100,
learning_rate: float = 5.8e-4,
weight_decay: float = 0.00598,
device: str = "cuda",
project: str = "mulgit",
space_id: Optional[str] = None,
push_to_hub: bool = False,
hub_model_id: Optional[str] = None,
) -> nn.Module:
"""
Train MuLGIT with a direct PyTorch training loop + Trackio.
Useful when HF Trainer integration is not needed.
"""
import os
os.environ["TRACKIO_PROJECT"] = project
if space_id:
os.environ["TRACKIO_SPACE_ID"] = space_id
model = model.to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
)
from .losses import cox_negative_log_likelihood
model.train()
global_step = 0
best_c_index = 0.0
for epoch in range(num_epochs):
epoch_losses = []
for batch in train_dataloader:
methylation = batch["methylation"].to(device)
cnv = batch["cnv"].to(device)
mrna = batch["mrna"].to(device)
mirna = batch["mirna"].to(device)
times = batch["survival_times"].to(device)
events = batch["event_observed"].to(device)
outputs = model(methylation, cnv, mrna, mirna)
loss = cox_negative_log_likelihood(outputs["risk"], times, events)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_losses.append(loss.item())
global_step += 1
if global_step % 10 == 1:
trackio.log({"loss": loss.item(), "step": global_step})
avg_loss = np.mean(epoch_losses)
trackio.log({"epoch": epoch, "epoch_loss": avg_loss})
# Alert on anomalies
if np.isnan(avg_loss):
trackio.alert(
"NaN Loss", f"Epoch {epoch}: NaN loss. Reduce LR.", "ERROR"
)
break
# Evaluate
if eval_dataloader is not None:
from .losses import evaluate_survival_model
metrics = evaluate_survival_model(model, eval_dataloader, device)
trackio.log(metrics)
if metrics["c_index"] > best_c_index:
best_c_index = metrics["c_index"]
trackio.alert(
"New Best", f"C-index={best_c_index:.4f} at epoch {epoch}", "INFO"
)
# Push to Hub
if push_to_hub and hub_model_id:
from huggingface_hub import HfApi
torch.save(model.state_dict(), "model.pt")
api = HfApi()
api.upload_file(
path_or_fileobj="model.pt",
path_in_repo="pytorch_model.bin",
repo_id=hub_model_id,
)
trackio.alert("Complete", f"Training done. Best C-index: {best_c_index:.4f}", "INFO")
return model