| import os |
| import pandas as pd |
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
| import optuna |
| import functools |
| import time |
|
|
| from .model import GeneformerMultiTask |
| from .utils import ( |
| calculate_metrics, |
| get_layer_freeze_range, |
| set_seed, |
| initialize_wandb, |
| create_model, |
| setup_optimizer_and_scheduler, |
| save_model, |
| save_hyperparameters, |
| prepare_training_environment, |
| log_training_step, |
| log_validation_metrics, |
| save_validation_predictions, |
| setup_logging, |
| setup_distributed_environment, |
| train_distributed |
| ) |
|
|
|
|
| class Trainer: |
| """Trainer class for multi-task learning""" |
| |
| def __init__(self, config): |
| self.config = config |
| self.device = None |
| self.model = None |
| self.optimizer = None |
| self.scheduler = None |
| self.writer = None |
| self.is_distributed = config.get("distributed_training", False) |
| self.local_rank = config.get("local_rank", 0) |
| self.is_main_process = not self.is_distributed or self.local_rank == 0 |
| |
| def train_epoch(self, train_loader, epoch): |
| """Train the model for one epoch.""" |
| epoch_start = time.time() |
| self.model.train() |
| |
| |
| if self.is_distributed: |
| |
| world_size = dist.get_world_size() |
| |
| total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader) |
| else: |
| world_size = 1 |
| total_batches_global = len(train_loader) |
| |
| progress_bar = None |
| if self.is_main_process: |
| |
| progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}", |
| total=len(train_loader)) |
| iterator = progress_bar |
| |
| |
| if self.is_distributed: |
| print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, " |
| f"{total_batches_global} total batches globally") |
| else: |
| iterator = train_loader |
| |
| batch_times = [] |
| forward_times = [] |
| backward_times = [] |
| optimizer_times = [] |
| |
| |
| accumulation_steps = self.config.get("gradient_accumulation_steps", 1) |
| |
| |
| self.optimizer.zero_grad() |
| |
| |
| total_loss = 0.0 |
| num_batches = 0 |
| accumulated_loss = 0.0 |
| |
| for batch_idx, batch in enumerate(iterator): |
| batch_start = time.time() |
| |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| labels = [ |
| batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"] |
| ] |
|
|
| forward_start = time.time() |
| loss, _, _ = self.model(input_ids, attention_mask, labels) |
| |
| |
| if accumulation_steps > 1: |
| loss = loss / accumulation_steps |
| |
| forward_end = time.time() |
| forward_times.append(forward_end - forward_start) |
| |
| |
| unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps) |
| total_loss += unscaled_loss |
| num_batches += 1 |
| accumulated_loss += loss.item() |
| |
| backward_start = time.time() |
| |
| |
| if self.is_distributed and accumulation_steps > 1: |
| |
| if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader): |
| with self.model.no_sync(): |
| loss.backward() |
| else: |
| loss.backward() |
| else: |
| |
| loss.backward() |
| |
| backward_end = time.time() |
| backward_times.append(backward_end - backward_start) |
|
|
| |
| if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader): |
| if self.config["gradient_clipping"]: |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"]) |
|
|
| optimizer_start = time.time() |
| self.optimizer.step() |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
| optimizer_end = time.time() |
| optimizer_times.append(optimizer_end - optimizer_start) |
| |
| |
| if self.is_main_process: |
| |
| avg_loss = total_loss / num_batches |
| |
| log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx) |
| |
| |
| progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"}) |
| |
| accumulated_loss = 0.0 |
| else: |
| optimizer_times.append(0) |
| |
| batch_end = time.time() |
| batch_times.append(batch_end - batch_start) |
| |
| epoch_end = time.time() |
| |
| |
| epoch_avg_loss = total_loss / num_batches |
| |
| |
| if self.is_distributed: |
| |
| loss_tensor = torch.tensor([epoch_avg_loss], device=self.device) |
| |
| all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())] |
| dist.all_gather(all_losses, loss_tensor) |
| |
| epoch_avg_loss = torch.mean(torch.stack(all_losses)).item() |
| |
| if self.is_main_process: |
| |
| |
| per_gpu_batch_size = self.config['batch_size'] |
| total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size |
| |
| print(f"Epoch {epoch+1} timing:") |
| print(f" Total epoch time: {epoch_end - epoch_start:.2f}s") |
| print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s") |
| print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s") |
| print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s") |
| print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s") |
| print(f" Gradient accumulation steps: {accumulation_steps}") |
| print(f" Batch size per GPU: {per_gpu_batch_size}") |
| print(f" Effective global batch size: {total_effective_batch}") |
| print(f" Average training loss: {epoch_avg_loss:.4f}") |
| if self.is_distributed: |
| print(f" Total batches processed across all GPUs: {total_batches_global}") |
| print(f" Communication optimization: Using no_sync() for gradient accumulation") |
|
|
| return epoch_avg_loss |
|
|
| def validate_model(self, val_loader): |
| val_start = time.time() |
| self.model.eval() |
| val_loss = 0.0 |
| task_true_labels = {task_name: [] for task_name in self.config["task_names"]} |
| task_pred_labels = {task_name: [] for task_name in self.config["task_names"]} |
| task_pred_probs = {task_name: [] for task_name in self.config["task_names"]} |
| |
| val_cell_ids = {} |
| sample_counter = 0 |
|
|
| batch_times = [] |
| |
| |
| if self.is_main_process: |
| print(f"Validation dataset size: {len(val_loader.dataset)} samples") |
| print(f"Number of validation batches: {len(val_loader)}") |
|
|
| if self.is_distributed: |
| world_size = dist.get_world_size() |
| print(f"Distributed validation: {world_size} GPUs") |
| if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'): |
| samples_per_gpu = val_loader.sampler.num_samples |
| print(f"Each GPU processes {samples_per_gpu} validation samples") |
| print(f"Total validation samples processed: {samples_per_gpu * world_size}") |
| |
| with torch.no_grad(): |
| for batch in val_loader: |
| batch_start = time.time() |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| labels = [ |
| batch["labels"][task_name].to(self.device) |
| for task_name in self.config["task_names"] |
| ] |
| loss, logits, _ = self.model(input_ids, attention_mask, labels) |
| val_loss += loss.item() |
|
|
| if "cell_id" in batch: |
| for i, cell_id in enumerate(batch["cell_id"]): |
| |
| val_cell_ids[sample_counter + i] = cell_id.item() |
| |
| for sample_idx in range(len(batch["input_ids"])): |
| for i, task_name in enumerate(self.config["task_names"]): |
| true_label = batch["labels"][task_name][sample_idx].item() |
| pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() |
| |
| pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist() |
| task_true_labels[task_name].append(true_label) |
| task_pred_labels[task_name].append(pred_label) |
| task_pred_probs[task_name].append(pred_prob) |
| |
| |
| sample_counter += len(batch["input_ids"]) |
| |
| batch_end = time.time() |
| batch_times.append(batch_end - batch_start) |
|
|
| |
| val_loss /= len(val_loader) |
| |
| |
| if self.is_distributed: |
| |
| loss_tensor = torch.tensor([val_loss], device=self.device) |
| gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())] |
| dist.all_gather(gathered_losses, loss_tensor) |
| |
| |
| val_loss = torch.mean(torch.cat(gathered_losses)).item() |
| |
| world_size = dist.get_world_size() |
| |
| if self.is_main_process: |
| print(f"Collected predictions from rank {self.local_rank}") |
| print(f"Number of samples processed by this rank: {sample_counter}") |
| |
| val_end = time.time() |
| |
| if self.is_main_process: |
| print(f"Validation timing:") |
| print(f" Total validation time: {val_end - val_start:.2f}s") |
| print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s") |
| print(f" Collected {len(val_cell_ids)} cell indices from validation data") |
| print(f" Processed {sample_counter} total samples during validation") |
| |
| |
| for task_name in self.config["task_names"]: |
| print(f" Task {task_name}: {len(task_true_labels[task_name])} samples") |
| |
| return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids |
|
|
| def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): |
| """Train the model and return validation loss and trained model.""" |
| if self.config.get("use_wandb", False) and self.is_main_process: |
| initialize_wandb(self.config) |
|
|
| |
| self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank) |
| |
| |
| total_steps = len(train_loader) * self.config["epochs"] |
| self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps) |
|
|
| |
| if self.is_main_process: |
| epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress") |
| else: |
| epoch_progress = range(self.config["epochs"]) |
| |
| best_val_loss = float('inf') |
| train_losses = [] |
| |
| with setup_logging(self.config) as self.writer: |
| for epoch in epoch_progress: |
| if self.is_distributed: |
| train_loader.sampler.set_epoch(epoch) |
| |
| train_loss = self.train_epoch(train_loader, epoch) |
| train_losses.append(train_loss) |
| |
| |
| if self.config.get("validate_each_epoch", False): |
| val_loss, _, _, _, _ = self.validate_model(val_loader) |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| |
| if self.is_main_process: |
| epoch_progress.set_postfix({ |
| "train_loss": f"{train_loss:.4f}", |
| "val_loss": f"{val_loss:.4f}", |
| "best_val_loss": f"{best_val_loss:.4f}" |
| }) |
| else: |
| if self.is_main_process: |
| epoch_progress.set_postfix({ |
| "train_loss": f"{train_loss:.4f}" |
| }) |
|
|
| val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader) |
| task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") |
|
|
| if self.is_main_process: |
| log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"]) |
|
|
| |
| save_validation_predictions( |
| val_cell_ids, |
| task_true_labels, |
| task_pred_labels, |
| task_pred_probs, |
| {**self.config, "val_cell_mapping": val_cell_id_mapping} |
| ) |
|
|
| if self.config.get("use_wandb", False): |
| import wandb |
| wandb.finish() |
|
|
| print(f"\nTraining Summary:") |
| print(f" Final Training Loss: {train_losses[-1]:.4f}") |
| print(f" Final Validation Loss: {val_loss:.4f}") |
| for task_name, metrics in task_metrics.items(): |
| print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}") |
| |
| return val_loss, self.model |
|
|
| def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): |
| if self.is_distributed: |
| self.device = torch.device(f"cuda:{self.local_rank}") |
| else: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| self.model = create_model(self.config, num_labels_list, self.device) |
| |
| |
| if self.is_distributed: |
| self.model = DDP(self.model, device_ids=[self.local_rank]) |
| |
| |
| from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks |
| |
| |
| self.model.register_comm_hook( |
| state=None, |
| hook=comm_hooks.allreduce_hook |
| ) |
| |
| print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization") |
|
|
| print(f"Rank {self.local_rank}: Using samplers created in distributed worker") |
| print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples") |
| if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'): |
| print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch") |
| |
| if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'): |
| print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples") |
| |
| |
| self.optimizer, self.scheduler = setup_optimizer_and_scheduler( |
| self.model, self.config, len(train_loader) |
| ) |
|
|
| if self.is_main_process and self.config.get("use_wandb", False): |
| initialize_wandb(self.config) |
| |
| return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list |
|
|
|
|
| def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): |
| """Train a model with the given configuration and data.""" |
| |
| if config.get("distributed_training", False): |
| |
| if torch.cuda.device_count() > 1: |
| result = train_distributed( |
| Trainer, |
| config, |
| train_loader, |
| val_loader, |
| train_cell_id_mapping, |
| val_cell_id_mapping, |
| num_labels_list |
| ) |
| if result is not None: |
| return result |
| else: |
| print("Distributed training requested but only one GPU found. Falling back to single GPU training.") |
| config["distributed_training"] = False |
| |
| |
| trainer = Trainer(config) |
| trainer.device = device |
| return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list) |
|
|
|
|
| def objective( |
| trial, |
| train_loader, |
| val_loader, |
| train_cell_id_mapping, |
| val_cell_id_mapping, |
| num_labels_list, |
| config, |
| device, |
| ): |
| """Objective function for Optuna hyperparameter optimization.""" |
| set_seed(config["seed"]) |
| initialize_wandb(config) |
|
|
| trial_config = config.copy() |
| |
| |
| for param_name, param_config in config["hyperparameters"].items(): |
| if param_name == "lr_scheduler_type": |
| trial_config[param_name] = trial.suggest_categorical( |
| param_name, param_config["choices"] |
| ) |
| elif param_name == "task_weights" and config["use_task_weights"]: |
| weights = [ |
| trial.suggest_float( |
| f"task_weight_{i}", |
| param_config["low"], |
| param_config["high"], |
| ) |
| for i in range(len(num_labels_list)) |
| ] |
| weight_sum = sum(weights) |
| trial_config[param_name] = [w / weight_sum for w in weights] |
| elif "log" in param_config and param_config["log"]: |
| trial_config[param_name] = trial.suggest_float( |
| param_name, param_config["low"], param_config["high"], log=True |
| ) |
| else: |
| trial_config[param_name] = trial.suggest_float( |
| param_name, param_config["low"], param_config["high"] |
| ) |
| |
| |
| if "max_layers_to_freeze" in trial_config: |
| if trial_config["max_layers_to_freeze"] is None: |
| |
| freeze_range = get_layer_freeze_range(trial_config["pretrained_path"]) |
| trial_config["max_layers_to_freeze"] = int( |
| trial.suggest_int( |
| "max_layers_to_freeze", |
| freeze_range["min"], |
| freeze_range["max"], |
| ) |
| ) |
| else: |
| |
| min_freeze = trial_config["max_layers_to_freeze"]["min"] |
| max_freeze = trial_config["max_layers_to_freeze"]["max"] |
|
|
| trial_config["max_layers_to_freeze"] = int( |
| trial.suggest_int("max_layers_to_freeze", min_freeze, max_freeze) |
| ) |
|
|
| trial_config["run_name"] = f"trial_{trial.number}" |
| |
| |
| if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1: |
| manager = mp.Manager() |
| shared_dict = manager.dict() |
| |
| train_distributed( |
| Trainer, |
| trial_config, |
| train_loader, |
| val_loader, |
| train_cell_id_mapping, |
| val_cell_id_mapping, |
| num_labels_list, |
| trial.number, |
| shared_dict |
| ) |
| |
| val_loss = shared_dict.get('val_loss', float('inf')) |
| task_metrics = shared_dict.get('task_metrics', {}) |
| |
| trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {})) |
| trial.set_user_attr("task_weights", trial_config["task_weights"]) |
| |
| if config.get("use_wandb", False): |
| import wandb |
| wandb.log({ |
| "trial_number": trial.number, |
| "val_loss": val_loss, |
| **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()}, |
| **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()}, |
| }) |
| wandb.finish() |
| |
| return val_loss |
| |
| with setup_logging(trial_config) as writer: |
| trainer = Trainer(trial_config) |
| trainer.device = device |
| trainer.writer = writer |
| |
| |
| trainer.model = create_model(trial_config, num_labels_list, device) |
| total_steps = len(train_loader) * config["epochs"] |
| trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps) |
|
|
| |
| for epoch in range(config["epochs"]): |
| trainer.train_epoch(train_loader, epoch) |
|
|
| val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader) |
| task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") |
|
|
| |
| log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"]) |
|
|
| |
| save_validation_predictions( |
| val_cell_ids, |
| task_true_labels, |
| task_pred_labels, |
| task_pred_probs, |
| {**trial_config, "val_cell_mapping": val_cell_id_mapping}, |
| trial.number, |
| ) |
|
|
| |
| trial.set_user_attr("model_state_dict", trainer.model.state_dict()) |
| trial.set_user_attr("task_weights", trial_config["task_weights"]) |
|
|
| |
| trial.report(val_loss, config["epochs"]) |
| if trial.should_prune(): |
| raise optuna.TrialPruned() |
|
|
| if config.get("use_wandb", False): |
| import wandb |
| wandb.log( |
| { |
| "trial_number": trial.number, |
| "val_loss": val_loss, |
| **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()}, |
| **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()}, |
| **{k: v for k, v in trial_config.items() if k in [ |
| "learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", |
| "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze" |
| ]}, |
| } |
| ) |
| wandb.finish() |
|
|
| return val_loss |
|
|
|
|
| def run_manual_tuning(config): |
| """Run training with manually specified hyperparameters.""" |
| ( |
| device, |
| train_loader, |
| val_loader, |
| train_cell_id_mapping, |
| val_cell_id_mapping, |
| num_labels_list, |
| ) = prepare_training_environment(config) |
|
|
| print("\nManual hyperparameters being used:") |
| for key, value in config["manual_hyperparameters"].items(): |
| print(f"{key}: {value}") |
| print() |
|
|
| |
| for key, value in config["manual_hyperparameters"].items(): |
| config[key] = value |
|
|
| |
| val_loss, trained_model = train_model( |
| config, |
| device, |
| train_loader, |
| val_loader, |
| train_cell_id_mapping, |
| val_cell_id_mapping, |
| num_labels_list, |
| ) |
|
|
| print(f"\nValidation loss with manual hyperparameters: {val_loss}") |
|
|
| |
| |
| if not config.get("distributed_training", False): |
| model_save_directory = os.path.join( |
| config["model_save_path"], "GeneformerMultiTask" |
| ) |
| save_model(trained_model, model_save_directory) |
|
|
| |
| hyperparams_to_save = { |
| **config["manual_hyperparameters"], |
| "dropout_rate": config["dropout_rate"], |
| "use_task_weights": config["use_task_weights"], |
| "task_weights": config["task_weights"], |
| "max_layers_to_freeze": config["max_layers_to_freeze"], |
| "use_attention_pooling": config["use_attention_pooling"], |
| } |
| save_hyperparameters(model_save_directory, hyperparams_to_save) |
|
|
| return val_loss |
|
|
|
|
| def run_optuna_study(config): |
| """Run hyperparameter optimization using Optuna.""" |
| |
| ( |
| device, |
| train_loader, |
| val_loader, |
| train_cell_id_mapping, |
| val_cell_id_mapping, |
| num_labels_list, |
| ) = prepare_training_environment(config) |
|
|
| |
| if config.get("use_manual_hyperparameters", False): |
| return run_manual_tuning(config) |
|
|
| |
| objective_with_config_and_data = functools.partial( |
| objective, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| train_cell_id_mapping=train_cell_id_mapping, |
| val_cell_id_mapping=val_cell_id_mapping, |
| num_labels_list=num_labels_list, |
| config=config, |
| device=device, |
| ) |
|
|
| |
| study = optuna.create_study( |
| direction="minimize", |
| study_name=config["study_name"], |
| |
| load_if_exists=True, |
| ) |
|
|
| study.optimize(objective_with_config_and_data, n_trials=config["n_trials"]) |
|
|
| |
| best_params = study.best_trial.params |
| best_task_weights = study.best_trial.user_attrs["task_weights"] |
| print("Saving the best model and its hyperparameters...") |
|
|
| |
| best_model = GeneformerMultiTask( |
| config["pretrained_path"], |
| num_labels_list, |
| dropout_rate=best_params["dropout_rate"], |
| use_task_weights=config["use_task_weights"], |
| task_weights=best_task_weights, |
| max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0), |
| use_attention_pooling=best_params.get("use_attention_pooling", False), |
| ) |
|
|
| |
| best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] |
|
|
| best_model_state_dict = { |
| k.replace("module.", ""): v for k, v in best_model_state_dict.items() |
| } |
|
|
| best_model.load_state_dict(best_model_state_dict, strict=False) |
|
|
| model_save_directory = os.path.join( |
| config["model_save_path"], "GeneformerMultiTask" |
| ) |
| save_model(best_model, model_save_directory) |
|
|
| save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights}) |
|
|
| return study.best_trial.value |
|
|