Geneformer / geneformer /mtl /utils.py
mtydc's picture
Duplicate from ctheodoris/Geneformer
12a51b3
from typing import Dict, List, Optional, Union
import json
import os
import pickle
import random
import torch
import numpy as np
import optuna
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from torch.optim import AdamW
import pandas as pd
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from contextlib import contextmanager
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def initialize_wandb(config):
if config.get("use_wandb", False):
import wandb
wandb.init(
project=config.get("wandb_project", "geneformer_multitask"),
name=config.get("run_name", "experiment"),
config=config,
reinit=True,
)
def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0):
"""Create and initialize the model based on configuration."""
from .model import GeneformerMultiTask
model = GeneformerMultiTask(
config["pretrained_path"],
num_labels_list,
dropout_rate=config.get("dropout_rate", 0.1),
use_task_weights=config.get("use_task_weights", False),
task_weights=config.get("task_weights", None),
max_layers_to_freeze=config.get("max_layers_to_freeze", 0),
use_attention_pooling=config.get("use_attention_pooling", False),
)
# Move model to device
model.to(device)
if is_distributed:
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
return model
def setup_optimizer_and_scheduler(model, config, total_steps):
"""Set up optimizer and learning rate scheduler."""
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": config["weight_decay"],
},
{
"params": [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=config["learning_rate"],
eps=config.get("adam_epsilon", 1e-8)
)
# Prepare scheduler
warmup_steps = int(total_steps * config["warmup_ratio"])
scheduler_map = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup
}
scheduler_fn = scheduler_map.get(config["lr_scheduler_type"])
if not scheduler_fn:
raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}")
scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
return optimizer, scheduler
def save_model(model, model_save_directory):
"""Save model weights and configuration."""
os.makedirs(model_save_directory, exist_ok=True)
# Handle DDP model
if isinstance(model, DDP):
model_to_save = model.module
else:
model_to_save = model
model_state_dict = model_to_save.state_dict()
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
torch.save(model_state_dict, model_save_path)
# Save the model configuration
model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json"))
print(f"Model and configuration saved to {model_save_directory}")
def save_hyperparameters(model_save_directory, hyperparams):
"""Save hyperparameters to a JSON file."""
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
with open(hyperparams_path, "w") as f:
json.dump(hyperparams, f)
print(f"Hyperparameters saved to {hyperparams_path}")
def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"):
if metric_type == "single":
# Calculate metrics for a single task
if labels is None or preds is None:
raise ValueError("Labels and predictions must be provided for single task metrics")
task_name = None
if isinstance(labels, dict) and len(labels) == 1:
task_name = list(labels.keys())[0]
labels = labels[task_name]
preds = preds[task_name]
f1 = f1_score(labels, preds, average="macro")
accuracy = accuracy_score(labels, preds)
if return_format == "tuple":
return f1, accuracy
result = {"f1": f1, "accuracy": accuracy}
if task_name:
return {task_name: result}
return result
elif metric_type == "task_specific":
# Calculate metrics for multiple tasks
if task_data:
result = {}
for task_name, (task_labels, task_preds) in task_data.items():
f1 = f1_score(task_labels, task_preds, average="macro")
accuracy = accuracy_score(task_labels, task_preds)
result[task_name] = {"f1": f1, "accuracy": accuracy}
return result
elif isinstance(labels, dict) and isinstance(preds, dict):
result = {}
for task_name in labels:
if task_name in preds:
f1 = f1_score(labels[task_name], preds[task_name], average="macro")
accuracy = accuracy_score(labels[task_name], preds[task_name])
result[task_name] = {"f1": f1, "accuracy": accuracy}
return result
else:
raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided")
elif metric_type == "combined":
# Calculate combined metrics across all tasks
if labels is None or preds is None:
raise ValueError("Labels and predictions must be provided for combined metrics")
# Handle label encoding for non-numeric labels
if not all(isinstance(x, (int, float)) for x in labels + preds):
le = LabelEncoder()
le.fit(labels + preds)
labels = le.transform(labels)
preds = le.transform(preds)
f1 = f1_score(labels, preds, average="macro")
accuracy = accuracy_score(labels, preds)
if return_format == "tuple":
return f1, accuracy
return {"f1": f1, "accuracy": accuracy}
else:
raise ValueError(f"Unknown metric_type: {metric_type}")
def get_layer_freeze_range(pretrained_path):
if not pretrained_path:
return {"min": 0, "max": 0}
config = AutoConfig.from_pretrained(pretrained_path)
total_layers = config.num_hidden_layers
return {"min": 0, "max": total_layers - 1}
def prepare_training_environment(config):
"""
Prepare the training environment by setting seed and loading data.
Returns:
tuple: (device, train_loader, val_loader, train_cell_id_mapping,
val_cell_id_mapping, num_labels_list)
"""
from .data import prepare_data_loaders
# Set seed for reproducibility
set_seed(config["seed"])
# Set up device - for non-distributed training
if not config.get("distributed_training", False):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
# For distributed training, device will be set per process
device = None
# Load data using the streaming dataset
data = prepare_data_loaders(config)
# For distributed training, we'll set up samplers later in the distributed worker
# Don't create DistributedSampler here as process group isn't initialized yet
return (
device,
data["train_loader"],
data["val_loader"],
data["train_cell_mapping"],
data["val_cell_mapping"],
data["num_labels_list"],
)
# Optuna hyperparameter optimization utilities
def save_trial_callback(study, trial, trials_result_path):
"""
Callback to save Optuna trial results to a file.
Args:
study: Optuna study object
trial: Current trial object
trials_result_path: Path to save trial results
"""
with open(trials_result_path, "a") as f:
f.write(
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
)
def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study:
"""Create and run an Optuna study with TensorBoard logging."""
from optuna.integration import TensorBoardCallback
study = optuna.create_study(direction="maximize")
study.optimize(
objective,
n_trials=n_trials,
callbacks=[
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
]
)
return study
@contextmanager
def setup_logging(config):
run_name = config.get("run_name", "manual_run")
log_dir = os.path.join(config["tensorboard_log_dir"], run_name)
writer = SummaryWriter(log_dir=log_dir)
try:
yield writer
finally:
writer.close()
def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx):
"""Log training step metrics to TensorBoard and optionally W&B."""
writer.add_scalar(
"Training Loss", loss, epoch * steps_per_epoch + batch_idx
)
if config.get("use_wandb", False):
import wandb
wandb.log({"Training Loss": loss})
def log_validation_metrics(task_metrics, val_loss, config, writer, epoch):
"""Log validation metrics to console, TensorBoard, and optionally W&B."""
for task_name, metrics in task_metrics.items():
print(
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
)
if config.get("use_wandb", False):
import wandb
wandb.log(
{
f"{task_name} Validation F1 Macro": metrics["f1"],
f"{task_name} Validation Accuracy": metrics["accuracy"],
}
)
writer.add_scalar("Validation Loss", val_loss, epoch)
for task_name, metrics in task_metrics.items():
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch)
writer.add_scalar(
f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch
)
def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]:
"""Load or initialize task label mappings."""
label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl")
if os.path.exists(label_mappings_path):
with open(label_mappings_path, 'rb') as f:
return pickle.load(f)
return {task_name: {} for task_name in task_names}
def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict,
task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str],
inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict:
"""Create a row for validation predictions."""
batch_cell_idx = val_cell_indices.get(sample_idx)
cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}"
row = {"Cell ID": cell_id}
for task_name in task_names:
if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]):
true_idx = task_true_labels[task_name][sample_idx]
pred_idx = task_pred_labels[task_name][sample_idx]
true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}")
pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}")
row.update({
f"{task_name}_true_idx": true_idx,
f"{task_name}_pred_idx": pred_idx,
f"{task_name}_true_label": true_label,
f"{task_name}_pred_label": pred_label
})
if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]):
probs = task_pred_probs[task_name][sample_idx]
if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)):
prob_list = list(probs) if not isinstance(probs, list) else probs
row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list))
for class_idx, prob in enumerate(prob_list):
class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}")
row[f"{task_name}_prob_{class_label}"] = prob
else:
row[f"{task_name}_all_probs"] = str(probs)
return row
def save_validation_predictions(
val_cell_indices,
task_true_labels,
task_pred_labels,
task_pred_probs,
config,
trial_number=None,
):
"""Save validation predictions to a CSV file with class labels and probabilities."""
os.makedirs(config["results_dir"], exist_ok=True)
if trial_number is not None:
os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True)
val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv")
else:
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
if not val_cell_indices or not task_true_labels:
pd.DataFrame().to_csv(val_preds_file, index=False)
return
try:
label_mappings = load_label_mappings(config["results_dir"], config["task_names"])
inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()}
val_cell_mapping = config.get("val_cell_mapping", {})
# Determine maximum number of samples
max_samples = max(
[len(val_cell_indices)] +
[len(task_true_labels[task]) for task in task_true_labels]
)
rows = [
create_prediction_row(
sample_idx, val_cell_indices, task_true_labels, task_pred_labels,
task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping
)
for sample_idx in range(max_samples)
]
pd.DataFrame(rows).to_csv(val_preds_file, index=False)
except Exception as e:
pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False)
def setup_distributed_environment(rank, world_size, config):
"""
Setup the distributed training environment.
Args:
rank (int): The rank of the current process
world_size (int): Total number of processes
config (dict): Configuration dictionary
"""
os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost')
os.environ['MASTER_PORT'] = config.get('master_port', '12355')
# Initialize the process group
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank
)
# Set the device for this process
torch.cuda.set_device(rank)
def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
"""Run distributed training across multiple GPUs with fallback to single GPU."""
world_size = torch.cuda.device_count()
if world_size <= 1:
print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
config["distributed_training"] = False
trainer = trainer_class(config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainer.device = device
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
)
val_loss, model = trainer.train(
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
)
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
save_model(model, model_save_directory)
save_hyperparameters(model_save_directory, {
**get_config_value(config, "manual_hyperparameters", {}),
"dropout_rate": config["dropout_rate"],
"use_task_weights": config["use_task_weights"],
"task_weights": config["task_weights"],
"max_layers_to_freeze": config["max_layers_to_freeze"],
"use_attention_pooling": config["use_attention_pooling"],
})
if shared_dict is not None:
shared_dict['val_loss'] = val_loss
task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config)
shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
return val_loss, model
print(f"Using distributed training with {world_size} GPUs")
mp.spawn(
_distributed_worker,
args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict),
nprocs=world_size,
join=True
)
if trial_number is None and shared_dict is None:
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
model_path = os.path.join(model_save_directory, "pytorch_model.bin")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = create_model(config, num_labels_list, device)
model.load_state_dict(torch.load(model_path))
return 0.0, model
return None
def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
"""Worker function for distributed training."""
setup_distributed_environment(rank, world_size, config)
config["local_rank"] = rank
# Set up distributed samplers
from torch.utils.data import DistributedSampler
from .data import get_data_loader
train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False)
val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False)
if rank == 0:
print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples")
print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}")
# Create and setup trainer
trainer = trainer_class(config)
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
)
# Train the model
val_loss, model = trainer.train(
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
)
# Save model only from the main process
if rank == 0:
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
save_model(model, model_save_directory)
save_hyperparameters(model_save_directory, {
**get_config_value(config, "manual_hyperparameters", {}),
"dropout_rate": config["dropout_rate"],
"use_task_weights": config["use_task_weights"],
"task_weights": config["task_weights"],
"max_layers_to_freeze": config["max_layers_to_freeze"],
"use_attention_pooling": config["use_attention_pooling"],
})
# For Optuna trials, store results in shared dictionary
if shared_dict is not None:
shared_dict['val_loss'] = val_loss
# Run validation on full dataset from rank 0 for consistent metrics
full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False)
# Get validation predictions using our utility function
task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(
model, full_val_loader, trainer.device, config
)
# Calculate metrics
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
shared_dict['task_metrics'] = task_metrics
# Store model state dict
if isinstance(model, DDP):
model_state_dict = model.module.state_dict()
else:
model_state_dict = model.state_dict()
shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()}
# Clean up distributed environment
dist.destroy_process_group()
def save_model_without_heads(model_directory):
"""
Save a version of the fine-tuned model without classification heads.
Args:
model_directory (str): Path to the directory containing the fine-tuned model
"""
import torch
from transformers import BertConfig, BertModel
# Load the full model
model_path = os.path.join(model_directory, "pytorch_model.bin")
config_path = os.path.join(model_directory, "config.json")
if not os.path.exists(model_path) or not os.path.exists(config_path):
raise FileNotFoundError(f"Model files not found in {model_directory}")
# Load the configuration
config = BertConfig.from_json_file(config_path)
# Load the model state dict
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Create a new model without heads
base_model = BertModel(config)
# Filter out the classification head parameters
base_model_state_dict = {}
for key, value in state_dict.items():
# Only keep parameters that belong to the base model (not classification heads)
if not key.startswith('classification_heads') and not key.startswith('attention_pool'):
base_model_state_dict[key] = value
# Load the filtered state dict into the base model
base_model.load_state_dict(base_model_state_dict, strict=False)
# Save the model without heads
output_dir = os.path.join(model_directory, "model_without_heads")
os.makedirs(output_dir, exist_ok=True)
# Save the model weights
torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
# Save the configuration
base_model.config.to_json_file(os.path.join(output_dir, "config.json"))
print(f"Model without classification heads saved to {output_dir}")
return output_dir
def get_config_value(config: Dict, key: str, default=None):
return config.get(key, default)
def collect_validation_predictions(model, val_loader, device, config) -> tuple:
task_true_labels = {}
task_pred_labels = {}
task_pred_probs = {}
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]]
_, logits, _ = model(input_ids, attention_mask, labels)
for sample_idx in range(len(batch["input_ids"])):
for i, task_name in enumerate(config["task_names"]):
if task_name not in task_true_labels:
task_true_labels[task_name] = []
task_pred_labels[task_name] = []
task_pred_probs[task_name] = []
true_label = batch["labels"][task_name][sample_idx].item()
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
task_true_labels[task_name].append(true_label)
task_pred_labels[task_name].append(pred_label)
task_pred_probs[task_name].append(pred_prob)
return task_true_labels, task_pred_labels, task_pred_probs