ddokbaro commited on
Commit
b6e3c85
ยท
verified ยท
1 Parent(s): b7dc9d7

Upload 2 files

Browse files
scripts/hpo_result_analyzer_universal.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ ๋ฐ ์ž๋™ ์„ค์น˜
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import logging
6
+ import pandas as pd
7
+ import argparse
8
+
9
+ def install_package(package_name):
10
+ """์ง€์ •๋œ ํŒจํ‚ค์ง€๊ฐ€ ์—†์œผ๋ฉด pip๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค."""
11
+ try:
12
+ __import__(package_name)
13
+ except ImportError:
14
+ logging.info(f"{package_name} ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์–ด ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค...")
15
+ try:
16
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_name, "--quiet"])
17
+ logging.info(f"{package_name} ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ ์„ฑ๊ณต.")
18
+ except subprocess.CalledProcessError as e:
19
+ logging.error(f"์˜ค๋ฅ˜: {package_name} ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ ์‹คํŒจ.")
20
+ sys.exit(f"{package_name} ์„ค์น˜ ์‹คํŒจ๋กœ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
21
+
22
+ # --- ํŒจํ‚ค์ง€ ์„ค์น˜ ---
23
+ required_packages = ["optuna", "pandas", "plotly"]
24
+ for pkg in required_packages:
25
+ install_package(pkg)
26
+
27
+ import optuna
28
+
29
+ # --- 1. ๊ฒฝ๋กœ ๋ฐ ๋กœ๊น… ์„ค์ • ---
30
+ PROJECT_BASE_DIR = "/home/work/baro/sillok25060103"
31
+ LOG_DIR = os.path.join(PROJECT_BASE_DIR, "logs")
32
+ os.makedirs(LOG_DIR, exist_ok=True)
33
+ LOG_FILE_PATH = os.path.join(LOG_DIR, "hpo_universal_analysis.log")
34
+
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format="%(asctime)s [%(levelname)s] %(message)s",
38
+ handlers=[
39
+ logging.FileHandler(LOG_FILE_PATH, mode='w', encoding='utf-8'),
40
+ logging.StreamHandler(sys.stdout)
41
+ ]
42
+ )
43
+
44
+ def analyze_hpo_results(db_name, study_name, output_dir, file_prefix):
45
+ """
46
+ ์ง€์ •๋œ Optuna DB์™€ Study๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
47
+ """
48
+ logging.info("="*60)
49
+ logging.info(f"'{study_name}' ์—ฐ๊ตฌ ๋ถ„์„ ์‹œ์ž‘")
50
+ logging.info(f"DB ํŒŒ์ผ: {db_name}")
51
+ logging.info("="*60)
52
+
53
+ db_path = os.path.join(PROJECT_BASE_DIR, "optuna_db", db_name)
54
+ storage_name = f"sqlite:///{db_path}"
55
+
56
+ os.makedirs(output_dir, exist_ok=True)
57
+ logging.info(f"๋ถ„์„ ๊ฒฐ๊ณผ๋ฌผ(๊ทธ๋ž˜ํ”„) ์ €์žฅ ๊ฒฝ๋กœ: {output_dir}")
58
+
59
+ try:
60
+ study = optuna.load_study(study_name=study_name, storage=storage_name)
61
+ logging.info(f"์—ฐ๊ตฌ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ๋ถˆ๋Ÿฌ์™”์Šต๋‹ˆ๋‹ค. (์ด Trial: {len(study.trials)})")
62
+ except Exception as e:
63
+ logging.error(f"์—ฐ๊ตฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
64
+ return
65
+
66
+ # --- ๊ฒฐ๊ณผ ์š”์•ฝ ํ…Œ์ด๋ธ” ์ถœ๋ ฅ ---
67
+ logging.info("\n--- ์ƒ์œ„ 5๊ฐœ Trial ์š”์•ฝ (eval_loss ๊ธฐ์ค€) ---")
68
+
69
+ # *** ์ˆ˜์ •๋œ ๋ถ€๋ถ„: DataFrame์„ ์ˆ˜๋™์œผ๋กœ ์ƒ์„ฑํ•˜์—ฌ ์•ˆ์ •์„ฑ ํ™•๋ณด ***
70
+ completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
71
+ if not completed_trials:
72
+ logging.warning("์™„๋ฃŒ๋œ Trial์ด ์—†์Šต๋‹ˆ๋‹ค. ๋ถ„์„์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
73
+ return
74
+
75
+ records = []
76
+ for t in completed_trials:
77
+ # 3๋‹จ๊ณ„ HPO์ฒ˜๋Ÿผ user_attrs์— ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ €์žฅํ–ˆ์„ ๊ฒฝ์šฐ์™€ ์ผ๋ฐ˜์ ์ธ ๊ฒฝ์šฐ ๋ชจ๋‘ ์ฒ˜๋ฆฌ
78
+ params = t.user_attrs.get('predefined_params', t.params)
79
+ record = {
80
+ 'number': t.number,
81
+ 'eval_loss': t.value,
82
+ 'duration': t.duration,
83
+ 'params': params
84
+ }
85
+ records.append(record)
86
+
87
+ df_results = pd.DataFrame(records).sort_values(by='eval_loss', ascending=True)
88
+
89
+ with pd.option_context('display.max_rows', 5, 'display.max_columns', None, 'display.width', 120, 'display.max_colwidth', None):
90
+ print(df_results.head())
91
+
92
+ # --- ์‹œ๊ฐํ™” ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ๋ฐ ์ €์žฅ ---
93
+ logging.info("\n--- ๊ฒฐ๊ณผ ์‹œ๊ฐํ™” ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘ ---")
94
+ try:
95
+ fig_importance = optuna.visualization.plot_param_importances(study)
96
+ importance_path = os.path.join(output_dir, f"{file_prefix}param_importances.html")
97
+ fig_importance.write_html(importance_path)
98
+ logging.info(f"1. ํŒŒ๋ผ๋ฏธํ„ฐ ์ค‘์š”๋„ ๊ทธ๋ž˜ํ”„ ์ €์žฅ ์™„๋ฃŒ: {importance_path}")
99
+
100
+ fig_history = optuna.visualization.plot_optimization_history(study)
101
+ history_path = os.path.join(output_dir, f"{file_prefix}optimization_history.html")
102
+ fig_history.write_html(history_path)
103
+ logging.info(f"2. ์ตœ์ ํ™” ๊ณผ์ • ๊ทธ๋ž˜ํ”„ ์ €์žฅ ์™„๋ฃŒ: {history_path}")
104
+
105
+ fig_slice = optuna.visualization.plot_slice(study)
106
+ slice_path = os.path.join(output_dir, f"{file_prefix}slice_plot.html")
107
+ fig_slice.write_html(slice_path)
108
+ logging.info(f"3. ๊ฐœ๋ณ„ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ด€๊ณ„๋„ ๊ทธ๋ž˜ํ”„ ์ €์žฅ ์™„๋ฃŒ: {slice_path}")
109
+ except Exception as e:
110
+ logging.error(f"์‹œ๊ฐํ™” ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
111
+
112
+ logging.info(f"\n'{study_name}' ์—ฐ๊ตฌ ๋ถ„์„ ์™„๋ฃŒ.\n")
113
+
114
+
115
+ if __name__ == '__main__':
116
+ # ํ„ฐ๋ฏธ๋„์—์„œ ์‹คํ–‰ํ•  ๋•Œ ์ธ์ž(argument)๋ฅผ ๋ฐ›์•„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
117
+ parser = argparse.ArgumentParser(description="Optuna HPO ๊ฒฐ๊ณผ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.")
118
+ parser.add_argument("--db_name", type=str, required=True, help="๋ถ„์„ํ•  Optuna DB ํŒŒ์ผ๋ช… (optuna_db ํด๋” ๋‚ด ์œ„์น˜)")
119
+ parser.add_argument("--study_name", type=str, required=True, help="๋ถ„์„ํ•  Study์˜ ์ด๋ฆ„")
120
+ parser.add_argument("--output_dir", type=str, default=os.path.join(PROJECT_BASE_DIR, "hpo_analysis_results"), help="์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฌผ์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ")
121
+ parser.add_argument("--file_prefix", type=str, default="analysis_", help="์ƒ์„ฑ๋  HTML ํŒŒ์ผ๋ช…์˜ ์ ‘๋‘์‚ฌ (์˜ˆ: 'stage2_')")
122
+
123
+ args = parser.parse_args()
124
+
125
+ analyze_hpo_results(
126
+ db_name=args.db_name,
127
+ study_name=args.study_name,
128
+ output_dir=args.output_dir,
129
+ file_prefix=args.file_prefix
130
+ )