Troyejcan's picture
Upload 175 files
35c7764 verified
import os
import pandas as pd
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import optuna
import functools
import time
from .model import GeneformerMultiTask
from .utils import (
calculate_metrics,
get_layer_freeze_range,
set_seed,
initialize_wandb,
create_model,
setup_optimizer_and_scheduler,
save_model,
save_hyperparameters,
prepare_training_environment,
log_training_step,
log_validation_metrics,
save_validation_predictions,
setup_logging,
setup_distributed_environment,
train_distributed
)
class Trainer:
"""Trainer class for multi-task learning"""
def __init__(self, config):
self.config = config
self.device = None
self.model = None
self.optimizer = None
self.scheduler = None
self.writer = None
self.is_distributed = config.get("distributed_training", False)
self.local_rank = config.get("local_rank", 0)
self.is_main_process = not self.is_distributed or self.local_rank == 0
def train_epoch(self, train_loader, epoch):
"""Train the model for one epoch."""
epoch_start = time.time()
self.model.train()
# For distributed training, we need to be aware of the global batch count
if self.is_distributed:
# Get world size for reporting
world_size = dist.get_world_size()
# Calculate total batches across all GPUs
total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader)
else:
world_size = 1
total_batches_global = len(train_loader)
progress_bar = None
if self.is_main_process:
# Use the global batch count for progress reporting in distributed mode
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}",
total=len(train_loader))
iterator = progress_bar
# Report distributed training information
if self.is_distributed:
print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, "
f"{total_batches_global} total batches globally")
else:
iterator = train_loader
batch_times = []
forward_times = []
backward_times = []
optimizer_times = []
# Get gradient accumulation steps from config (default to 1 if not specified)
accumulation_steps = self.config.get("gradient_accumulation_steps", 1)
# Zero gradients at the beginning
self.optimizer.zero_grad()
# Track loss for the entire epoch
total_loss = 0.0
num_batches = 0
accumulated_loss = 0.0
for batch_idx, batch in enumerate(iterator):
batch_start = time.time()
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = [
batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"]
]
forward_start = time.time()
loss, _, _ = self.model(input_ids, attention_mask, labels)
# Scale loss by accumulation steps for gradient accumulation
if accumulation_steps > 1:
loss = loss / accumulation_steps
forward_end = time.time()
forward_times.append(forward_end - forward_start)
# Track loss - store the unscaled loss for reporting
unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps)
total_loss += unscaled_loss
num_batches += 1
accumulated_loss += loss.item() # For gradient accumulation tracking
backward_start = time.time()
# Use no_sync() for all but the last accumulation step to avoid unnecessary communication
if self.is_distributed and accumulation_steps > 1:
# If this is not the last accumulation step or the last batch
if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader):
with self.model.no_sync():
loss.backward()
else:
loss.backward()
else:
# Non-distributed training or accumulation_steps=1
loss.backward()
backward_end = time.time()
backward_times.append(backward_end - backward_start)
# Only update weights after accumulation_steps or at the end of the epoch
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
if self.config["gradient_clipping"]:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
optimizer_start = time.time()
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
optimizer_end = time.time()
optimizer_times.append(optimizer_end - optimizer_start)
# Log after optimizer step
if self.is_main_process:
# Calculate running average loss
avg_loss = total_loss / num_batches
log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx)
# Update progress bar with just the running average loss
progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
accumulated_loss = 0.0
else:
optimizer_times.append(0) # No optimizer step taken
batch_end = time.time()
batch_times.append(batch_end - batch_start)
epoch_end = time.time()
# Calculate average loss for the epoch
epoch_avg_loss = total_loss / num_batches
# If distributed, gather losses from all processes to compute global average
if self.is_distributed:
# Create a tensor to hold the loss
loss_tensor = torch.tensor([epoch_avg_loss], device=self.device)
# Gather losses from all processes
all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
dist.all_gather(all_losses, loss_tensor)
# Compute the global average loss across all processes
epoch_avg_loss = torch.mean(torch.stack(all_losses)).item()
if self.is_main_process:
# douhble check if batch_size has already been adjusted for world_size in the config
# This avoids double-counting the effective batch size
per_gpu_batch_size = self.config['batch_size']
total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size
print(f"Epoch {epoch+1} timing:")
print(f" Total epoch time: {epoch_end - epoch_start:.2f}s")
print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s")
print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s")
print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s")
print(f" Gradient accumulation steps: {accumulation_steps}")
print(f" Batch size per GPU: {per_gpu_batch_size}")
print(f" Effective global batch size: {total_effective_batch}")
print(f" Average training loss: {epoch_avg_loss:.4f}")
if self.is_distributed:
print(f" Total batches processed across all GPUs: {total_batches_global}")
print(f" Communication optimization: Using no_sync() for gradient accumulation")
return epoch_avg_loss # Return the average loss for the epoch
def validate_model(self, val_loader):
val_start = time.time()
self.model.eval()
val_loss = 0.0
task_true_labels = {task_name: [] for task_name in self.config["task_names"]}
task_pred_labels = {task_name: [] for task_name in self.config["task_names"]}
task_pred_probs = {task_name: [] for task_name in self.config["task_names"]}
val_cell_ids = {}
sample_counter = 0
batch_times = []
# Print validation dataset size
if self.is_main_process:
print(f"Validation dataset size: {len(val_loader.dataset)} samples")
print(f"Number of validation batches: {len(val_loader)}")
if self.is_distributed:
world_size = dist.get_world_size()
print(f"Distributed validation: {world_size} GPUs")
if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
samples_per_gpu = val_loader.sampler.num_samples
print(f"Each GPU processes {samples_per_gpu} validation samples")
print(f"Total validation samples processed: {samples_per_gpu * world_size}")
with torch.no_grad():
for batch in val_loader:
batch_start = time.time()
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = [
batch["labels"][task_name].to(self.device)
for task_name in self.config["task_names"]
]
loss, logits, _ = self.model(input_ids, attention_mask, labels)
val_loss += loss.item()
if "cell_id" in batch:
for i, cell_id in enumerate(batch["cell_id"]):
# Store the actual index for later mapping to unique_cell_id
val_cell_ids[sample_counter + i] = cell_id.item()
for sample_idx in range(len(batch["input_ids"])):
for i, task_name in enumerate(self.config["task_names"]):
true_label = batch["labels"][task_name][sample_idx].item()
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
# Store the full probability distribution
pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist()
task_true_labels[task_name].append(true_label)
task_pred_labels[task_name].append(pred_label)
task_pred_probs[task_name].append(pred_prob)
# Update current index for cell ID tracking
sample_counter += len(batch["input_ids"])
batch_end = time.time()
batch_times.append(batch_end - batch_start)
# norm validation loss by the number of batches
val_loss /= len(val_loader)
# distributed, gather results from all processes
if self.is_distributed:
# Create tensors to hold the local results
loss_tensor = torch.tensor([val_loss], device=self.device)
gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_losses, loss_tensor)
# Compute average loss across all processes
val_loss = torch.mean(torch.cat(gathered_losses)).item()
world_size = dist.get_world_size()
if self.is_main_process:
print(f"Collected predictions from rank {self.local_rank}")
print(f"Number of samples processed by this rank: {sample_counter}")
val_end = time.time()
if self.is_main_process:
print(f"Validation timing:")
print(f" Total validation time: {val_end - val_start:.2f}s")
print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
print(f" Collected {len(val_cell_ids)} cell indices from validation data")
print(f" Processed {sample_counter} total samples during validation")
# Print number of samples per task
for task_name in self.config["task_names"]:
print(f" Task {task_name}: {len(task_true_labels[task_name])} samples")
return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids
def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
"""Train the model and return validation loss and trained model."""
if self.config.get("use_wandb", False) and self.is_main_process:
initialize_wandb(self.config)
# Create model
self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank)
# Setup optimizer and scheduler
total_steps = len(train_loader) * self.config["epochs"]
self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps)
# Training loop
if self.is_main_process:
epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress")
else:
epoch_progress = range(self.config["epochs"])
best_val_loss = float('inf')
train_losses = []
with setup_logging(self.config) as self.writer:
for epoch in epoch_progress:
if self.is_distributed:
train_loader.sampler.set_epoch(epoch)
train_loss = self.train_epoch(train_loader, epoch)
train_losses.append(train_loss)
# Run validation after each epoch if configured to do so
if self.config.get("validate_each_epoch", False):
val_loss, _, _, _, _ = self.validate_model(val_loader)
if val_loss < best_val_loss:
best_val_loss = val_loss
if self.is_main_process:
epoch_progress.set_postfix({
"train_loss": f"{train_loss:.4f}",
"val_loss": f"{val_loss:.4f}",
"best_val_loss": f"{best_val_loss:.4f}"
})
else:
if self.is_main_process:
epoch_progress.set_postfix({
"train_loss": f"{train_loss:.4f}"
})
val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader)
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
if self.is_main_process:
log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"])
# Save validation predictions
save_validation_predictions(
val_cell_ids,
task_true_labels,
task_pred_labels,
task_pred_probs,
{**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping
)
if self.config.get("use_wandb", False):
import wandb
wandb.finish()
print(f"\nTraining Summary:")
print(f" Final Training Loss: {train_losses[-1]:.4f}")
print(f" Final Validation Loss: {val_loss:.4f}")
for task_name, metrics in task_metrics.items():
print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
return val_loss, self.model # Return both the validation loss and the trained model
def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
if self.is_distributed:
self.device = torch.device(f"cuda:{self.local_rank}")
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = create_model(self.config, num_labels_list, self.device)
# war model w DDP
if self.is_distributed:
self.model = DDP(self.model, device_ids=[self.local_rank])
# communication hook to optimize gradient synchronization
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
# default hook which maintains full precision
self.model.register_comm_hook(
state=None,
hook=comm_hooks.allreduce_hook
)
print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization")
print(f"Rank {self.local_rank}: Using samplers created in distributed worker")
print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples")
if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'):
print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch")
if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples")
# Set up optimizer and scheduler
self.optimizer, self.scheduler = setup_optimizer_and_scheduler(
self.model, self.config, len(train_loader)
)
if self.is_main_process and self.config.get("use_wandb", False):
initialize_wandb(self.config)
return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
"""Train a model with the given configuration and data."""
# Check if distributed training is enabled
if config.get("distributed_training", False):
# Check if we have multiple GPUs
if torch.cuda.device_count() > 1:
result = train_distributed(
Trainer,
config,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list
)
if result is not None:
return result
else:
print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
config["distributed_training"] = False
# Non-distributed training
trainer = Trainer(config)
trainer.device = device
return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
def objective(
trial,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
config,
device,
):
"""Objective function for Optuna hyperparameter optimization."""
set_seed(config["seed"])
initialize_wandb(config)
trial_config = config.copy()
# Suggest hyperparameters for this trial
for param_name, param_config in config["hyperparameters"].items():
if param_name == "lr_scheduler_type":
trial_config[param_name] = trial.suggest_categorical(
param_name, param_config["choices"]
)
elif param_name == "task_weights" and config["use_task_weights"]:
weights = [
trial.suggest_float(
f"task_weight_{i}",
param_config["low"],
param_config["high"],
)
for i in range(len(num_labels_list))
]
weight_sum = sum(weights)
trial_config[param_name] = [w / weight_sum for w in weights]
elif "log" in param_config and param_config["log"]:
trial_config[param_name] = trial.suggest_float(
param_name, param_config["low"], param_config["high"], log=True
)
else:
trial_config[param_name] = trial.suggest_float(
param_name, param_config["low"], param_config["high"]
)
# Set appropriate max layers to freeze based on pretrained model
if "max_layers_to_freeze" in trial_config:
freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
trial_config["max_layers_to_freeze"] = int(trial.suggest_int(
"max_layers_to_freeze",
freeze_range["min"],
freeze_range["max"]
))
trial_config["run_name"] = f"trial_{trial.number}"
# Handle distributed training for this trial
if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1:
manager = mp.Manager()
shared_dict = manager.dict()
train_distributed(
Trainer,
trial_config,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
trial.number,
shared_dict
)
val_loss = shared_dict.get('val_loss', float('inf'))
task_metrics = shared_dict.get('task_metrics', {})
trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {}))
trial.set_user_attr("task_weights", trial_config["task_weights"])
if config.get("use_wandb", False):
import wandb
wandb.log({
"trial_number": trial.number,
"val_loss": val_loss,
**{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
**{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
})
wandb.finish()
return val_loss
with setup_logging(trial_config) as writer:
trainer = Trainer(trial_config)
trainer.device = device
trainer.writer = writer
# Create model with trial hyperparameters
trainer.model = create_model(trial_config, num_labels_list, device)
total_steps = len(train_loader) * config["epochs"]
trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps)
# Training loop
for epoch in range(config["epochs"]):
trainer.train_epoch(train_loader, epoch)
val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader)
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
# Log metrics
log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"])
# Save validation predictions
save_validation_predictions(
val_cell_ids,
task_true_labels,
task_pred_labels,
task_pred_probs,
{**trial_config, "val_cell_mapping": val_cell_id_mapping},
trial.number,
)
# Store model state dict and task weights in trial user attributes
trial.set_user_attr("model_state_dict", trainer.model.state_dict())
trial.set_user_attr("task_weights", trial_config["task_weights"])
# Report intermediate value to Optuna
trial.report(val_loss, config["epochs"])
if trial.should_prune():
raise optuna.TrialPruned()
if config.get("use_wandb", False):
import wandb
wandb.log(
{
"trial_number": trial.number,
"val_loss": val_loss,
**{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
**{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
**{k: v for k, v in trial_config.items() if k in [
"learning_rate", "warmup_ratio", "weight_decay", "dropout_rate",
"lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"
]},
}
)
wandb.finish()
return val_loss
def run_manual_tuning(config):
"""Run training with manually specified hyperparameters."""
(
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
) = prepare_training_environment(config)
print("\nManual hyperparameters being used:")
for key, value in config["manual_hyperparameters"].items():
print(f"{key}: {value}")
print()
# Update config with manual hyperparameters
for key, value in config["manual_hyperparameters"].items():
config[key] = value
# Train the model
val_loss, trained_model = train_model(
config,
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
)
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
# Save the trained model - only if not using distributed training
# (distributed training saves the model in the worker)
if not config.get("distributed_training", False):
model_save_directory = os.path.join(
config["model_save_path"], "GeneformerMultiTask"
)
save_model(trained_model, model_save_directory)
# Save the hyperparameters
hyperparams_to_save = {
**config["manual_hyperparameters"],
"dropout_rate": config["dropout_rate"],
"use_task_weights": config["use_task_weights"],
"task_weights": config["task_weights"],
"max_layers_to_freeze": config["max_layers_to_freeze"],
"use_attention_pooling": config["use_attention_pooling"],
}
save_hyperparameters(model_save_directory, hyperparams_to_save)
return val_loss
def run_optuna_study(config):
"""Run hyperparameter optimization using Optuna."""
# Prepare training environment
(
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
) = prepare_training_environment(config)
# If manual hyperparameters are specified, use them instead of running Optuna
if config.get("use_manual_hyperparameters", False):
return run_manual_tuning(config)
# Create a partial function with fixed arguments for the objective
objective_with_config_and_data = functools.partial(
objective,
train_loader=train_loader,
val_loader=val_loader,
train_cell_id_mapping=train_cell_id_mapping,
val_cell_id_mapping=val_cell_id_mapping,
num_labels_list=num_labels_list,
config=config,
device=device,
)
# Create and run the Optuna study
study = optuna.create_study(
direction="minimize", # Minimize validation loss
study_name=config["study_name"],
# storage=config["storage"],
load_if_exists=True,
)
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
# After finding the best trial
best_params = study.best_trial.params
best_task_weights = study.best_trial.user_attrs["task_weights"]
print("Saving the best model and its hyperparameters...")
# Create a model with the best hyperparameters
best_model = GeneformerMultiTask(
config["pretrained_path"],
num_labels_list,
dropout_rate=best_params["dropout_rate"],
use_task_weights=config["use_task_weights"],
task_weights=best_task_weights,
max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0),
use_attention_pooling=best_params.get("use_attention_pooling", False),
)
# Get the best model state dictionary
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
best_model_state_dict = {
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
}
best_model.load_state_dict(best_model_state_dict, strict=False)
model_save_directory = os.path.join(
config["model_save_path"], "GeneformerMultiTask"
)
save_model(best_model, model_save_directory)
save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights})
return study.best_trial.value # Return the best validation loss