File size: 4,710 Bytes
5841e58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
============================================================
์๋ ํ์ต ๋ฐ์ฒ (train_cli.py ์์)
------------------------------------------------------------
์ด ์คํฌ๋ฆฝํธ๋ CSV๋ฅผ ์ฝ์ด ์๋์ผ๋ก ์ปฌ๋ผ ๋งคํ โ ํผ์ฒ ์์ฑ โ
๋ชจ๋ธ ํ๋ณด ํ์ต(์ต์
: Optuna ํ๋) โ ์ํฐํฉํธ/๋ชจ๋ธ ์ ์ฅ์
ํ ๋ฒ์ ์ํํฉ๋๋ค.
[์ฌ์ฉ ์]
python train_cli.py --data ./data/sample_sales.csv \
--project . \
--valid_ratio 0.2 \
--use_optuna --optuna_trials 20
ํ์:
--data ํ์ต์ ์ฌ์ฉํ CSV ํ์ผ ๊ฒฝ๋ก
์ ํ:
--project ์์
๋ฃจํธ ํด๋(๊ธฐ๋ณธ: ํ์ฌ ํด๋ ".")
--valid_ratio ๊ฒ์ฆ ๋น์จ(0.05~0.4 ๊ถ์ฅ, ๊ธฐ๋ณธ 0.2)
--use_optuna Optuna ํ๋ ์ฌ์ฉ ํ๋๊ทธ(์ง์ ์ on)
--optuna_trials Optuna ์๋ ํ์(๊ธฐ๋ณธ 15)
์ถ๋ ฅ:
ํ๋ก์ ํธ ํด๋ ์๋์
artifacts/ (๋ก๊ทธ/๋ฆฌ๋๋ณด๋ ๋ฑ ์ค๊ฐ ์ฐ์ถ๋ฌผ)
models/ (best_model.pkl ๋ฑ ๋ชจ๋ธ ํ์ผ)
์ด ์์ฑ๋ฉ๋๋ค.
============================================================
"""
import os
import argparse
import pandas as pd # (ํ์ํ๋ฉด ์ถํ ์ฌ์ฉ, ์ง๊ธ์ ์ํฌํธ๋ง)
from utils_io import read_csv_flexible, save_utf8sig, ensure_dirs, auto_map_columns
from preprocess import make_matrix
from train_core import train_and_score, save_artifacts
def main():
"""
์ปค๋งจ๋๋ผ์ธ ์ธ์๋ฅผ ํ์ฑํด์:
1) CSV ๋ก๋
2) ์๋ ์ปฌ๋ผ ๋งคํ
3) ํ์ต์ฉ ๋ฐ์ดํฐ์
(X, y) ๊ตฌ์ฑ
4) ๋ชจ๋ธ ํ์ต(+์ต์
: Optuna ํ๋)
5) ๊ฒฐ๊ณผ ์ ์ฅ(artifacts/, models/)
๋ฅผ ์์ฐจ ์คํํฉ๋๋ค.
"""
# --------------------------------------------------------
# 1) ์ปค๋งจ๋๋ผ์ธ ์ต์
์ ์/ํ์ฑ
# --------------------------------------------------------
ap = argparse.ArgumentParser()
ap.add_argument("--data", required=True, help="ํ์ต์ ์ฌ์ฉํ CSV ๊ฒฝ๋ก (์: ./data/sales.csv)")
ap.add_argument("--project", default=".", help="์์
๋ฃจํธ ํด๋(artifacts/models ์์ฑ ์์น). ๊ธฐ๋ณธ๊ฐ='.'")
ap.add_argument("--valid_ratio", type=float, default=0.2, help="๊ฒ์ฆ ๋ฐ์ดํฐ ๋น์จ(๊ธฐ๋ณธ 0.2)")
ap.add_argument("--use_optuna", action="store_true", help="Optuna ํ๋ ์ฌ์ฉ ์ฌ๋ถ(ํ๋๊ทธ ์ง์ ์ ์ฌ์ฉ)")
ap.add_argument("--optuna_trials", type=int, default=15, help="Optuna ์๋ ํ์(๊ธฐ๋ณธ 15)")
args = ap.parse_args()
# --------------------------------------------------------
# 2) ์์
๋ฃจํธ ์ด๋ (์๋ ๊ฒฝ๋ก ํผ๋ ๋ฐฉ์ง)
# --------------------------------------------------------
proj = os.path.abspath(args.project) # ์ ๋๊ฒฝ๋ก๋ก ๋ณํ
os.chdir(proj) # ์ฌ๊ธธ ๊ธฐ์ค์ผ๋ก ํ์ผ ์ฝ๊ณ /์ ์ฅ
# --------------------------------------------------------
# 3) CSV ๋ก๋ + ์ปฌ๋ผ ์๋ ๋งคํ
# --------------------------------------------------------
data = read_csv_flexible(args.data)
mapping = auto_map_columns(data)
# --------------------------------------------------------
# 4) ํผ์ฒ ๊ตฌ์ฑ(X, y, feat_names ์์ฑ)
# --------------------------------------------------------
df, X, y, feat_names = make_matrix(data, mapping)
# --------------------------------------------------------
# 5) ์ถ๋ ฅ ํด๋ ์ค๋น (์์ผ๋ฉด ์์ฑ)
# --------------------------------------------------------
artifacts = os.path.join(proj, "artifacts") # ๋ฆฌ๋๋ณด๋/๋ก๊ทธ ๋ฑ
models_dir = os.path.join(proj, "models") # best_model.pkl ์ ์ฅ ์์น
ensure_dirs(artifacts, models_dir)
# --------------------------------------------------------
# 6) ๋ชจ๋ธ ํ์ต(+์ต์
: Optuna) & ๋ฆฌ๋๋ณด๋ ํ๋
# --------------------------------------------------------
best_model, lb = train_and_score(
X, y,
valid_ratio=args.valid_ratio,
use_optuna=args.use_optuna,
optuna_trials=args.optuna_trials
)
# --------------------------------------------------------
# 7) ์ฐ์ถ๋ฌผ ์ ์ฅ (๋ชจ๋ธ/๋ฉํ๋ฐ์ดํฐ/๋ฆฌ๋๋ณด๋)
# --------------------------------------------------------
save_artifacts([artifacts, models_dir], best_model, feat_names, mapping, lb)
# --------------------------------------------------------
# 8) ์ฝ์ ๋ก๊ทธ(์์ฝ)
# --------------------------------------------------------
print("โ
training done.")
print(" - artifacts:", artifacts)
print(" - models :", models_dir)
try:
print(lb.head())
except Exception:
print(lb)
if __name__ == "__main__":
main()
|