| """ |
| HuggingFace Trainer Integration for MuLGIT |
| |
| Wraps MuLGIT custom models to work with HuggingFace's Trainer API, |
| enabling: |
| - Push to Hub |
| - Trackio monitoring |
| - Mixed precision training |
| - Gradient accumulation |
| - Learning rate scheduling |
| - Checkpointing |
| |
| Pattern: Custom model + HF Trainer via a custom forward-compatible wrapper |
| that accepts batch dicts and returns loss-compatible outputs. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from typing import Optional, Dict, Any, Tuple |
| import numpy as np |
|
|
| from transformers import ( |
| Trainer, |
| TrainingArguments, |
| TrainerCallback, |
| PreTrainedModel, |
| PretrainedConfig, |
| ) |
| import trackio |
|
|
|
|
| |
|
|
| class MuLGITHFConfig(PretrainedConfig): |
| """HuggingFace-compatible config for MuLGIT models.""" |
| |
| model_type = "mulgit" |
| |
| def __init__( |
| self, |
| dim_methylation: int = 20000, |
| dim_cnv: int = 20000, |
| dim_mrna: int = 20000, |
| dim_mirna: int = 2000, |
| latent_dim: int = 48, |
| dropout: float = 0.1058, |
| num_classes: int = 1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.dim_methylation = dim_methylation |
| self.dim_cnv = dim_cnv |
| self.dim_mrna = dim_mrna |
| self.dim_mirna = dim_mirna |
| self.latent_dim = latent_dim |
| self.dropout = dropout |
| self.num_classes = num_classes |
|
|
|
|
| |
|
|
| class MuLGITForSurvival(PreTrainedModel): |
| """ |
| HuggingFace-compatible wrapper around MuLGIT that integrates with |
| the Trainer API for survival prediction. |
| |
| Handles the custom forward pass (multi-omics input dict β risk score) |
| and loss computation internally. |
| """ |
| |
| config_class = MuLGITHFConfig |
| |
| def __init__(self, config: MuLGITHFConfig): |
| super().__init__(config) |
| |
| |
| from .models import MuLGITModel |
| |
| self.mulgit = MuLGITModel( |
| dim_methylation=config.dim_methylation, |
| dim_cnv=config.dim_cnv, |
| dim_mrna=config.dim_mrna, |
| dim_mirna=config.dim_mirna, |
| latent_dim=config.latent_dim, |
| dropout=config.dropout, |
| num_classes=config.num_classes, |
| ) |
| |
| self.loss_fn = None |
|
|
| def forward( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| survival_times: Optional[torch.Tensor] = None, |
| event_observed: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass compatible with HF Trainer. |
| |
| HF Trainer passes all batch keys as kwargs. We extract the omics |
| features and optional labels. |
| """ |
| outputs = self.mulgit( |
| methylation=methylation, |
| cnv=cnv, |
| mrna=mrna, |
| mirna=mirna, |
| ) |
| |
| |
| if survival_times is not None and event_observed is not None: |
| from .losses import cox_negative_log_likelihood |
| loss = cox_negative_log_likelihood( |
| outputs["risk"], survival_times, event_observed |
| ) |
| outputs["loss"] = loss |
| |
| return outputs |
|
|
| def _init_weights(self, module): |
| """Initialize weights (called by HF).""" |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| |
|
|
| class MuLGITTrainer(Trainer): |
| """ |
| Custom HF Trainer for MuLGIT multi-omics survival prediction. |
| |
| Extends Trainer to: |
| - Accept custom batch format (multi-omics dicts) |
| - Compute Cox loss correctly |
| - Track survival-specific metrics (C-index) |
| """ |
| |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| """ |
| Compute Cox loss from multi-omics batch. |
| |
| HF Trainer calls this internally. Our model forward pass returns |
| a dict with "loss" key if labels are provided. |
| """ |
| outputs = model(**inputs) |
| loss = outputs.get("loss", None) |
| |
| if loss is None: |
| |
| from .losses import cox_negative_log_likelihood |
| loss = cox_negative_log_likelihood( |
| outputs["risk"], |
| inputs["survival_times"], |
| inputs["event_observed"], |
| ) |
| outputs["loss"] = loss |
| |
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| |
|
|
| class SurvivalMetricsCallback(TrainerCallback): |
| """ |
| Logs survival-specific metrics (C-index, AUC) to Trackio. |
| Also fires alerts for training divergence or convergence. |
| """ |
| |
| def __init__(self, eval_dataloader=None): |
| self.eval_dataloader = eval_dataloader |
| self.best_c_index = 0.0 |
| self.loss_history = [] |
| |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| """Log training metrics and fire alerts if needed.""" |
| if logs is None: |
| return |
| |
| loss = logs.get("loss", None) |
| if loss is not None: |
| self.loss_history.append(loss) |
| |
| |
| if len(self.loss_history) > 10: |
| recent_mean = np.mean(self.loss_history[-10:]) |
| if loss > recent_mean * 3: |
| trackio.alert( |
| title="Loss Spike Detected", |
| text=f"loss={loss:.4f} at step {state.global_step} β " |
| f"3x above recent mean {recent_mean:.4f}. " |
| f"Consider reducing learning rate.", |
| level="WARN", |
| ) |
| |
| |
| if np.isnan(loss): |
| trackio.alert( |
| title="NaN Loss", |
| text=f"NaN loss at step {state.global_step}. " |
| f"Training diverged. Reduce learning rate by 10x and restart.", |
| level="ERROR", |
| ) |
| |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
| """Log eval metrics and fire convergence alerts.""" |
| if metrics is None: |
| return |
| |
| c_index = metrics.get("eval_c_index", None) |
| if c_index is not None: |
| if c_index > self.best_c_index: |
| self.best_c_index = c_index |
| trackio.alert( |
| title="New Best C-index", |
| text=f"C-index={c_index:.4f} at step {state.global_step}. " |
| f"New best model.", |
| level="INFO", |
| ) |
| |
| if c_index < 0.5: |
| trackio.alert( |
| title="C-index Below Random", |
| text=f"C-index={c_index:.4f} below 0.5. Model is not learning. " |
| f"Check data preprocessing or increase model capacity.", |
| level="WARN", |
| ) |
|
|
|
|
| |
|
|
| def train_mulgit( |
| model: MuLGITForSurvival, |
| train_dataloader: DataLoader, |
| eval_dataloader: Optional[DataLoader] = None, |
| output_dir: str = "./mulgit_output", |
| num_epochs: int = 100, |
| learning_rate: float = 5.8e-4, |
| batch_size: int = 256, |
| weight_decay: float = 0.00598, |
| push_to_hub: bool = True, |
| hub_model_id: Optional[str] = None, |
| trackio_project: str = "mulgit", |
| trackio_space_id: Optional[str] = None, |
| **kwargs, |
| ) -> Trainer: |
| """ |
| Train MuLGIT model with HF Trainer + Trackio monitoring. |
| |
| Args: |
| model: MuLGITForSurvival model instance |
| train_dataloader: training data |
| eval_dataloader: evaluation data (optional) |
| output_dir: where to save checkpoints |
| num_epochs: number of training epochs |
| learning_rate: from SeNMo paper (5.8e-4) |
| batch_size: from SeNMo paper (256) |
| weight_decay: from SeNMo paper (0.00598) |
| push_to_hub: whether to push model to Hub |
| hub_model_id: Hub repo ID |
| trackio_project: Trackio project name |
| trackio_space_id: Trackio Space for dashboard |
| """ |
| |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=num_epochs, |
| per_device_train_batch_size=batch_size // 4, |
| per_device_eval_batch_size=batch_size // 2, |
| gradient_accumulation_steps=4, |
| learning_rate=learning_rate, |
| weight_decay=weight_decay, |
| warmup_ratio=0.1, |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| eval_strategy="epoch" if eval_dataloader else "no", |
| save_strategy="epoch", |
| save_total_limit=3, |
| load_best_model_at_end=True if eval_dataloader else False, |
| metric_for_best_model="c_index" if eval_dataloader else "loss", |
| greater_is_better=True, |
| bf16=True, |
| dataloader_drop_last=True, |
| report_to="trackio", |
| run_name=f"mulgit_lr{learning_rate}_bs{batch_size}", |
| push_to_hub=push_to_hub, |
| hub_model_id=hub_model_id, |
| disable_tqdm=True, |
| ddp_find_unused_parameters=False, |
| **kwargs, |
| ) |
| |
| |
| if trackio_project: |
| import os |
| os.environ["TRACKIO_PROJECT"] = trackio_project |
| if trackio_space_id: |
| os.environ["TRACKIO_SPACE_ID"] = trackio_space_id |
| |
| |
| trainer = MuLGITTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataloader.dataset, |
| eval_dataset=eval_dataloader.dataset if eval_dataloader else None, |
| callbacks=[SurvivalMetricsCallback(eval_dataloader)], |
| ) |
| |
| |
| trainer.train() |
| |
| |
| trainer.save_model() |
| if push_to_hub: |
| trainer.push_to_hub() |
| |
| trackio.alert( |
| title="Training Complete", |
| text=f"MuLGIT training finished. " |
| f"Steps: {trainer.state.global_step}, " |
| f"Best metric: {trainer.state.best_metric}", |
| level="INFO", |
| ) |
| |
| return trainer |
|
|
|
|
| |
|
|
| def train_mulgit_standalone( |
| model: nn.Module, |
| train_dataloader: DataLoader, |
| eval_dataloader: Optional[DataLoader] = None, |
| num_epochs: int = 100, |
| learning_rate: float = 5.8e-4, |
| weight_decay: float = 0.00598, |
| device: str = "cuda", |
| project: str = "mulgit", |
| space_id: Optional[str] = None, |
| push_to_hub: bool = False, |
| hub_model_id: Optional[str] = None, |
| ) -> nn.Module: |
| """ |
| Train MuLGIT with a direct PyTorch training loop + Trackio. |
| Useful when HF Trainer integration is not needed. |
| """ |
| import os |
| os.environ["TRACKIO_PROJECT"] = project |
| if space_id: |
| os.environ["TRACKIO_SPACE_ID"] = space_id |
| |
| model = model.to(device) |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=learning_rate, |
| weight_decay=weight_decay, |
| ) |
| |
| from .losses import cox_negative_log_likelihood |
| |
| model.train() |
| global_step = 0 |
| best_c_index = 0.0 |
| |
| for epoch in range(num_epochs): |
| epoch_losses = [] |
| |
| for batch in train_dataloader: |
| methylation = batch["methylation"].to(device) |
| cnv = batch["cnv"].to(device) |
| mrna = batch["mrna"].to(device) |
| mirna = batch["mirna"].to(device) |
| times = batch["survival_times"].to(device) |
| events = batch["event_observed"].to(device) |
| |
| outputs = model(methylation, cnv, mrna, mirna) |
| loss = cox_negative_log_likelihood(outputs["risk"], times, events) |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| epoch_losses.append(loss.item()) |
| global_step += 1 |
| |
| if global_step % 10 == 1: |
| trackio.log({"loss": loss.item(), "step": global_step}) |
| |
| avg_loss = np.mean(epoch_losses) |
| trackio.log({"epoch": epoch, "epoch_loss": avg_loss}) |
| |
| |
| if np.isnan(avg_loss): |
| trackio.alert( |
| "NaN Loss", f"Epoch {epoch}: NaN loss. Reduce LR.", "ERROR" |
| ) |
| break |
| |
| |
| if eval_dataloader is not None: |
| from .losses import evaluate_survival_model |
| metrics = evaluate_survival_model(model, eval_dataloader, device) |
| trackio.log(metrics) |
| |
| if metrics["c_index"] > best_c_index: |
| best_c_index = metrics["c_index"] |
| trackio.alert( |
| "New Best", f"C-index={best_c_index:.4f} at epoch {epoch}", "INFO" |
| ) |
| |
| |
| if push_to_hub and hub_model_id: |
| from huggingface_hub import HfApi |
| torch.save(model.state_dict(), "model.pt") |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj="model.pt", |
| path_in_repo="pytorch_model.bin", |
| repo_id=hub_model_id, |
| ) |
| |
| trackio.alert("Complete", f"Training done. Best C-index: {best_c_index:.4f}", "INFO") |
| return model |
|
|