import torch import os import numpy as np from copy import deepcopy from typing import Optional, Dict, List, Any from huggingface_hub import HfApi from transformers import Trainer, TrainingArguments, EarlyStoppingCallback from dataclasses import dataclass try: from probes.hybrid_probe import HybridProbe, HybridProbeConfig from probes.export_packaged_model import export_packaged_model_to_hub from data.dataset_classes import ( EmbedsLabelsDatasetFromDisk, PairEmbedsLabelsDatasetFromDisk, EmbedsLabelsDataset, PairEmbedsLabelsDataset, StringLabelDataset, PairStringLabelDataset, MultiEmbedsLabelsDatasetFromDisk, MultiEmbedsLabelsDataset, ) except ImportError: from .hybrid_probe import HybridProbe, HybridProbeConfig from .export_packaged_model import export_packaged_model_to_hub from ..data.dataset_classes import ( EmbedsLabelsDatasetFromDisk, PairEmbedsLabelsDatasetFromDisk, EmbedsLabelsDataset, PairEmbedsLabelsDataset, StringLabelDataset, PairStringLabelDataset, MultiEmbedsLabelsDatasetFromDisk, MultiEmbedsLabelsDataset, ) try: from data.data_collators import ( EmbedsLabelsCollator, PairEmbedsLabelsCollator, PairCollator_input_ids, StringLabelsCollator, ) from visualization.ci_plots import regression_ci_plot, classification_ci_plot from utils import print_message from metrics import get_compute_metrics from seed_utils import set_global_seed from probes.get_probe import get_probe except ImportError: from ..data.data_collators import ( EmbedsLabelsCollator, PairEmbedsLabelsCollator, PairCollator_input_ids, StringLabelsCollator, ) from ..visualization.ci_plots import regression_ci_plot, classification_ci_plot from ..utils import print_message from ..metrics import get_compute_metrics from ..seed_utils import set_global_seed from .get_probe import get_probe @dataclass class TrainerArguments: def __init__( self, model_save_dir: str, num_epochs: int = 200, probe_batch_size: int = 64, base_batch_size: int = 4, probe_grad_accum: int = 1, base_grad_accum: int = 1, lr: float = 1e-4, weight_decay: float = 0.00, task_type: str = 'regression', patience: int = 3, read_scaler: int = 100, save_model: bool = False, seed: int = 42, train_data_size: int = 100, plots_dir: str = None, full_finetuning: bool = False, hybrid_probe: bool = False, num_workers: int = 0, make_plots: bool = True, num_runs: int = 1, **kwargs ): self.model_save_dir = model_save_dir self.num_epochs = num_epochs self.probe_batch_size = probe_batch_size self.base_batch_size = base_batch_size self.probe_grad_accum = probe_grad_accum self.base_grad_accum = base_grad_accum self.lr = lr self.weight_decay = weight_decay self.task_type = task_type self.patience = patience self.save = save_model self.read_scaler = read_scaler self.seed = seed self.train_data_size = train_data_size self.plots_dir = plots_dir self.full_finetuning = full_finetuning self.hybrid_probe = hybrid_probe self.num_workers = num_workers self.make_plots = make_plots self.num_runs = num_runs def __call__(self, probe: Optional[bool] = True): if self.train_data_size > 350000: eval_strats = { 'eval_strategy': 'steps', 'eval_steps': 5000, 'save_strategy': 'steps', 'save_steps': 5000, } else: eval_strats = { 'eval_strategy': 'epoch', 'save_strategy': 'epoch', } if '/' in self.model_save_dir: save_dir = self.model_save_dir.split('/')[-1] else: save_dir = self.model_save_dir batch_size = self.probe_batch_size if probe else self.base_batch_size grad_accum = self.probe_grad_accum if probe else self.base_grad_accum warmup_steps = 100 if probe else 1000 return TrainingArguments( output_dir=save_dir, num_train_epochs=self.num_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=grad_accum, learning_rate=float(self.lr), lr_scheduler_type='cosine', weight_decay=float(self.weight_decay), warmup_steps=warmup_steps, save_total_limit=3, logging_steps=1000, report_to='none', load_best_model_at_end=True, metric_for_best_model='eval_loss', greater_is_better=False, seed=self.seed, label_names=['labels'], dataloader_num_workers=self.num_workers, dataloader_prefetch_factor=2 if self.num_workers > 0 else None, # Explicitly disable mixed precision training to prevent automatic fp16 conversion fp16=False, bf16=False, **eval_strats ) class TrainerMixin: def __init__(self, trainer_args: Optional[TrainerArguments] = None): self.trainer_args = trainer_args def _format_metric_value(self, value: Any) -> str: if isinstance(value, float): return f"{value:.6f}" return str(value) def _format_metrics_markdown(self, metrics: Dict[str, Any]) -> str: if metrics is None or len(metrics) == 0: return "- No metrics recorded." lines = [] for key in sorted(metrics.keys()): lines.append(f"- `{key}`: {self._format_metric_value(metrics[key])}") return "\n".join(lines) def _build_model_card( self, repo_id: str, data_name: str, model_name: str, log_id: str, train_dataset, valid_dataset, test_dataset, valid_metrics: Dict[str, Any], test_metrics: Dict[str, Any], ) -> str: train_size = len(train_dataset) valid_size = "N/A" if valid_dataset is None else str(len(valid_dataset)) test_size = len(test_dataset) task_type = self.trainer_args.task_type num_runs = self.trainer_args.num_runs validation_metrics_text = self._format_metrics_markdown(valid_metrics) test_metrics_text = self._format_metrics_markdown(test_metrics) return f"""--- library_name: transformers tags: [] --- # {repo_id} Fine-tuned with Protify. ## About Protify Protify is an open source platform designed to simplify and democratize workflows for chemical language models. With Protify, deep learning models can be trained to predict chemical properties without requiring extensive coding knowledge or computational resources. ### Why Protify? - Benchmark multiple models efficiently. - Flexible for all skill levels. - Accessible computing with support for precomputed embeddings. - Cost-effective workflows for training and evaluation. ## Training Run - `dataset`: {data_name} - `model`: {model_name} - `run_id`: {log_id} - `task_type`: {task_type} - `num_runs`: {num_runs} ## Dataset Statistics - `train_size`: {train_size} - `valid_size`: {valid_size} - `test_size`: {test_size} ## Validation Metrics {validation_metrics_text} ## Test Metrics {test_metrics_text} """ def _train( self, model, train_dataset, valid_dataset, test_dataset, data_collator, tokenizer, log_id, model_name, data_name, source_model_name: Optional[str] = None, ppi: bool = False, probe: Optional[bool] = True, skip_plot: bool = False, ): task_type = self.trainer_args.task_type tokenwise = self.probe_args.tokenwise compute_metrics = get_compute_metrics(task_type, tokenwise=tokenwise) self.trainer_args.train_data_size = len(train_dataset) hf_trainer_args = self.trainer_args(probe=probe) ### TODO add options for optimizers and schedulers trainer = Trainer( model=model, args=hf_trainer_args, train_dataset=train_dataset, eval_dataset=valid_dataset, data_collator=data_collator, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=self.trainer_args.patience)] ) trainer.can_return_loss = True metrics = trainer.evaluate(test_dataset) print_message(f'Initial metrics: {metrics}') train_output = trainer.train() train_runtime = train_output.metrics.get('train_runtime', 0.0) valid_metrics = trainer.evaluate(valid_dataset) print_message(f'Final validation metrics: {valid_metrics}') y_pred, y_true, test_metrics = trainer.predict(test_dataset) if isinstance(y_pred, tuple): y_pred = y_pred[0] if isinstance(y_true, tuple): y_true = y_true[0] y_pred, y_true = y_pred.astype(np.float32), y_true.astype(np.float32) # Remove singleton dimension if present if y_pred.ndim == 3 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(1) if y_true.ndim == 3 and y_true.shape[1] == 1: y_true = y_true.squeeze(1) test_metrics['training_time_seconds'] = train_runtime print_message(f'y_pred: {y_pred.shape}\ny_true: {y_true.shape}\nFinal test metrics: \n{test_metrics}\n') if self.trainer_args.make_plots and self.trainer_args.plots_dir is not None and not skip_plot: output_dir = os.path.join(self.trainer_args.plots_dir, log_id) os.makedirs(output_dir, exist_ok=True) save_path = os.path.join(output_dir, f"{data_name}_{model_name}_{log_id}.png") title = f"{data_name} {model_name} {log_id}" if task_type == 'regression': regression_ci_plot(y_true, y_pred, save_path, title) else: classification_ci_plot(y_true, y_pred, save_path, title) if source_model_name is None: source_model_name = model_name if self.trainer_args.save: try: hf_username = self.full_args.hf_username if hf_username is None or hf_username == "": print_message("Warning: hf_username is not set. Cannot save model to HuggingFace Hub.") else: repo_id = f"{hf_username}/{data_name}_{model_name}_{log_id}" hf_token = self.full_args.hf_token if hf_token is None: hf_token = os.environ.get("HF_TOKEN") model_card = self._build_model_card( repo_id=repo_id, data_name=data_name, model_name=model_name, log_id=log_id, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, valid_metrics=valid_metrics, test_metrics=test_metrics, ) packaged_export_succeeded = False if probe or isinstance(trainer.model, HybridProbe): try: packaged_export_succeeded, export_message = export_packaged_model_to_hub( trained_model=trainer.model, source_model_name=source_model_name, probe_args=self.probe_args, embedding_args=self.embedding_args, tokenizer=tokenizer, repo_id=repo_id, model_card=model_card, ppi=ppi, private=True, hf_token=hf_token, ) print_message(export_message) except Exception as packaged_error: print_message(f"Warning: packaged export failed for {repo_id}: {packaged_error}") if not packaged_export_succeeded: print_message(f"Falling back to direct model push_to_hub for {repo_id}") if hf_token is not None: trainer.model.push_to_hub(repo_id, private=True, token=hf_token) api = HfApi(token=hf_token) else: trainer.model.push_to_hub(repo_id, private=True) api = HfApi() api.upload_file( path_or_fileobj=model_card.encode("utf-8"), path_in_repo="README.md", repo_id=repo_id, repo_type="model", ) print_message(f"Successfully saved model to HuggingFace Hub: {repo_id}") except Exception as e: import traceback error_trace = traceback.format_exc() print_message(f"Error saving model to HuggingFace Hub: {e}") print_message(f"Error traceback: {error_trace}") print_message(f"save_model flag: {self.trainer_args.save}") model = trainer.model.cpu() trainer.accelerator.free_memory() torch.cuda.empty_cache() return model, valid_metrics, test_metrics, y_pred, y_true def _aggregate_metrics(self, metrics_list: List[Dict[str, Any]]) -> Dict[str, Any]: """Aggregate metrics across multiple runs, computing mean ± std for each metric.""" if not metrics_list: return {} # Collect all metric keys all_keys = set() for m in metrics_list: all_keys.update(m.keys()) aggregated = {} for key in all_keys: values = [m.get(key) for m in metrics_list if key in m and m[key] is not None] if not values: continue # Check if all values are numeric if all(isinstance(v, (int, float)) for v in values): mean_val = np.mean(values) std_val = np.std(values) # Store as formatted string with mean±std aggregated[key] = f"{mean_val:.4f}±{std_val:.4f}" # Also store raw mean for sorting/comparison purposes aggregated[f"{key}_mean"] = float(mean_val) aggregated[f"{key}_std"] = float(std_val) else: # For non-numeric values, just take the first one aggregated[key] = values[0] return aggregated def trainer_probe( self, model, tokenizer, model_name, data_name, train_dataset, valid_dataset, test_dataset, emb_dict=None, ppi=False, log_id=None, skip_plot=False, source_model_name: Optional[str] = None, ): batch_size = self.trainer_args.probe_batch_size read_scaler = self.trainer_args.read_scaler input_size = self.probe_args.input_size task_type = self.probe_args.task_type tokenwise = self.probe_args.tokenwise num_runs = getattr(self.trainer_args, 'num_runs', 1) base_seed = self.trainer_args.seed print(f'task_type: {task_type}') full = self.embedding_args.matrix_embed db_path = os.path.join(self.embedding_args.embedding_save_dir, f'{model_name}_{full}.db') use_multi = getattr(self.full_args, 'multi_column', None) if self.embedding_args.sql: print('SQL enabled') if ppi: if full: raise ValueError('Full matrix embeddings not currently supported for SQL and PPI') # TODO: Implement DatasetClass = PairEmbedsLabelsDatasetFromDisk CollatorClass = PairEmbedsLabelsCollator elif use_multi: DatasetClass = MultiEmbedsLabelsDatasetFromDisk CollatorClass = EmbedsLabelsCollator else: DatasetClass = EmbedsLabelsDatasetFromDisk CollatorClass = EmbedsLabelsCollator else: print('SQL disabled') if ppi: DatasetClass = PairEmbedsLabelsDataset CollatorClass = PairEmbedsLabelsCollator elif use_multi: DatasetClass = MultiEmbedsLabelsDataset CollatorClass = EmbedsLabelsCollator else: DatasetClass = EmbedsLabelsDataset CollatorClass = EmbedsLabelsCollator """ For collator need to pass tokenizer, full, task_type For dataset need to pass hf_dataset, col_a, col_b, label_col, input_size, task_type, db_path, emb_dict, batch_size, read_scaler, full, train """ add_token_ids = getattr(self.probe_args, 'add_token_ids', False) data_collator = CollatorClass(tokenizer=tokenizer, full=full, task_type=task_type, tokenwise=tokenwise, add_token_ids=add_token_ids) common_kwargs = dict( hf_dataset=train_dataset, input_size=input_size, task_type=task_type, db_path=db_path, emb_dict=emb_dict, batch_size=batch_size, read_scaler=read_scaler, full=full, train=True, random_pair_flipping=self.full_args.random_pair_flipping, ) if use_multi: train_ds = DatasetClass(seq_cols=use_multi, **deepcopy(common_kwargs)) else: train_ds = DatasetClass(**deepcopy(common_kwargs)) # BUG FIX: Update hf_dataset in common_kwargs before creating validation and test datasets. # Previously, common_kwargs['hf_dataset'] was set to train_dataset and never updated, # causing valid_dataset and test_dataset to incorrectly use training data. This resulted # in valid_metrics and test_metrics being identical since they were computed on the same # (training) dataset. The fix ensures each dataset uses the correct HuggingFace dataset. # We use deepcopy to ensure each dataset gets an independent copy of the kwargs dictionary # to prevent any potential shared state issues. common_kwargs['train'] = False common_kwargs['hf_dataset'] = valid_dataset if use_multi: valid_ds = DatasetClass(seq_cols=use_multi, **deepcopy(common_kwargs)) else: valid_ds = DatasetClass(**deepcopy(common_kwargs)) common_kwargs['hf_dataset'] = test_dataset if use_multi: test_ds = DatasetClass(seq_cols=use_multi, **deepcopy(common_kwargs)) else: test_ds = DatasetClass(**deepcopy(common_kwargs)) # Single run - original behavior if num_runs == 1: return self._train( model=model, train_dataset=train_ds, valid_dataset=valid_ds, test_dataset=test_ds, data_collator=data_collator, tokenizer=tokenizer, log_id=log_id, model_name=model_name, data_name=data_name, source_model_name=source_model_name, ppi=ppi, probe=True, skip_plot=skip_plot, ) # Multi-run mode: train multiple times with different seeds, reusing datasets print_message(f"Running {num_runs} training runs with different seeds for {data_name}/{model_name}") all_valid_metrics = [] all_test_metrics = [] run_results = [] # Store (run_idx, test_loss, y_pred, y_true, seed, model) for plotting best for run_idx in range(num_runs): run_seed = base_seed + run_idx self.trainer_args.seed = run_seed set_global_seed(run_seed) print_message(f"=== Run {run_idx + 1}/{num_runs} with seed {run_seed} ===") # Create a fresh probe for each run probe = get_probe(self.probe_args) run_model, valid_metrics, test_metrics, y_pred, y_true = self._train( model=probe, train_dataset=train_ds, valid_dataset=valid_ds, test_dataset=test_ds, data_collator=data_collator, tokenizer=tokenizer, log_id=f"{log_id}_run{run_idx}", model_name=model_name, data_name=data_name, source_model_name=source_model_name, ppi=ppi, probe=True, skip_plot=True, # Skip plots during individual runs ) all_valid_metrics.append(valid_metrics) all_test_metrics.append(test_metrics) # Track test loss for determining best run test_loss = test_metrics.get('test_loss', test_metrics.get('eval_loss', float('inf'))) run_results.append((run_idx, test_loss, y_pred, y_true, run_seed, run_model)) # Restore original seed self.trainer_args.seed = base_seed # Compute aggregated metrics (mean ± std) aggregated_valid = self._aggregate_metrics(all_valid_metrics) aggregated_test = self._aggregate_metrics(all_test_metrics) # Find the best run (lowest test loss) best_run = min(run_results, key=lambda x: x[1]) best_run_idx, best_loss, best_y_pred, best_y_true, best_seed, best_model = best_run print_message(f"Best run: {best_run_idx + 1} (seed={best_seed}, test_loss={best_loss:.4f})") # Generate plot for best run (unless skip_plot is True) if not skip_plot: output_dir = os.path.join(self.trainer_args.plots_dir, log_id) os.makedirs(output_dir, exist_ok=True) save_path = os.path.join(output_dir, f"{data_name}_{model_name}_{log_id}_best.png") title = f"{data_name} {model_name} (best of {num_runs} runs, seed={best_seed})" if task_type == 'regression': regression_ci_plot(best_y_true, best_y_pred, save_path, title) else: classification_ci_plot(best_y_true, best_y_pred, save_path, title) # Return the best model along with aggregated metrics return best_model, aggregated_valid, aggregated_test, best_y_pred, best_y_true def trainer_base_model( self, model, tokenizer, model_name, data_name, train_dataset, valid_dataset, test_dataset, ppi=False, log_id=None, skip_plot=False, model_factory=None, source_model_name: Optional[str] = None, ): task_type = self.probe_args.task_type tokenwise = self.probe_args.tokenwise num_runs = getattr(self.trainer_args, 'num_runs', 1) base_seed = self.trainer_args.seed if ppi: DatasetClass = PairStringLabelDataset CollatorClass = PairCollator_input_ids else: DatasetClass = StringLabelDataset CollatorClass = StringLabelsCollator data_collator = CollatorClass(tokenizer=tokenizer, task_type=task_type, tokenwise=tokenwise) train_ds = DatasetClass(hf_dataset=train_dataset, train=True, random_pair_flipping=self.full_args.random_pair_flipping) valid_ds = DatasetClass(hf_dataset=valid_dataset, train=False, random_pair_flipping=self.full_args.random_pair_flipping) test_ds = DatasetClass(hf_dataset=test_dataset, train=False, random_pair_flipping=self.full_args.random_pair_flipping) # Single run - original behavior if num_runs == 1: return self._train( model=model, train_dataset=train_ds, valid_dataset=valid_ds, test_dataset=test_ds, data_collator=data_collator, tokenizer=tokenizer, log_id=log_id, model_name=model_name, data_name=data_name, source_model_name=source_model_name, ppi=ppi, probe=False, skip_plot=skip_plot, ) # Multi-run mode: train multiple times with different seeds print_message(f"Running {num_runs} full finetuning runs with different seeds for {data_name}/{model_name}") all_valid_metrics = [] all_test_metrics = [] run_results = [] # Store (run_idx, test_loss, y_pred, y_true, seed, model) for plotting best for run_idx in range(num_runs): run_seed = base_seed + run_idx self.trainer_args.seed = run_seed set_global_seed(run_seed) print_message(f"=== Run {run_idx + 1}/{num_runs} with seed {run_seed} ===") # Create a fresh model for each run using the factory if model_factory is not None: run_model = model_factory() trained_model, valid_metrics, test_metrics, y_pred, y_true = self._train( model=run_model, train_dataset=train_ds, valid_dataset=valid_ds, test_dataset=test_ds, data_collator=data_collator, tokenizer=tokenizer, log_id=f"{log_id}_run{run_idx}", model_name=model_name, data_name=data_name, source_model_name=source_model_name, ppi=ppi, probe=False, skip_plot=True, # Skip plots during individual runs ) all_valid_metrics.append(valid_metrics) all_test_metrics.append(test_metrics) # Track test loss for determining best run test_loss = test_metrics.get('test_loss', test_metrics.get('eval_loss', float('inf'))) run_results.append((run_idx, test_loss, y_pred, y_true, run_seed, trained_model)) # Restore original seed self.trainer_args.seed = base_seed # Compute aggregated metrics (mean ± std) aggregated_valid = self._aggregate_metrics(all_valid_metrics) aggregated_test = self._aggregate_metrics(all_test_metrics) # Find the best run (lowest test loss) best_run = min(run_results, key=lambda x: x[1]) best_run_idx, best_loss, best_y_pred, best_y_true, best_seed, best_model = best_run print_message(f"Best run: {best_run_idx + 1} (seed={best_seed}, test_loss={best_loss:.4f})") # Generate plot for best run (unless skip_plot is True) if not skip_plot: output_dir = os.path.join(self.trainer_args.plots_dir, log_id) os.makedirs(output_dir, exist_ok=True) save_path = os.path.join(output_dir, f"{data_name}_{model_name}_{log_id}_best.png") title = f"{data_name} {model_name} (best of {num_runs} runs, seed={best_seed})" if task_type == 'regression': regression_ci_plot(best_y_true, best_y_pred, save_path, title) else: classification_ci_plot(best_y_true, best_y_pred, save_path, title) # Return the best model along with aggregated metrics return best_model, aggregated_valid, aggregated_test, best_y_pred, best_y_true def trainer_hybrid_model( self, model, tokenizer, probe, model_name, data_name, train_dataset, valid_dataset, test_dataset, emb_dict=None, ppi=False, log_id=None, skip_plot=False, model_factory=None, probe_factory=None, source_model_name: Optional[str] = None, ): num_runs = getattr(self.trainer_args, 'num_runs', 1) base_seed = self.trainer_args.seed # Single run - original behavior if num_runs == 1: return self._train_hybrid_single_run( model=model, tokenizer=tokenizer, probe=probe, model_name=model_name, data_name=data_name, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, emb_dict=emb_dict, ppi=ppi, log_id=log_id, skip_plot=skip_plot, source_model_name=source_model_name, ) # Multi-run mode for hybrid probe # For hybrid probe, we only care about final metrics, not intermediate probe metrics # training_time_seconds should sum both probe and model+probe training times print_message(f"Running {num_runs} hybrid probe runs with different seeds for {data_name}/{model_name}") all_valid_metrics = [] all_test_metrics = [] run_results = [] # Store (run_idx, test_loss, y_pred, y_true, seed, model) for plotting best for run_idx in range(num_runs): run_seed = base_seed + run_idx self.trainer_args.seed = run_seed set_global_seed(run_seed) print_message(f"=== Hybrid Run {run_idx + 1}/{num_runs} with seed {run_seed} ===") # Create fresh probe and model for each run using factories if probe_factory is not None: run_probe = probe_factory() if model_factory is not None: run_model = model_factory() trained_model, valid_metrics, test_metrics, y_pred, y_true = self._train_hybrid_single_run( model=run_model, tokenizer=tokenizer, probe=run_probe, model_name=model_name, data_name=data_name, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, emb_dict=emb_dict, ppi=ppi, log_id=f"{log_id}_run{run_idx}", skip_plot=True, # Skip plots during individual runs source_model_name=source_model_name, ) # Only collect final metrics (not intermediate probe metrics) all_valid_metrics.append(valid_metrics) all_test_metrics.append(test_metrics) # Track test loss for determining best run test_loss = test_metrics.get('test_loss', test_metrics.get('eval_loss', float('inf'))) run_results.append((run_idx, test_loss, y_pred, y_true, run_seed, trained_model)) # Restore original seed self.trainer_args.seed = base_seed # Compute aggregated metrics (mean ± std) # This will include training_time_seconds which already has probe + base time summed per run aggregated_valid = self._aggregate_metrics(all_valid_metrics) aggregated_test = self._aggregate_metrics(all_test_metrics) # Find the best run (lowest test loss) best_run = min(run_results, key=lambda x: x[1]) best_run_idx, best_loss, best_y_pred, best_y_true, best_seed, best_model = best_run print_message(f"Best hybrid run: {best_run_idx + 1} (seed={best_seed}, test_loss={best_loss:.4f})") # Generate plot for best run (unless skip_plot is True) task_type = self.probe_args.task_type if not skip_plot: output_dir = os.path.join(self.trainer_args.plots_dir, log_id) os.makedirs(output_dir, exist_ok=True) save_path = os.path.join(output_dir, f"{data_name}_{model_name}_{log_id}_best.png") title = f"{data_name} {model_name} hybrid (best of {num_runs} runs, seed={best_seed})" if task_type == 'regression': regression_ci_plot(best_y_true, best_y_pred, save_path, title) else: classification_ci_plot(best_y_true, best_y_pred, save_path, title) # Return the best model along with aggregated metrics return best_model, aggregated_valid, aggregated_test, best_y_pred, best_y_true def _train_hybrid_single_run( self, model, tokenizer, probe, model_name, data_name, train_dataset, valid_dataset, test_dataset, emb_dict=None, ppi=False, log_id=None, skip_plot=False, source_model_name: Optional[str] = None, ): """Single run of hybrid probe training (probe first, then model+probe).""" # Store original num_runs and temporarily set to 1 for the probe phase original_num_runs = getattr(self.trainer_args, 'num_runs', 1) self.trainer_args.num_runs = 1 probe, _, probe_test_metrics, _, _ = self.trainer_probe( model=probe, tokenizer=tokenizer, model_name=model_name, data_name=data_name, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, emb_dict=emb_dict, ppi=ppi, log_id=log_id, skip_plot=True, # Always skip plot for probe phase in hybrid source_model_name=source_model_name, ) # Restore num_runs self.trainer_args.num_runs = original_num_runs probe_time = probe_test_metrics.get('training_time_seconds') if not isinstance(probe_time, (int, float)): raise ValueError(f"Probe time is not a number: {probe_time}") # ensure we are capturing the time correctly config = HybridProbeConfig( tokenwise=self.probe_args.tokenwise, matrix_embed=self.embedding_args.matrix_embed, pooling_types=self.embedding_args.pooling_types, ) hybrid_model = HybridProbe(config=config, model=model, probe=probe) # Temporarily set num_runs to 1 for the base model phase self.trainer_args.num_runs = 1 base_model, base_valid_metrics, base_test_metrics, y_pred, y_true = self.trainer_base_model( model=hybrid_model, tokenizer=tokenizer, model_name=model_name, data_name=data_name, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, ppi=ppi, log_id=log_id, skip_plot=skip_plot, source_model_name=source_model_name, ) # Restore num_runs self.trainer_args.num_runs = original_num_runs # Sum probe time and base time for total training time if probe_time is not None: base_time = base_test_metrics.get('training_time_seconds') if isinstance(base_time, (int, float)): base_test_metrics['training_time_seconds'] = base_time + probe_time elif base_time is None: base_test_metrics['training_time_seconds'] = probe_time return base_model, base_valid_metrics, base_test_metrics, y_pred, y_true