multilabel-news-classifier / scripts /select_best_model.py
Solareva Taisia
chore(release): initial public snapshot
198ccb0
"""Script to automatically select best model from multiple runs."""
import logging
import argparse
from pathlib import Path
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def select_best_model(
results_dir: str,
metric_name: str = "val_f1",
mode: str = "max",
) -> dict:
"""
Select best model from results directory.
Args:
results_dir: Directory containing model results
metric_name: Metric to use for selection
mode: "max" or "min"
Returns:
Dictionary with best model information
"""
results_path = Path(results_dir)
if not results_path.exists():
raise ValueError(f"Results directory not found: {results_dir}")
best_value = float("-inf") if mode == "max" else float("inf")
best_model = None
best_run = None
# Search for results files
for result_file in results_path.rglob("*.json"):
try:
with open(result_file) as f:
result = json.load(f)
metric_value = result.get(metric_name)
if metric_value is None:
continue
is_best = False
if mode == "max":
if metric_value > best_value:
is_best = True
else:
if metric_value < best_value:
is_best = True
if is_best:
best_value = metric_value
best_model = result.get("model_path")
best_run = result_file.stem
except Exception as e:
logger.warning(f"Failed to read {result_file}: {e}")
if best_model is None:
raise ValueError("No valid results found")
result = {
"best_model_path": best_model,
"best_metric_value": best_value,
"best_run": best_run,
"metric_name": metric_name,
}
logger.info("=" * 60)
logger.info("Best Model Selection Results")
logger.info("=" * 60)
logger.info(f"Best model: {best_model}")
logger.info(f"Best {metric_name}: {best_value:.4f}")
logger.info(f"Best run: {best_run}")
return result
def select_from_optuna_study(
study_path: str,
output_path: Optional[str] = None,
) -> dict:
"""
Select best model from Optuna study.
Args:
study_path: Path to Optuna study file
output_path: Path to save best model info
Returns:
Dictionary with best model information
"""
import joblib
study = joblib.load(study_path)
best_trial = study.best_trial
best_params = study.best_params
best_value = study.best_value
result = {
"best_trial": best_trial.number,
"best_value": best_value,
"best_params": best_params,
}
logger.info("=" * 60)
logger.info("Optuna Study Results")
logger.info("=" * 60)
logger.info(f"Best trial: {best_trial.number}")
logger.info(f"Best value: {best_value:.4f}")
logger.info("Best parameters:")
for key, value in best_params.items():
logger.info(f" {key}: {value}")
if output_path:
with open(output_path, 'w') as f:
json.dump(result, f, indent=2)
logger.info(f"Results saved to {output_path}")
return result
def select_from_wandb_sweep(
project: str,
sweep_id: str,
entity: Optional[str] = None,
) -> dict:
"""
Select best model from WandB sweep.
Args:
project: WandB project name
sweep_id: Sweep ID
entity: WandB entity
Returns:
Dictionary with best model information
"""
try:
import wandb
api = wandb.Api()
except ImportError:
raise ImportError("wandb not installed. Install with: pip install wandb")
sweep = api.sweep(f"{entity or ''}/{project}/{sweep_id}".lstrip('/'))
# Get best run
runs = sorted(
sweep.runs,
key=lambda r: r.summary.get("val_f1", 0),
reverse=True,
)
if not runs:
raise ValueError("No runs found in sweep")
best_run = runs[0]
result = {
"run_id": best_run.id,
"run_name": best_run.name,
"config": dict(best_run.config),
"metrics": dict(best_run.summary),
}
logger.info("=" * 60)
logger.info("WandB Sweep Results")
logger.info("=" * 60)
logger.info(f"Best run: {best_run.name}")
logger.info(f"Best val_f1: {best_run.summary.get('val_f1', 'N/A')}")
logger.info("Best config:")
for key, value in best_run.config.items():
logger.info(f" {key}: {value}")
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Select best model")
parser.add_argument(
"--method",
type=str,
choices=["results", "optuna", "wandb"],
default="results",
help="Selection method"
)
parser.add_argument(
"--results-dir",
type=str,
default="results/",
help="Results directory (for results method)"
)
parser.add_argument(
"--study-path",
type=str,
help="Path to Optuna study (for optuna method)"
)
parser.add_argument(
"--project",
type=str,
help="WandB project (for wandb method)"
)
parser.add_argument(
"--sweep-id",
type=str,
help="WandB sweep ID (for wandb method)"
)
parser.add_argument(
"--metric",
type=str,
default="val_f1",
help="Metric name"
)
parser.add_argument(
"--mode",
type=str,
choices=["max", "min"],
default="max",
help="Optimization mode"
)
parser.add_argument(
"--output",
type=str,
help="Output file path"
)
args = parser.parse_args()
if args.method == "results":
result = select_best_model(
results_dir=args.results_dir,
metric_name=args.metric,
mode=args.mode,
)
elif args.method == "optuna":
if not args.study_path:
raise ValueError("--study-path required for optuna method")
result = select_from_optuna_study(
study_path=args.study_path,
output_path=args.output,
)
elif args.method == "wandb":
if not args.project or not args.sweep_id:
raise ValueError("--project and --sweep-id required for wandb method")
result = select_from_wandb_sweep(
project=args.project,
sweep_id=args.sweep_id,
)
if args.output:
with open(args.output, 'w') as f:
json.dump(result, f, indent=2)
logger.info(f"Results saved to {args.output}")