""" HuggingFace Trainer Integration for MuLGIT Wraps MuLGIT custom models to work with HuggingFace's Trainer API, enabling: - Push to Hub - Trackio monitoring - Mixed precision training - Gradient accumulation - Learning rate scheduling - Checkpointing Pattern: Custom model + HF Trainer via a custom forward-compatible wrapper that accepts batch dicts and returns loss-compatible outputs. """ import torch import torch.nn as nn from torch.utils.data import DataLoader from typing import Optional, Dict, Any, Tuple import numpy as np from transformers import ( Trainer, TrainingArguments, TrainerCallback, PreTrainedModel, PretrainedConfig, ) import trackio # ─── Custom HF-Style Config ───────────────────────────────────────────────── class MuLGITHFConfig(PretrainedConfig): """HuggingFace-compatible config for MuLGIT models.""" model_type = "mulgit" def __init__( self, dim_methylation: int = 20000, dim_cnv: int = 20000, dim_mrna: int = 20000, dim_mirna: int = 2000, latent_dim: int = 48, dropout: float = 0.1058, num_classes: int = 1, **kwargs, ): super().__init__(**kwargs) self.dim_methylation = dim_methylation self.dim_cnv = dim_cnv self.dim_mrna = dim_mrna self.dim_mirna = dim_mirna self.latent_dim = latent_dim self.dropout = dropout self.num_classes = num_classes # ─── HF-Compatible Model Wrapper ───────────────────────────────────────────── class MuLGITForSurvival(PreTrainedModel): """ HuggingFace-compatible wrapper around MuLGIT that integrates with the Trainer API for survival prediction. Handles the custom forward pass (multi-omics input dict → risk score) and loss computation internally. """ config_class = MuLGITHFConfig def __init__(self, config: MuLGITHFConfig): super().__init__(config) # Import here to avoid circular dependency from .models import MuLGITModel self.mulgit = MuLGITModel( dim_methylation=config.dim_methylation, dim_cnv=config.dim_cnv, dim_mrna=config.dim_mrna, dim_mirna=config.dim_mirna, latent_dim=config.latent_dim, dropout=config.dropout, num_classes=config.num_classes, ) self.loss_fn = None # set after init if needed def forward( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, survival_times: Optional[torch.Tensor] = None, event_observed: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs, ) -> Dict[str, torch.Tensor]: """ Forward pass compatible with HF Trainer. HF Trainer passes all batch keys as kwargs. We extract the omics features and optional labels. """ outputs = self.mulgit( methylation=methylation, cnv=cnv, mrna=mrna, mirna=mirna, ) # Compute loss during training if survival_times is not None and event_observed is not None: from .losses import cox_negative_log_likelihood loss = cox_negative_log_likelihood( outputs["risk"], survival_times, event_observed ) outputs["loss"] = loss return outputs def _init_weights(self, module): """Initialize weights (called by HF).""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() # ─── Custom Trainer for Multi-Omics ────────────────────────────────────────── class MuLGITTrainer(Trainer): """ Custom HF Trainer for MuLGIT multi-omics survival prediction. Extends Trainer to: - Accept custom batch format (multi-omics dicts) - Compute Cox loss correctly - Track survival-specific metrics (C-index) """ def compute_loss(self, model, inputs, return_outputs=False, **kwargs): """ Compute Cox loss from multi-omics batch. HF Trainer calls this internally. Our model forward pass returns a dict with "loss" key if labels are provided. """ outputs = model(**inputs) loss = outputs.get("loss", None) if loss is None: # If model didn't compute loss, do it here from .losses import cox_negative_log_likelihood loss = cox_negative_log_likelihood( outputs["risk"], inputs["survival_times"], inputs["event_observed"], ) outputs["loss"] = loss return (loss, outputs) if return_outputs else loss # ─── Trackio Callback for Survival Metrics ─────────────────────────────────── class SurvivalMetricsCallback(TrainerCallback): """ Logs survival-specific metrics (C-index, AUC) to Trackio. Also fires alerts for training divergence or convergence. """ def __init__(self, eval_dataloader=None): self.eval_dataloader = eval_dataloader self.best_c_index = 0.0 self.loss_history = [] def on_log(self, args, state, control, logs=None, **kwargs): """Log training metrics and fire alerts if needed.""" if logs is None: return loss = logs.get("loss", None) if loss is not None: self.loss_history.append(loss) # Alert on loss spikes if len(self.loss_history) > 10: recent_mean = np.mean(self.loss_history[-10:]) if loss > recent_mean * 3: trackio.alert( title="Loss Spike Detected", text=f"loss={loss:.4f} at step {state.global_step} — " f"3x above recent mean {recent_mean:.4f}. " f"Consider reducing learning rate.", level="WARN", ) # Alert on NaN if np.isnan(loss): trackio.alert( title="NaN Loss", text=f"NaN loss at step {state.global_step}. " f"Training diverged. Reduce learning rate by 10x and restart.", level="ERROR", ) def on_evaluate(self, args, state, control, metrics=None, **kwargs): """Log eval metrics and fire convergence alerts.""" if metrics is None: return c_index = metrics.get("eval_c_index", None) if c_index is not None: if c_index > self.best_c_index: self.best_c_index = c_index trackio.alert( title="New Best C-index", text=f"C-index={c_index:.4f} at step {state.global_step}. " f"New best model.", level="INFO", ) if c_index < 0.5: trackio.alert( title="C-index Below Random", text=f"C-index={c_index:.4f} below 0.5. Model is not learning. " f"Check data preprocessing or increase model capacity.", level="WARN", ) # ─── Training Script ───────────────────────────────────────────────────────── def train_mulgit( model: MuLGITForSurvival, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, output_dir: str = "./mulgit_output", num_epochs: int = 100, learning_rate: float = 5.8e-4, batch_size: int = 256, weight_decay: float = 0.00598, push_to_hub: bool = True, hub_model_id: Optional[str] = None, trackio_project: str = "mulgit", trackio_space_id: Optional[str] = None, **kwargs, ) -> Trainer: """ Train MuLGIT model with HF Trainer + Trackio monitoring. Args: model: MuLGITForSurvival model instance train_dataloader: training data eval_dataloader: evaluation data (optional) output_dir: where to save checkpoints num_epochs: number of training epochs learning_rate: from SeNMo paper (5.8e-4) batch_size: from SeNMo paper (256) weight_decay: from SeNMo paper (0.00598) push_to_hub: whether to push model to Hub hub_model_id: Hub repo ID trackio_project: Trackio project name trackio_space_id: Trackio Space for dashboard """ training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=num_epochs, per_device_train_batch_size=batch_size // 4, # adjust for multi-GPU per_device_eval_batch_size=batch_size // 2, gradient_accumulation_steps=4, # effective batch size = batch_size learning_rate=learning_rate, weight_decay=weight_decay, warmup_ratio=0.1, logging_strategy="steps", logging_steps=10, logging_first_step=True, eval_strategy="epoch" if eval_dataloader else "no", save_strategy="epoch", save_total_limit=3, load_best_model_at_end=True if eval_dataloader else False, metric_for_best_model="c_index" if eval_dataloader else "loss", greater_is_better=True, bf16=True, dataloader_drop_last=True, report_to="trackio", run_name=f"mulgit_lr{learning_rate}_bs{batch_size}", push_to_hub=push_to_hub, hub_model_id=hub_model_id, disable_tqdm=True, ddp_find_unused_parameters=False, **kwargs, ) # Setup Trackio if trackio_project: import os os.environ["TRACKIO_PROJECT"] = trackio_project if trackio_space_id: os.environ["TRACKIO_SPACE_ID"] = trackio_space_id # Create trainer trainer = MuLGITTrainer( model=model, args=training_args, train_dataset=train_dataloader.dataset, eval_dataset=eval_dataloader.dataset if eval_dataloader else None, callbacks=[SurvivalMetricsCallback(eval_dataloader)], ) # Train trainer.train() # Save and push trainer.save_model() if push_to_hub: trainer.push_to_hub() trackio.alert( title="Training Complete", text=f"MuLGIT training finished. " f"Steps: {trainer.state.global_step}, " f"Best metric: {trainer.state.best_metric}", level="INFO", ) return trainer # ─── Standalone Training Loop (without HF Trainer) ────────────────────────── def train_mulgit_standalone( model: nn.Module, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, num_epochs: int = 100, learning_rate: float = 5.8e-4, weight_decay: float = 0.00598, device: str = "cuda", project: str = "mulgit", space_id: Optional[str] = None, push_to_hub: bool = False, hub_model_id: Optional[str] = None, ) -> nn.Module: """ Train MuLGIT with a direct PyTorch training loop + Trackio. Useful when HF Trainer integration is not needed. """ import os os.environ["TRACKIO_PROJECT"] = project if space_id: os.environ["TRACKIO_SPACE_ID"] = space_id model = model.to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, ) from .losses import cox_negative_log_likelihood model.train() global_step = 0 best_c_index = 0.0 for epoch in range(num_epochs): epoch_losses = [] for batch in train_dataloader: methylation = batch["methylation"].to(device) cnv = batch["cnv"].to(device) mrna = batch["mrna"].to(device) mirna = batch["mirna"].to(device) times = batch["survival_times"].to(device) events = batch["event_observed"].to(device) outputs = model(methylation, cnv, mrna, mirna) loss = cox_negative_log_likelihood(outputs["risk"], times, events) optimizer.zero_grad() loss.backward() optimizer.step() epoch_losses.append(loss.item()) global_step += 1 if global_step % 10 == 1: trackio.log({"loss": loss.item(), "step": global_step}) avg_loss = np.mean(epoch_losses) trackio.log({"epoch": epoch, "epoch_loss": avg_loss}) # Alert on anomalies if np.isnan(avg_loss): trackio.alert( "NaN Loss", f"Epoch {epoch}: NaN loss. Reduce LR.", "ERROR" ) break # Evaluate if eval_dataloader is not None: from .losses import evaluate_survival_model metrics = evaluate_survival_model(model, eval_dataloader, device) trackio.log(metrics) if metrics["c_index"] > best_c_index: best_c_index = metrics["c_index"] trackio.alert( "New Best", f"C-index={best_c_index:.4f} at epoch {epoch}", "INFO" ) # Push to Hub if push_to_hub and hub_model_id: from huggingface_hub import HfApi torch.save(model.state_dict(), "model.pt") api = HfApi() api.upload_file( path_or_fileobj="model.pt", path_in_repo="pytorch_model.bin", repo_id=hub_model_id, ) trackio.alert("Complete", f"Training done. Best C-index: {best_c_index:.4f}", "INFO") return model