SillokBert / scripts /hpo_result_analyzer_universal.py
ddokbaro's picture
Upload 2 files
b6e3c85 verified
# ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ ๋ฐ ์ž๋™ ์„ค์น˜
import os
import subprocess
import sys
import logging
import pandas as pd
import argparse
def install_package(package_name):
"""์ง€์ •๋œ ํŒจํ‚ค์ง€๊ฐ€ ์—†์œผ๋ฉด pip๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค."""
try:
__import__(package_name)
except ImportError:
logging.info(f"{package_name} ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์–ด ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค...")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name, "--quiet"])
logging.info(f"{package_name} ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ ์„ฑ๊ณต.")
except subprocess.CalledProcessError as e:
logging.error(f"์˜ค๋ฅ˜: {package_name} ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ ์‹คํŒจ.")
sys.exit(f"{package_name} ์„ค์น˜ ์‹คํŒจ๋กœ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
# --- ํŒจํ‚ค์ง€ ์„ค์น˜ ---
required_packages = ["optuna", "pandas", "plotly"]
for pkg in required_packages:
install_package(pkg)
import optuna
# --- 1. ๊ฒฝ๋กœ ๋ฐ ๋กœ๊น… ์„ค์ • ---
PROJECT_BASE_DIR = "/home/work/baro/sillok25060103"
LOG_DIR = os.path.join(PROJECT_BASE_DIR, "logs")
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILE_PATH = os.path.join(LOG_DIR, "hpo_universal_analysis.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(LOG_FILE_PATH, mode='w', encoding='utf-8'),
logging.StreamHandler(sys.stdout)
]
)
def analyze_hpo_results(db_name, study_name, output_dir, file_prefix):
"""
์ง€์ •๋œ Optuna DB์™€ Study๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
"""
logging.info("="*60)
logging.info(f"'{study_name}' ์—ฐ๊ตฌ ๋ถ„์„ ์‹œ์ž‘")
logging.info(f"DB ํŒŒ์ผ: {db_name}")
logging.info("="*60)
db_path = os.path.join(PROJECT_BASE_DIR, "optuna_db", db_name)
storage_name = f"sqlite:///{db_path}"
os.makedirs(output_dir, exist_ok=True)
logging.info(f"๋ถ„์„ ๊ฒฐ๊ณผ๋ฌผ(๊ทธ๋ž˜ํ”„) ์ €์žฅ ๊ฒฝ๋กœ: {output_dir}")
try:
study = optuna.load_study(study_name=study_name, storage=storage_name)
logging.info(f"์—ฐ๊ตฌ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ๋ถˆ๋Ÿฌ์™”์Šต๋‹ˆ๋‹ค. (์ด Trial: {len(study.trials)})")
except Exception as e:
logging.error(f"์—ฐ๊ตฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return
# --- ๊ฒฐ๊ณผ ์š”์•ฝ ํ…Œ์ด๋ธ” ์ถœ๋ ฅ ---
logging.info("\n--- ์ƒ์œ„ 5๊ฐœ Trial ์š”์•ฝ (eval_loss ๊ธฐ์ค€) ---")
# *** ์ˆ˜์ •๋œ ๋ถ€๋ถ„: DataFrame์„ ์ˆ˜๋™์œผ๋กœ ์ƒ์„ฑํ•˜์—ฌ ์•ˆ์ •์„ฑ ํ™•๋ณด ***
completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
if not completed_trials:
logging.warning("์™„๋ฃŒ๋œ Trial์ด ์—†์Šต๋‹ˆ๋‹ค. ๋ถ„์„์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
return
records = []
for t in completed_trials:
# 3๋‹จ๊ณ„ HPO์ฒ˜๋Ÿผ user_attrs์— ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ €์žฅํ–ˆ์„ ๊ฒฝ์šฐ์™€ ์ผ๋ฐ˜์ ์ธ ๊ฒฝ์šฐ ๋ชจ๋‘ ์ฒ˜๋ฆฌ
params = t.user_attrs.get('predefined_params', t.params)
record = {
'number': t.number,
'eval_loss': t.value,
'duration': t.duration,
'params': params
}
records.append(record)
df_results = pd.DataFrame(records).sort_values(by='eval_loss', ascending=True)
with pd.option_context('display.max_rows', 5, 'display.max_columns', None, 'display.width', 120, 'display.max_colwidth', None):
print(df_results.head())
# --- ์‹œ๊ฐํ™” ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ๋ฐ ์ €์žฅ ---
logging.info("\n--- ๊ฒฐ๊ณผ ์‹œ๊ฐํ™” ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘ ---")
try:
fig_importance = optuna.visualization.plot_param_importances(study)
importance_path = os.path.join(output_dir, f"{file_prefix}param_importances.html")
fig_importance.write_html(importance_path)
logging.info(f"1. ํŒŒ๋ผ๋ฏธํ„ฐ ์ค‘์š”๋„ ๊ทธ๋ž˜ํ”„ ์ €์žฅ ์™„๋ฃŒ: {importance_path}")
fig_history = optuna.visualization.plot_optimization_history(study)
history_path = os.path.join(output_dir, f"{file_prefix}optimization_history.html")
fig_history.write_html(history_path)
logging.info(f"2. ์ตœ์ ํ™” ๊ณผ์ • ๊ทธ๋ž˜ํ”„ ์ €์žฅ ์™„๋ฃŒ: {history_path}")
fig_slice = optuna.visualization.plot_slice(study)
slice_path = os.path.join(output_dir, f"{file_prefix}slice_plot.html")
fig_slice.write_html(slice_path)
logging.info(f"3. ๊ฐœ๋ณ„ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ด€๊ณ„๋„ ๊ทธ๋ž˜ํ”„ ์ €์žฅ ์™„๋ฃŒ: {slice_path}")
except Exception as e:
logging.error(f"์‹œ๊ฐํ™” ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
logging.info(f"\n'{study_name}' ์—ฐ๊ตฌ ๋ถ„์„ ์™„๋ฃŒ.\n")
if __name__ == '__main__':
# ํ„ฐ๋ฏธ๋„์—์„œ ์‹คํ–‰ํ•  ๋•Œ ์ธ์ž(argument)๋ฅผ ๋ฐ›์•„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
parser = argparse.ArgumentParser(description="Optuna HPO ๊ฒฐ๊ณผ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.")
parser.add_argument("--db_name", type=str, required=True, help="๋ถ„์„ํ•  Optuna DB ํŒŒ์ผ๋ช… (optuna_db ํด๋” ๋‚ด ์œ„์น˜)")
parser.add_argument("--study_name", type=str, required=True, help="๋ถ„์„ํ•  Study์˜ ์ด๋ฆ„")
parser.add_argument("--output_dir", type=str, default=os.path.join(PROJECT_BASE_DIR, "hpo_analysis_results"), help="์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฌผ์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ")
parser.add_argument("--file_prefix", type=str, default="analysis_", help="์ƒ์„ฑ๋  HTML ํŒŒ์ผ๋ช…์˜ ์ ‘๋‘์‚ฌ (์˜ˆ: 'stage2_')")
args = parser.parse_args()
analyze_hpo_results(
db_name=args.db_name,
study_name=args.study_name,
output_dir=args.output_dir,
file_prefix=args.file_prefix
)