| |
| import os |
| import subprocess |
| import sys |
| import logging |
| import pandas as pd |
| import argparse |
|
|
| def install_package(package_name): |
| """์ง์ ๋ ํจํค์ง๊ฐ ์์ผ๋ฉด pip๋ฅผ ์ฌ์ฉํ์ฌ ์ค์นํฉ๋๋ค.""" |
| try: |
| __import__(package_name) |
| except ImportError: |
| logging.info(f"{package_name} ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฐพ์ ์ ์์ด ์ค์นํฉ๋๋ค...") |
| try: |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package_name, "--quiet"]) |
| logging.info(f"{package_name} ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น ์ฑ๊ณต.") |
| except subprocess.CalledProcessError as e: |
| logging.error(f"์ค๋ฅ: {package_name} ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น ์คํจ.") |
| sys.exit(f"{package_name} ์ค์น ์คํจ๋ก ์คํฌ๋ฆฝํธ๋ฅผ ์ข
๋ฃํฉ๋๋ค.") |
|
|
| |
| required_packages = ["optuna", "pandas", "plotly"] |
| for pkg in required_packages: |
| install_package(pkg) |
|
|
| import optuna |
|
|
| |
| PROJECT_BASE_DIR = "/home/work/baro/sillok25060103" |
| LOG_DIR = os.path.join(PROJECT_BASE_DIR, "logs") |
| os.makedirs(LOG_DIR, exist_ok=True) |
| LOG_FILE_PATH = os.path.join(LOG_DIR, "hpo_universal_analysis.log") |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| handlers=[ |
| logging.FileHandler(LOG_FILE_PATH, mode='w', encoding='utf-8'), |
| logging.StreamHandler(sys.stdout) |
| ] |
| ) |
|
|
| def analyze_hpo_results(db_name, study_name, output_dir, file_prefix): |
| """ |
| ์ง์ ๋ Optuna DB์ Study๋ฅผ ๋ถ์ํ๊ณ ์๊ฐํ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํฉ๋๋ค. |
| """ |
| logging.info("="*60) |
| logging.info(f"'{study_name}' ์ฐ๊ตฌ ๋ถ์ ์์") |
| logging.info(f"DB ํ์ผ: {db_name}") |
| logging.info("="*60) |
|
|
| db_path = os.path.join(PROJECT_BASE_DIR, "optuna_db", db_name) |
| storage_name = f"sqlite:///{db_path}" |
| |
| os.makedirs(output_dir, exist_ok=True) |
| logging.info(f"๋ถ์ ๊ฒฐ๊ณผ๋ฌผ(๊ทธ๋ํ) ์ ์ฅ ๊ฒฝ๋ก: {output_dir}") |
|
|
| try: |
| study = optuna.load_study(study_name=study_name, storage=storage_name) |
| logging.info(f"์ฐ๊ตฌ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ๋ถ๋ฌ์์ต๋๋ค. (์ด Trial: {len(study.trials)})") |
| except Exception as e: |
| logging.error(f"์ฐ๊ตฌ๋ฅผ ๋ถ๋ฌ์ค๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
| return |
|
|
| |
| logging.info("\n--- ์์ 5๊ฐ Trial ์์ฝ (eval_loss ๊ธฐ์ค) ---") |
| |
| |
| completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE] |
| if not completed_trials: |
| logging.warning("์๋ฃ๋ Trial์ด ์์ต๋๋ค. ๋ถ์์ ์ข
๋ฃํฉ๋๋ค.") |
| return |
|
|
| records = [] |
| for t in completed_trials: |
| |
| params = t.user_attrs.get('predefined_params', t.params) |
| record = { |
| 'number': t.number, |
| 'eval_loss': t.value, |
| 'duration': t.duration, |
| 'params': params |
| } |
| records.append(record) |
|
|
| df_results = pd.DataFrame(records).sort_values(by='eval_loss', ascending=True) |
|
|
| with pd.option_context('display.max_rows', 5, 'display.max_columns', None, 'display.width', 120, 'display.max_colwidth', None): |
| print(df_results.head()) |
|
|
| |
| logging.info("\n--- ๊ฒฐ๊ณผ ์๊ฐํ ๊ทธ๋ํ ์์ฑ ์ค ---") |
| try: |
| fig_importance = optuna.visualization.plot_param_importances(study) |
| importance_path = os.path.join(output_dir, f"{file_prefix}param_importances.html") |
| fig_importance.write_html(importance_path) |
| logging.info(f"1. ํ๋ผ๋ฏธํฐ ์ค์๋ ๊ทธ๋ํ ์ ์ฅ ์๋ฃ: {importance_path}") |
|
|
| fig_history = optuna.visualization.plot_optimization_history(study) |
| history_path = os.path.join(output_dir, f"{file_prefix}optimization_history.html") |
| fig_history.write_html(history_path) |
| logging.info(f"2. ์ต์ ํ ๊ณผ์ ๊ทธ๋ํ ์ ์ฅ ์๋ฃ: {history_path}") |
|
|
| fig_slice = optuna.visualization.plot_slice(study) |
| slice_path = os.path.join(output_dir, f"{file_prefix}slice_plot.html") |
| fig_slice.write_html(slice_path) |
| logging.info(f"3. ๊ฐ๋ณ ํ๋ผ๋ฏธํฐ ๊ด๊ณ๋ ๊ทธ๋ํ ์ ์ฅ ์๋ฃ: {slice_path}") |
| except Exception as e: |
| logging.error(f"์๊ฐํ ๊ทธ๋ํ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
| |
| logging.info(f"\n'{study_name}' ์ฐ๊ตฌ ๋ถ์ ์๋ฃ.\n") |
|
|
|
|
| if __name__ == '__main__': |
| |
| parser = argparse.ArgumentParser(description="Optuna HPO ๊ฒฐ๊ณผ๋ฅผ ๋ถ์ํ๊ณ ์๊ฐํํฉ๋๋ค.") |
| parser.add_argument("--db_name", type=str, required=True, help="๋ถ์ํ Optuna DB ํ์ผ๋ช
(optuna_db ํด๋ ๋ด ์์น)") |
| parser.add_argument("--study_name", type=str, required=True, help="๋ถ์ํ Study์ ์ด๋ฆ") |
| parser.add_argument("--output_dir", type=str, default=os.path.join(PROJECT_BASE_DIR, "hpo_analysis_results"), help="์๊ฐํ ๊ฒฐ๊ณผ๋ฌผ์ ์ ์ฅํ ๋๋ ํ ๋ฆฌ") |
| parser.add_argument("--file_prefix", type=str, default="analysis_", help="์์ฑ๋ HTML ํ์ผ๋ช
์ ์ ๋์ฌ (์: 'stage2_')") |
| |
| args = parser.parse_args() |
|
|
| analyze_hpo_results( |
| db_name=args.db_name, |
| study_name=args.study_name, |
| output_dir=args.output_dir, |
| file_prefix=args.file_prefix |
| ) |
|
|