MuscleCare-Train-Hybrid / train_hybrid.py
Merry99's picture
Restore: Space paused fixes - memory optimization and error handling
9a12dde
"""
LightGBM ๊ธฐ๋ฐ˜ ๊ทผํ”ผ๋กœ๋„ ์ถ”์ • ํŒŒ์ดํ”„๋ผ์ธ
- Hugging Face Dataset ๋กœ๋“œ
- ํŠน์ง• ์ƒ์„ฑ (ฮฑ/ฮฒ ๋ณด์ •๊ฐ’ + user_emb)
- LightGBM ํ•™์Šต ๋ฐ ํ‰๊ฐ€
"""
import os
import argparse
import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split
from load_dataset import DEFAULT_DATASET_ID, load_dataset_from_hub
EMB_DIM = 12
FEATURES = ["rms_ratio", "freq_ratio"]
EMB_COLS = [f"useremb{i+1}" for i in range(EMB_DIM)]
def build_features(df: pd.DataFrame) -> pd.DataFrame:
required = [
"rms_acc",
"rms_gyro",
"mean_freq_acc",
"mean_freq_gyro",
"rms_base",
"freq_base",
"fatigue",
]
missing = set(required) - set(df.columns)
if missing:
raise KeyError(f"๋ˆ„๋ฝ๋œ ์ปฌ๋Ÿผ: {sorted(missing)}")
data = df.copy()
data["rms_ratio"] = (
(data["rms_acc"] + data["rms_gyro"]) / 2.0
) / data["rms_base"].replace(0, np.finfo(float).eps)
freq_mean = (data["mean_freq_acc"] + data["mean_freq_gyro"]) / 2.0
data["freq_ratio"] = data["freq_base"] / freq_mean.replace(
0, np.finfo(float).eps
)
if "user_emb" not in data.columns:
raise KeyError("๋ฐ์ดํ„ฐ์— user_emb ์ปฌ๋Ÿผ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
data[EMB_COLS] = pd.DataFrame(
data["user_emb"].tolist(), index=data.index
)
return data
def train_lightgbm(
data: pd.DataFrame,
test_size: float = 0.2,
random_state: int = 42,
) -> Dict[str, str]:
train_cols = FEATURES + EMB_COLS
X = data[train_cols]
y = data["fatigue"]
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=test_size, random_state=random_state
)
lgb_train = lgb.Dataset(X_train, label=y_train)
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train)
params = {
"objective": "regression",
"metric": "rmse",
"learning_rate": 0.1,
"num_leaves": 31,
"verbose": -1,
}
callbacks = [lgb.early_stopping(stopping_rounds=10, verbose=True)]
model = lgb.train(
params,
lgb_train,
valid_sets=[lgb_train, lgb_val],
num_boost_round=100,
callbacks=callbacks,
)
y_pred = model.predict(X_val, num_iteration=model.best_iteration)
rmse = np.sqrt(mean_squared_error(y_val, y_pred))
mae = mean_absolute_error(y_val, y_pred)
print(f"RMSE: {rmse:.6f}")
print(f"MAE : {mae:.6f}")
importance = pd.DataFrame(
{
"feature": train_cols,
"importance": model.feature_importance(),
}
).sort_values(by="importance", ascending=False)
print("\nFeature Importance:")
print(importance.to_string(index=False))
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
booster_path = models_dir / "lightgbm_model.txt"
model.save_model(str(booster_path))
print(f"\nโœ… LightGBM ๋ชจ๋ธ ์ €์žฅ: {booster_path}")
metadata = {
"rmse": rmse,
"mae": mae,
"feature_importance": importance.to_dict(orient="records"),
"model_path": str(booster_path),
"artifact_type": "lightgbm",
"sample_count": len(data),
}
metadata_path = models_dir / "training_metadata.json"
metadata_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False))
print(f"โ„น๏ธ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ €์žฅ: {metadata_path}")
return metadata
def main(
data_dir: str = "./data",
pattern: str = "user*.parquet",
emb_dim: int = EMB_DIM,
exclude_sessions: Optional[Iterable[str]] = None,
repo_id: Optional[str] = None,
split: Optional[str] = None,
) -> Dict[str, str]:
print("=" * 80)
print("MuscleCare LightGBM Trainer")
print("=" * 80)
resolved_repo = repo_id or os.getenv("HF_DATASET_REPO_ID", DEFAULT_DATASET_ID)
env_split = os.getenv("HF_DATASET_SPLIT")
resolved_split = split if split is not None else env_split
df, session_ids = load_dataset_from_hub(
repo_id=resolved_repo,
split=resolved_split,
emb_dim=emb_dim,
exclude_sessions=exclude_sessions,
)
if df.empty:
raise ValueError("NO_DATA_AVAILABLE")
df = build_features(df)
result = train_lightgbm(df)
result["session_ids"] = session_ids
result["session_count"] = len(session_ids)
result["dataset_repo"] = resolved_repo
result["dataset_split"] = resolved_split or "ALL"
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default="./data")
parser.add_argument("--pattern", default="user*.parquet")
parser.add_argument("--emb-dim", type=int, default=EMB_DIM)
args = parser.parse_args()
main(args.data_dir, args.pattern, args.emb_dim)