cstore / quick_train_runner.py
leedami's picture
Upload 7 files
5841e58 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
============================================================
์ž๋™ ํ•™์Šต ๋Ÿฐ์ฒ˜ (train_cli.py ์˜ˆ์‹œ)
------------------------------------------------------------
์ด ์Šคํฌ๋ฆฝํŠธ๋Š” CSV๋ฅผ ์ฝ์–ด ์ž๋™์œผ๋กœ ์ปฌ๋Ÿผ ๋งคํ•‘ โ†’ ํ”ผ์ฒ˜ ์ƒ์„ฑ โ†’
๋ชจ๋ธ ํ›„๋ณด ํ•™์Šต(์˜ต์…˜: Optuna ํŠœ๋‹) โ†’ ์•„ํ‹ฐํŒฉํŠธ/๋ชจ๋ธ ์ €์žฅ์„
ํ•œ ๋ฒˆ์— ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
[์‚ฌ์šฉ ์˜ˆ]
python train_cli.py --data ./data/sample_sales.csv \
--project . \
--valid_ratio 0.2 \
--use_optuna --optuna_trials 20
ํ•„์ˆ˜:
--data ํ•™์Šต์— ์‚ฌ์šฉํ•  CSV ํŒŒ์ผ ๊ฒฝ๋กœ
์„ ํƒ:
--project ์ž‘์—… ๋ฃจํŠธ ํด๋”(๊ธฐ๋ณธ: ํ˜„์žฌ ํด๋” ".")
--valid_ratio ๊ฒ€์ฆ ๋น„์œจ(0.05~0.4 ๊ถŒ์žฅ, ๊ธฐ๋ณธ 0.2)
--use_optuna Optuna ํŠœ๋‹ ์‚ฌ์šฉ ํ”Œ๋ž˜๊ทธ(์ง€์ • ์‹œ on)
--optuna_trials Optuna ์‹œ๋„ ํšŸ์ˆ˜(๊ธฐ๋ณธ 15)
์ถœ๋ ฅ:
ํ”„๋กœ์ ํŠธ ํด๋” ์•„๋ž˜์—
artifacts/ (๋กœ๊ทธ/๋ฆฌ๋”๋ณด๋“œ ๋“ฑ ์ค‘๊ฐ„ ์‚ฐ์ถœ๋ฌผ)
models/ (best_model.pkl ๋“ฑ ๋ชจ๋ธ ํŒŒ์ผ)
์ด ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
============================================================
"""
import os
import argparse
import pandas as pd # (ํ•„์š”ํ•˜๋ฉด ์ถ”ํ›„ ์‚ฌ์šฉ, ์ง€๊ธˆ์€ ์ž„ํฌํŠธ๋งŒ)
from utils_io import read_csv_flexible, save_utf8sig, ensure_dirs, auto_map_columns
from preprocess import make_matrix
from train_core import train_and_score, save_artifacts
def main():
"""
์ปค๋งจ๋“œ๋ผ์ธ ์ธ์ž๋ฅผ ํŒŒ์‹ฑํ•ด์„œ:
1) CSV ๋กœ๋“œ
2) ์ž๋™ ์ปฌ๋Ÿผ ๋งคํ•‘
3) ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹(X, y) ๊ตฌ์„ฑ
4) ๋ชจ๋ธ ํ•™์Šต(+์˜ต์…˜: Optuna ํŠœ๋‹)
5) ๊ฒฐ๊ณผ ์ €์žฅ(artifacts/, models/)
๋ฅผ ์ˆœ์ฐจ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
"""
# --------------------------------------------------------
# 1) ์ปค๋งจ๋“œ๋ผ์ธ ์˜ต์…˜ ์ •์˜/ํŒŒ์‹ฑ
# --------------------------------------------------------
ap = argparse.ArgumentParser()
ap.add_argument("--data", required=True, help="ํ•™์Šต์— ์‚ฌ์šฉํ•  CSV ๊ฒฝ๋กœ (์˜ˆ: ./data/sales.csv)")
ap.add_argument("--project", default=".", help="์ž‘์—… ๋ฃจํŠธ ํด๋”(artifacts/models ์ƒ์„ฑ ์œ„์น˜). ๊ธฐ๋ณธ๊ฐ’='.'")
ap.add_argument("--valid_ratio", type=float, default=0.2, help="๊ฒ€์ฆ ๋ฐ์ดํ„ฐ ๋น„์œจ(๊ธฐ๋ณธ 0.2)")
ap.add_argument("--use_optuna", action="store_true", help="Optuna ํŠœ๋‹ ์‚ฌ์šฉ ์—ฌ๋ถ€(ํ”Œ๋ž˜๊ทธ ์ง€์ • ์‹œ ์‚ฌ์šฉ)")
ap.add_argument("--optuna_trials", type=int, default=15, help="Optuna ์‹œ๋„ ํšŸ์ˆ˜(๊ธฐ๋ณธ 15)")
args = ap.parse_args()
# --------------------------------------------------------
# 2) ์ž‘์—… ๋ฃจํŠธ ์ด๋™ (์ƒ๋Œ€ ๊ฒฝ๋กœ ํ˜ผ๋™ ๋ฐฉ์ง€)
# --------------------------------------------------------
proj = os.path.abspath(args.project) # ์ ˆ๋Œ€๊ฒฝ๋กœ๋กœ ๋ณ€ํ™˜
os.chdir(proj) # ์—ฌ๊ธธ ๊ธฐ์ค€์œผ๋กœ ํŒŒ์ผ ์ฝ๊ณ /์ €์žฅ
# --------------------------------------------------------
# 3) CSV ๋กœ๋“œ + ์ปฌ๋Ÿผ ์ž๋™ ๋งคํ•‘
# --------------------------------------------------------
data = read_csv_flexible(args.data)
mapping = auto_map_columns(data)
# --------------------------------------------------------
# 4) ํ”ผ์ฒ˜ ๊ตฌ์„ฑ(X, y, feat_names ์ƒ์„ฑ)
# --------------------------------------------------------
df, X, y, feat_names = make_matrix(data, mapping)
# --------------------------------------------------------
# 5) ์ถœ๋ ฅ ํด๋” ์ค€๋น„ (์—†์œผ๋ฉด ์ƒ์„ฑ)
# --------------------------------------------------------
artifacts = os.path.join(proj, "artifacts") # ๋ฆฌ๋”๋ณด๋“œ/๋กœ๊ทธ ๋“ฑ
models_dir = os.path.join(proj, "models") # best_model.pkl ์ €์žฅ ์œ„์น˜
ensure_dirs(artifacts, models_dir)
# --------------------------------------------------------
# 6) ๋ชจ๋ธ ํ•™์Šต(+์˜ต์…˜: Optuna) & ๋ฆฌ๋”๋ณด๋“œ ํš๋“
# --------------------------------------------------------
best_model, lb = train_and_score(
X, y,
valid_ratio=args.valid_ratio,
use_optuna=args.use_optuna,
optuna_trials=args.optuna_trials
)
# --------------------------------------------------------
# 7) ์‚ฐ์ถœ๋ฌผ ์ €์žฅ (๋ชจ๋ธ/๋ฉ”ํƒ€๋ฐ์ดํ„ฐ/๋ฆฌ๋”๋ณด๋“œ)
# --------------------------------------------------------
save_artifacts([artifacts, models_dir], best_model, feat_names, mapping, lb)
# --------------------------------------------------------
# 8) ์ฝ˜์†” ๋กœ๊ทธ(์š”์•ฝ)
# --------------------------------------------------------
print("โœ… training done.")
print(" - artifacts:", artifacts)
print(" - models :", models_dir)
try:
print(lb.head())
except Exception:
print(lb)
if __name__ == "__main__":
main()