Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| ============================================================ | |
| ์๋ ํ์ต ๋ฐ์ฒ (train_cli.py ์์) | |
| ------------------------------------------------------------ | |
| ์ด ์คํฌ๋ฆฝํธ๋ CSV๋ฅผ ์ฝ์ด ์๋์ผ๋ก ์ปฌ๋ผ ๋งคํ โ ํผ์ฒ ์์ฑ โ | |
| ๋ชจ๋ธ ํ๋ณด ํ์ต(์ต์ : Optuna ํ๋) โ ์ํฐํฉํธ/๋ชจ๋ธ ์ ์ฅ์ | |
| ํ ๋ฒ์ ์ํํฉ๋๋ค. | |
| [์ฌ์ฉ ์] | |
| python train_cli.py --data ./data/sample_sales.csv \ | |
| --project . \ | |
| --valid_ratio 0.2 \ | |
| --use_optuna --optuna_trials 20 | |
| ํ์: | |
| --data ํ์ต์ ์ฌ์ฉํ CSV ํ์ผ ๊ฒฝ๋ก | |
| ์ ํ: | |
| --project ์์ ๋ฃจํธ ํด๋(๊ธฐ๋ณธ: ํ์ฌ ํด๋ ".") | |
| --valid_ratio ๊ฒ์ฆ ๋น์จ(0.05~0.4 ๊ถ์ฅ, ๊ธฐ๋ณธ 0.2) | |
| --use_optuna Optuna ํ๋ ์ฌ์ฉ ํ๋๊ทธ(์ง์ ์ on) | |
| --optuna_trials Optuna ์๋ ํ์(๊ธฐ๋ณธ 15) | |
| ์ถ๋ ฅ: | |
| ํ๋ก์ ํธ ํด๋ ์๋์ | |
| artifacts/ (๋ก๊ทธ/๋ฆฌ๋๋ณด๋ ๋ฑ ์ค๊ฐ ์ฐ์ถ๋ฌผ) | |
| models/ (best_model.pkl ๋ฑ ๋ชจ๋ธ ํ์ผ) | |
| ์ด ์์ฑ๋ฉ๋๋ค. | |
| ============================================================ | |
| """ | |
| import os | |
| import argparse | |
| import pandas as pd # (ํ์ํ๋ฉด ์ถํ ์ฌ์ฉ, ์ง๊ธ์ ์ํฌํธ๋ง) | |
| from utils_io import read_csv_flexible, save_utf8sig, ensure_dirs, auto_map_columns | |
| from preprocess import make_matrix | |
| from train_core import train_and_score, save_artifacts | |
| def main(): | |
| """ | |
| ์ปค๋งจ๋๋ผ์ธ ์ธ์๋ฅผ ํ์ฑํด์: | |
| 1) CSV ๋ก๋ | |
| 2) ์๋ ์ปฌ๋ผ ๋งคํ | |
| 3) ํ์ต์ฉ ๋ฐ์ดํฐ์ (X, y) ๊ตฌ์ฑ | |
| 4) ๋ชจ๋ธ ํ์ต(+์ต์ : Optuna ํ๋) | |
| 5) ๊ฒฐ๊ณผ ์ ์ฅ(artifacts/, models/) | |
| ๋ฅผ ์์ฐจ ์คํํฉ๋๋ค. | |
| """ | |
| # -------------------------------------------------------- | |
| # 1) ์ปค๋งจ๋๋ผ์ธ ์ต์ ์ ์/ํ์ฑ | |
| # -------------------------------------------------------- | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--data", required=True, help="ํ์ต์ ์ฌ์ฉํ CSV ๊ฒฝ๋ก (์: ./data/sales.csv)") | |
| ap.add_argument("--project", default=".", help="์์ ๋ฃจํธ ํด๋(artifacts/models ์์ฑ ์์น). ๊ธฐ๋ณธ๊ฐ='.'") | |
| ap.add_argument("--valid_ratio", type=float, default=0.2, help="๊ฒ์ฆ ๋ฐ์ดํฐ ๋น์จ(๊ธฐ๋ณธ 0.2)") | |
| ap.add_argument("--use_optuna", action="store_true", help="Optuna ํ๋ ์ฌ์ฉ ์ฌ๋ถ(ํ๋๊ทธ ์ง์ ์ ์ฌ์ฉ)") | |
| ap.add_argument("--optuna_trials", type=int, default=15, help="Optuna ์๋ ํ์(๊ธฐ๋ณธ 15)") | |
| args = ap.parse_args() | |
| # -------------------------------------------------------- | |
| # 2) ์์ ๋ฃจํธ ์ด๋ (์๋ ๊ฒฝ๋ก ํผ๋ ๋ฐฉ์ง) | |
| # -------------------------------------------------------- | |
| proj = os.path.abspath(args.project) # ์ ๋๊ฒฝ๋ก๋ก ๋ณํ | |
| os.chdir(proj) # ์ฌ๊ธธ ๊ธฐ์ค์ผ๋ก ํ์ผ ์ฝ๊ณ /์ ์ฅ | |
| # -------------------------------------------------------- | |
| # 3) CSV ๋ก๋ + ์ปฌ๋ผ ์๋ ๋งคํ | |
| # -------------------------------------------------------- | |
| data = read_csv_flexible(args.data) | |
| mapping = auto_map_columns(data) | |
| # -------------------------------------------------------- | |
| # 4) ํผ์ฒ ๊ตฌ์ฑ(X, y, feat_names ์์ฑ) | |
| # -------------------------------------------------------- | |
| df, X, y, feat_names = make_matrix(data, mapping) | |
| # -------------------------------------------------------- | |
| # 5) ์ถ๋ ฅ ํด๋ ์ค๋น (์์ผ๋ฉด ์์ฑ) | |
| # -------------------------------------------------------- | |
| artifacts = os.path.join(proj, "artifacts") # ๋ฆฌ๋๋ณด๋/๋ก๊ทธ ๋ฑ | |
| models_dir = os.path.join(proj, "models") # best_model.pkl ์ ์ฅ ์์น | |
| ensure_dirs(artifacts, models_dir) | |
| # -------------------------------------------------------- | |
| # 6) ๋ชจ๋ธ ํ์ต(+์ต์ : Optuna) & ๋ฆฌ๋๋ณด๋ ํ๋ | |
| # -------------------------------------------------------- | |
| best_model, lb = train_and_score( | |
| X, y, | |
| valid_ratio=args.valid_ratio, | |
| use_optuna=args.use_optuna, | |
| optuna_trials=args.optuna_trials | |
| ) | |
| # -------------------------------------------------------- | |
| # 7) ์ฐ์ถ๋ฌผ ์ ์ฅ (๋ชจ๋ธ/๋ฉํ๋ฐ์ดํฐ/๋ฆฌ๋๋ณด๋) | |
| # -------------------------------------------------------- | |
| save_artifacts([artifacts, models_dir], best_model, feat_names, mapping, lb) | |
| # -------------------------------------------------------- | |
| # 8) ์ฝ์ ๋ก๊ทธ(์์ฝ) | |
| # -------------------------------------------------------- | |
| print("โ training done.") | |
| print(" - artifacts:", artifacts) | |
| print(" - models :", models_dir) | |
| try: | |
| print(lb.head()) | |
| except Exception: | |
| print(lb) | |
| if __name__ == "__main__": | |
| main() | |