Spaces:
Running
Running
| """ | |
| LightGBM ๊ธฐ๋ฐ ๊ทผํผ๋ก๋ ์ถ์ ํ์ดํ๋ผ์ธ | |
| - Hugging Face Dataset ๋ก๋ | |
| - ํน์ง ์์ฑ (ฮฑ/ฮฒ ๋ณด์ ๊ฐ + user_emb) | |
| - LightGBM ํ์ต ๋ฐ ํ๊ฐ | |
| """ | |
| import os | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Optional | |
| import lightgbm as lgb | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.metrics import mean_absolute_error, mean_squared_error | |
| from sklearn.model_selection import train_test_split | |
| from load_dataset import DEFAULT_DATASET_ID, load_dataset_from_hub | |
| EMB_DIM = 12 | |
| FEATURES = ["rms_ratio", "freq_ratio"] | |
| EMB_COLS = [f"useremb{i+1}" for i in range(EMB_DIM)] | |
| def build_features(df: pd.DataFrame) -> pd.DataFrame: | |
| required = [ | |
| "rms_acc", | |
| "rms_gyro", | |
| "mean_freq_acc", | |
| "mean_freq_gyro", | |
| "rms_base", | |
| "freq_base", | |
| "fatigue", | |
| ] | |
| missing = set(required) - set(df.columns) | |
| if missing: | |
| raise KeyError(f"๋๋ฝ๋ ์ปฌ๋ผ: {sorted(missing)}") | |
| data = df.copy() | |
| data["rms_ratio"] = ( | |
| (data["rms_acc"] + data["rms_gyro"]) / 2.0 | |
| ) / data["rms_base"].replace(0, np.finfo(float).eps) | |
| freq_mean = (data["mean_freq_acc"] + data["mean_freq_gyro"]) / 2.0 | |
| data["freq_ratio"] = data["freq_base"] / freq_mean.replace( | |
| 0, np.finfo(float).eps | |
| ) | |
| if "user_emb" not in data.columns: | |
| raise KeyError("๋ฐ์ดํฐ์ user_emb ์ปฌ๋ผ์ด ํ์ํฉ๋๋ค.") | |
| data[EMB_COLS] = pd.DataFrame( | |
| data["user_emb"].tolist(), index=data.index | |
| ) | |
| return data | |
| def train_lightgbm( | |
| data: pd.DataFrame, | |
| test_size: float = 0.2, | |
| random_state: int = 42, | |
| ) -> Dict[str, str]: | |
| train_cols = FEATURES + EMB_COLS | |
| X = data[train_cols] | |
| y = data["fatigue"] | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state | |
| ) | |
| lgb_train = lgb.Dataset(X_train, label=y_train) | |
| lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train) | |
| params = { | |
| "objective": "regression", | |
| "metric": "rmse", | |
| "learning_rate": 0.1, | |
| "num_leaves": 31, | |
| "verbose": -1, | |
| } | |
| callbacks = [lgb.early_stopping(stopping_rounds=10, verbose=True)] | |
| model = lgb.train( | |
| params, | |
| lgb_train, | |
| valid_sets=[lgb_train, lgb_val], | |
| num_boost_round=100, | |
| callbacks=callbacks, | |
| ) | |
| y_pred = model.predict(X_val, num_iteration=model.best_iteration) | |
| rmse = np.sqrt(mean_squared_error(y_val, y_pred)) | |
| mae = mean_absolute_error(y_val, y_pred) | |
| print(f"RMSE: {rmse:.6f}") | |
| print(f"MAE : {mae:.6f}") | |
| importance = pd.DataFrame( | |
| { | |
| "feature": train_cols, | |
| "importance": model.feature_importance(), | |
| } | |
| ).sort_values(by="importance", ascending=False) | |
| print("\nFeature Importance:") | |
| print(importance.to_string(index=False)) | |
| models_dir = Path("models") | |
| models_dir.mkdir(exist_ok=True) | |
| booster_path = models_dir / "lightgbm_model.txt" | |
| model.save_model(str(booster_path)) | |
| print(f"\nโ LightGBM ๋ชจ๋ธ ์ ์ฅ: {booster_path}") | |
| metadata = { | |
| "rmse": rmse, | |
| "mae": mae, | |
| "feature_importance": importance.to_dict(orient="records"), | |
| "model_path": str(booster_path), | |
| "artifact_type": "lightgbm", | |
| "sample_count": len(data), | |
| } | |
| metadata_path = models_dir / "training_metadata.json" | |
| metadata_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False)) | |
| print(f"โน๏ธ ๋ฉํ๋ฐ์ดํฐ ์ ์ฅ: {metadata_path}") | |
| return metadata | |
| def main( | |
| data_dir: str = "./data", | |
| pattern: str = "user*.parquet", | |
| emb_dim: int = EMB_DIM, | |
| exclude_sessions: Optional[Iterable[str]] = None, | |
| repo_id: Optional[str] = None, | |
| split: Optional[str] = None, | |
| ) -> Dict[str, str]: | |
| print("=" * 80) | |
| print("MuscleCare LightGBM Trainer") | |
| print("=" * 80) | |
| resolved_repo = repo_id or os.getenv("HF_DATASET_REPO_ID", DEFAULT_DATASET_ID) | |
| env_split = os.getenv("HF_DATASET_SPLIT") | |
| resolved_split = split if split is not None else env_split | |
| df, session_ids = load_dataset_from_hub( | |
| repo_id=resolved_repo, | |
| split=resolved_split, | |
| emb_dim=emb_dim, | |
| exclude_sessions=exclude_sessions, | |
| ) | |
| if df.empty: | |
| raise ValueError("NO_DATA_AVAILABLE") | |
| df = build_features(df) | |
| result = train_lightgbm(df) | |
| result["session_ids"] = session_ids | |
| result["session_count"] = len(session_ids) | |
| result["dataset_repo"] = resolved_repo | |
| result["dataset_split"] = resolved_split or "ALL" | |
| return result | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data-dir", default="./data") | |
| parser.add_argument("--pattern", default="user*.parquet") | |
| parser.add_argument("--emb-dim", type=int, default=EMB_DIM) | |
| args = parser.parse_args() | |
| main(args.data_dir, args.pattern, args.emb_dim) | |