|
|
"""Script to automatically select best model from multiple runs.""" |
|
|
|
|
|
import logging |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def select_best_model( |
|
|
results_dir: str, |
|
|
metric_name: str = "val_f1", |
|
|
mode: str = "max", |
|
|
) -> dict: |
|
|
""" |
|
|
Select best model from results directory. |
|
|
|
|
|
Args: |
|
|
results_dir: Directory containing model results |
|
|
metric_name: Metric to use for selection |
|
|
mode: "max" or "min" |
|
|
|
|
|
Returns: |
|
|
Dictionary with best model information |
|
|
""" |
|
|
results_path = Path(results_dir) |
|
|
|
|
|
if not results_path.exists(): |
|
|
raise ValueError(f"Results directory not found: {results_dir}") |
|
|
|
|
|
best_value = float("-inf") if mode == "max" else float("inf") |
|
|
best_model = None |
|
|
best_run = None |
|
|
|
|
|
|
|
|
for result_file in results_path.rglob("*.json"): |
|
|
try: |
|
|
with open(result_file) as f: |
|
|
result = json.load(f) |
|
|
|
|
|
metric_value = result.get(metric_name) |
|
|
if metric_value is None: |
|
|
continue |
|
|
|
|
|
is_best = False |
|
|
if mode == "max": |
|
|
if metric_value > best_value: |
|
|
is_best = True |
|
|
else: |
|
|
if metric_value < best_value: |
|
|
is_best = True |
|
|
|
|
|
if is_best: |
|
|
best_value = metric_value |
|
|
best_model = result.get("model_path") |
|
|
best_run = result_file.stem |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to read {result_file}: {e}") |
|
|
|
|
|
if best_model is None: |
|
|
raise ValueError("No valid results found") |
|
|
|
|
|
result = { |
|
|
"best_model_path": best_model, |
|
|
"best_metric_value": best_value, |
|
|
"best_run": best_run, |
|
|
"metric_name": metric_name, |
|
|
} |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("Best Model Selection Results") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Best model: {best_model}") |
|
|
logger.info(f"Best {metric_name}: {best_value:.4f}") |
|
|
logger.info(f"Best run: {best_run}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def select_from_optuna_study( |
|
|
study_path: str, |
|
|
output_path: Optional[str] = None, |
|
|
) -> dict: |
|
|
""" |
|
|
Select best model from Optuna study. |
|
|
|
|
|
Args: |
|
|
study_path: Path to Optuna study file |
|
|
output_path: Path to save best model info |
|
|
|
|
|
Returns: |
|
|
Dictionary with best model information |
|
|
""" |
|
|
import joblib |
|
|
|
|
|
study = joblib.load(study_path) |
|
|
|
|
|
best_trial = study.best_trial |
|
|
best_params = study.best_params |
|
|
best_value = study.best_value |
|
|
|
|
|
result = { |
|
|
"best_trial": best_trial.number, |
|
|
"best_value": best_value, |
|
|
"best_params": best_params, |
|
|
} |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("Optuna Study Results") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Best trial: {best_trial.number}") |
|
|
logger.info(f"Best value: {best_value:.4f}") |
|
|
logger.info("Best parameters:") |
|
|
for key, value in best_params.items(): |
|
|
logger.info(f" {key}: {value}") |
|
|
|
|
|
if output_path: |
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(result, f, indent=2) |
|
|
logger.info(f"Results saved to {output_path}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def select_from_wandb_sweep( |
|
|
project: str, |
|
|
sweep_id: str, |
|
|
entity: Optional[str] = None, |
|
|
) -> dict: |
|
|
""" |
|
|
Select best model from WandB sweep. |
|
|
|
|
|
Args: |
|
|
project: WandB project name |
|
|
sweep_id: Sweep ID |
|
|
entity: WandB entity |
|
|
|
|
|
Returns: |
|
|
Dictionary with best model information |
|
|
""" |
|
|
try: |
|
|
import wandb |
|
|
api = wandb.Api() |
|
|
except ImportError: |
|
|
raise ImportError("wandb not installed. Install with: pip install wandb") |
|
|
|
|
|
sweep = api.sweep(f"{entity or ''}/{project}/{sweep_id}".lstrip('/')) |
|
|
|
|
|
|
|
|
runs = sorted( |
|
|
sweep.runs, |
|
|
key=lambda r: r.summary.get("val_f1", 0), |
|
|
reverse=True, |
|
|
) |
|
|
|
|
|
if not runs: |
|
|
raise ValueError("No runs found in sweep") |
|
|
|
|
|
best_run = runs[0] |
|
|
|
|
|
result = { |
|
|
"run_id": best_run.id, |
|
|
"run_name": best_run.name, |
|
|
"config": dict(best_run.config), |
|
|
"metrics": dict(best_run.summary), |
|
|
} |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("WandB Sweep Results") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Best run: {best_run.name}") |
|
|
logger.info(f"Best val_f1: {best_run.summary.get('val_f1', 'N/A')}") |
|
|
logger.info("Best config:") |
|
|
for key, value in best_run.config.items(): |
|
|
logger.info(f" {key}: {value}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Select best model") |
|
|
parser.add_argument( |
|
|
"--method", |
|
|
type=str, |
|
|
choices=["results", "optuna", "wandb"], |
|
|
default="results", |
|
|
help="Selection method" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--results-dir", |
|
|
type=str, |
|
|
default="results/", |
|
|
help="Results directory (for results method)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--study-path", |
|
|
type=str, |
|
|
help="Path to Optuna study (for optuna method)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--project", |
|
|
type=str, |
|
|
help="WandB project (for wandb method)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sweep-id", |
|
|
type=str, |
|
|
help="WandB sweep ID (for wandb method)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metric", |
|
|
type=str, |
|
|
default="val_f1", |
|
|
help="Metric name" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mode", |
|
|
type=str, |
|
|
choices=["max", "min"], |
|
|
default="max", |
|
|
help="Optimization mode" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
help="Output file path" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.method == "results": |
|
|
result = select_best_model( |
|
|
results_dir=args.results_dir, |
|
|
metric_name=args.metric, |
|
|
mode=args.mode, |
|
|
) |
|
|
elif args.method == "optuna": |
|
|
if not args.study_path: |
|
|
raise ValueError("--study-path required for optuna method") |
|
|
result = select_from_optuna_study( |
|
|
study_path=args.study_path, |
|
|
output_path=args.output, |
|
|
) |
|
|
elif args.method == "wandb": |
|
|
if not args.project or not args.sweep_id: |
|
|
raise ValueError("--project and --sweep-id required for wandb method") |
|
|
result = select_from_wandb_sweep( |
|
|
project=args.project, |
|
|
sweep_id=args.sweep_id, |
|
|
) |
|
|
|
|
|
if args.output: |
|
|
with open(args.output, 'w') as f: |
|
|
json.dump(result, f, indent=2) |
|
|
logger.info(f"Results saved to {args.output}") |
|
|
|
|
|
|