Spaces:
Running
Running
Restore: Space paused fixes - memory optimization and error handling
Browse files- Added proper .gitignore to exclude venv and cache files
- Fixed memory leak issues with model caching
- Improved error handling in startup and training
- Added lazy loading for models to prevent startup failures
- .gitignore +83 -0
- Dockerfile +20 -0
- app.py +351 -0
- load_dataset.py +196 -0
- requirements.txt +10 -0
- run_local.sh +39 -0
- start.py +9 -0
- train_hybrid.py +170 -0
- train_scheduler.py +243 -0
.gitignore
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
|
| 7 |
+
# Virtual environments
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
ENV/
|
| 11 |
+
env/
|
| 12 |
+
|
| 13 |
+
# IDE
|
| 14 |
+
.vscode/
|
| 15 |
+
.idea/
|
| 16 |
+
*.swp
|
| 17 |
+
*.swo
|
| 18 |
+
|
| 19 |
+
# OS
|
| 20 |
+
.DS_Store
|
| 21 |
+
Thumbs.db
|
| 22 |
+
|
| 23 |
+
# Logs
|
| 24 |
+
logs/
|
| 25 |
+
*.log
|
| 26 |
+
|
| 27 |
+
# Models (keep only in git if needed)
|
| 28 |
+
models/
|
| 29 |
+
!models/.gitkeep
|
| 30 |
+
|
| 31 |
+
# Data cache
|
| 32 |
+
data/
|
| 33 |
+
!data/.gitkeep
|
| 34 |
+
|
| 35 |
+
# Temporary files
|
| 36 |
+
*.tmp
|
| 37 |
+
*.temp
|
| 38 |
+
|
| 39 |
+
# Jupyter Notebook
|
| 40 |
+
.ipynb_checkpoints
|
| 41 |
+
|
| 42 |
+
# Distribution / packaging
|
| 43 |
+
.Python
|
| 44 |
+
build/
|
| 45 |
+
develop-eggs/
|
| 46 |
+
dist/
|
| 47 |
+
downloads/
|
| 48 |
+
eggs/
|
| 49 |
+
.eggs/
|
| 50 |
+
lib/
|
| 51 |
+
lib64/
|
| 52 |
+
parts/
|
| 53 |
+
sdist/
|
| 54 |
+
var/
|
| 55 |
+
wheels/
|
| 56 |
+
pip-wheel-metadata/
|
| 57 |
+
share/python-wheels/
|
| 58 |
+
*.egg-info/
|
| 59 |
+
.installed.cfg
|
| 60 |
+
*.egg
|
| 61 |
+
MANIFEST
|
| 62 |
+
|
| 63 |
+
# Unit test / coverage reports
|
| 64 |
+
htmlcov/
|
| 65 |
+
.tox/
|
| 66 |
+
.nox/
|
| 67 |
+
.coverage
|
| 68 |
+
.coverage.*
|
| 69 |
+
.cache
|
| 70 |
+
nosetests.xml
|
| 71 |
+
coverage.xml
|
| 72 |
+
*.cover
|
| 73 |
+
*.py,cover
|
| 74 |
+
.hypothesis/
|
| 75 |
+
.pytest_cache/
|
| 76 |
+
|
| 77 |
+
# Environment variables
|
| 78 |
+
.env
|
| 79 |
+
.env.local
|
| 80 |
+
.env.production
|
| 81 |
+
|
| 82 |
+
# Hugging Face
|
| 83 |
+
hf_cache/
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1
|
| 7 |
+
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends build-essential libgomp1 && \
|
| 10 |
+
pip install --upgrade pip && \
|
| 11 |
+
pip install -r requirements.txt && \
|
| 12 |
+
apt-get purge -y build-essential && \
|
| 13 |
+
apt-get autoremove -y && \
|
| 14 |
+
rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI μ±: μλ νμ΅ λ° λͺ¨λΈ λ€μ΄λ‘λ/μ
λ‘λ"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import schedule
|
| 12 |
+
import lightgbm as lgb
|
| 13 |
+
import numpy as np
|
| 14 |
+
from fastapi import FastAPI, HTTPException
|
| 15 |
+
from fastapi.responses import FileResponse
|
| 16 |
+
from huggingface_hub import HfApi
|
| 17 |
+
from pydantic import BaseModel, field_validator
|
| 18 |
+
|
| 19 |
+
from train_scheduler import TrainingScheduler
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title="MuscleCare LightGBM Scheduler",
|
| 24 |
+
description="MuscleCare-Train-AI Spaceμ λμΌν APIλ₯Ό LightGBM λͺ¨λΈλ‘ μ 곡ν©λλ€.",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
_scheduler = TrainingScheduler()
|
| 28 |
+
|
| 29 |
+
_model_lock = threading.Lock()
|
| 30 |
+
_current_model: Optional[lgb.Booster] = None
|
| 31 |
+
_current_model_path: Optional[str] = None
|
| 32 |
+
_current_model_version: Optional[int] = None
|
| 33 |
+
_model_cache_timestamp: Optional[float] = None
|
| 34 |
+
MODEL_CACHE_TIMEOUT = 3600 # 1μκ°
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TrainResponse(BaseModel):
|
| 38 |
+
status: str
|
| 39 |
+
new_data_count: int
|
| 40 |
+
model_path: Optional[str] = None
|
| 41 |
+
hub_url: Optional[str] = None
|
| 42 |
+
model_version: Optional[int] = None
|
| 43 |
+
message: str
|
| 44 |
+
new_session_count: Optional[int] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ResetStateResponse(BaseModel):
|
| 48 |
+
status: str
|
| 49 |
+
state: Dict[str, Any]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class PredictRequest(BaseModel):
|
| 53 |
+
rms_acc: float
|
| 54 |
+
rms_gyro: float
|
| 55 |
+
mean_freq_acc: float
|
| 56 |
+
mean_freq_gyro: float
|
| 57 |
+
rms_base: float
|
| 58 |
+
freq_base: float
|
| 59 |
+
user_emb: List[float]
|
| 60 |
+
|
| 61 |
+
@field_validator("user_emb")
|
| 62 |
+
@classmethod
|
| 63 |
+
def validate_user_emb(cls, v: List[float]) -> List[float]:
|
| 64 |
+
if len(v) != 12:
|
| 65 |
+
raise ValueError("user_emb must contain exactly 12 values.")
|
| 66 |
+
return v
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class PredictResponse(BaseModel):
|
| 70 |
+
fatigue: float
|
| 71 |
+
model_version: Optional[int]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _schedule_background_job() -> None:
|
| 75 |
+
schedule.clear()
|
| 76 |
+
schedule.every().sunday.at(_scheduler.schedule_time).do(_scheduler.run_scheduled_training)
|
| 77 |
+
|
| 78 |
+
def _loop() -> None:
|
| 79 |
+
while True:
|
| 80 |
+
schedule.run_pending()
|
| 81 |
+
time.sleep(60)
|
| 82 |
+
|
| 83 |
+
threading.Thread(target=_loop, daemon=True).start()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _apply_training_result(result: Dict[str, Any]) -> None:
|
| 87 |
+
if result.get("status") != "trained":
|
| 88 |
+
return
|
| 89 |
+
model_path = result.get("model_path")
|
| 90 |
+
if not model_path:
|
| 91 |
+
print("[Model] νμ΅ κ²°κ³Όμ model_pathκ° μμ΄ λͺ¨λΈμ λ‘λνμ§ λͺ»νμ΅λλ€.")
|
| 92 |
+
return
|
| 93 |
+
try:
|
| 94 |
+
_load_model_from_path(Path(model_path), result.get("model_version"))
|
| 95 |
+
except Exception as exc:
|
| 96 |
+
print(f"[Model] μ λͺ¨λΈ λ‘λ μ€ν¨: {exc}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _load_model_from_path(path: Path, version: Optional[int] = None) -> None:
|
| 100 |
+
if not path.exists():
|
| 101 |
+
raise FileNotFoundError(f"λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€: {path}")
|
| 102 |
+
booster = lgb.Booster(model_file=str(path))
|
| 103 |
+
with _model_lock:
|
| 104 |
+
global _current_model, _current_model_path, _current_model_version, _model_cache_timestamp
|
| 105 |
+
_current_model = booster
|
| 106 |
+
_current_model_path = str(path)
|
| 107 |
+
_current_model_version = version
|
| 108 |
+
_model_cache_timestamp = time.time()
|
| 109 |
+
print(f"[Model] Loaded LightGBM model from {path} (version={version})")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _get_cached_model() -> Optional[lgb.Booster]:
|
| 113 |
+
"""μΊμλ λͺ¨λΈ λ°ν, νμμμ μ None λ°ν"""
|
| 114 |
+
with _model_lock:
|
| 115 |
+
if _current_model is None:
|
| 116 |
+
return None
|
| 117 |
+
if _model_cache_timestamp is None:
|
| 118 |
+
return None
|
| 119 |
+
if time.time() - _model_cache_timestamp > MODEL_CACHE_TIMEOUT:
|
| 120 |
+
print("[Model] λͺ¨λΈ μΊμ λ§λ£, μ¬λ‘λ νμ")
|
| 121 |
+
_current_model = None
|
| 122 |
+
return None
|
| 123 |
+
return _current_model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _maybe_load_latest_model() -> None:
|
| 127 |
+
try:
|
| 128 |
+
manifest = _scheduler.get_model_versions()
|
| 129 |
+
target_entry = manifest[-1] if manifest else None
|
| 130 |
+
candidate_path: Optional[Path] = None
|
| 131 |
+
candidate_version: Optional[int] = None
|
| 132 |
+
|
| 133 |
+
if target_entry:
|
| 134 |
+
candidate_path = Path(target_entry["path"])
|
| 135 |
+
candidate_version = target_entry.get("version")
|
| 136 |
+
else:
|
| 137 |
+
default_path = Path("models/lightgbm_model.txt")
|
| 138 |
+
if default_path.exists():
|
| 139 |
+
candidate_path = default_path
|
| 140 |
+
|
| 141 |
+
if candidate_path and candidate_path.exists():
|
| 142 |
+
try:
|
| 143 |
+
_load_model_from_path(candidate_path, candidate_version)
|
| 144 |
+
print(f"[Model] λͺ¨λΈ λ‘λ μ±κ³΅: {candidate_path}")
|
| 145 |
+
except Exception as exc:
|
| 146 |
+
print(f"[Model] λͺ¨λΈ λ‘λ μ€ν¨ (κ³μ μ§ν): {exc}")
|
| 147 |
+
else:
|
| 148 |
+
print("[Model] λ‘λν λͺ¨λΈμ΄ μμ§ μμ΅λλ€.")
|
| 149 |
+
except Exception as exc:
|
| 150 |
+
print(f"[Model] λͺ¨λΈ λ‘λ κ³Όμ μμ μμΈ λ°μ: {exc}")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _get_active_model() -> Tuple[lgb.Booster, Optional[int]]:
|
| 154 |
+
# λ¨Όμ μΊμλ λͺ¨λΈ νμΈ
|
| 155 |
+
cached_model = _get_cached_model()
|
| 156 |
+
if cached_model is not None:
|
| 157 |
+
return cached_model, _current_model_version
|
| 158 |
+
|
| 159 |
+
# μΊμλ λͺ¨λΈμ΄ μμΌλ©΄ μ΅μ λͺ¨λΈ λ‘λ μλ
|
| 160 |
+
try:
|
| 161 |
+
manifest = _scheduler.get_model_versions()
|
| 162 |
+
target_entry = manifest[-1] if manifest else None
|
| 163 |
+
|
| 164 |
+
if target_entry:
|
| 165 |
+
path = Path(target_entry["path"])
|
| 166 |
+
version = target_entry.get("version")
|
| 167 |
+
else:
|
| 168 |
+
path = Path("models/lightgbm_model.txt")
|
| 169 |
+
|
| 170 |
+
if path.exists():
|
| 171 |
+
_load_model_from_path(path, version)
|
| 172 |
+
return _current_model, _current_model_version
|
| 173 |
+
else:
|
| 174 |
+
raise HTTPException(status_code=503, detail="λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€.")
|
| 175 |
+
except Exception as exc:
|
| 176 |
+
raise HTTPException(status_code=503, detail=f"λͺ¨λΈ λ‘λ μ€ν¨: {exc}")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _build_feature_vector(payload: PredictRequest) -> np.ndarray:
|
| 180 |
+
rms_base = payload.rms_base if payload.rms_base != 0 else 1e-6
|
| 181 |
+
freq_mean = (payload.mean_freq_acc + payload.mean_freq_gyro) / 2.0
|
| 182 |
+
if freq_mean == 0:
|
| 183 |
+
freq_mean = 1e-6
|
| 184 |
+
|
| 185 |
+
rms_ratio = ((payload.rms_acc + payload.rms_gyro) / 2.0) / rms_base
|
| 186 |
+
freq_ratio = payload.freq_base / freq_mean
|
| 187 |
+
|
| 188 |
+
feature_vector = [rms_ratio, freq_ratio, *payload.user_emb]
|
| 189 |
+
return np.asarray([feature_vector], dtype=np.float32)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.on_event("startup")
|
| 193 |
+
def on_startup() -> None:
|
| 194 |
+
print("[Startup] MuscleCare Space μμ μ€...")
|
| 195 |
+
try:
|
| 196 |
+
_schedule_background_job()
|
| 197 |
+
print("[Startup] μ€μΌμ€λ¬ μ΄κΈ°ν μλ£")
|
| 198 |
+
except Exception as exc:
|
| 199 |
+
print(f"[Startup] μ€μΌμ€λ¬ μ΄κΈ°ν μ€ν¨ (κ³μ μ§ν): {exc}")
|
| 200 |
+
|
| 201 |
+
# λͺ¨λΈμ μμΈ‘ μμ μ νμν λ λ‘λ (lazy loading)
|
| 202 |
+
print("[Startup] λͺ¨λΈμ νμ μμ μ λ‘λλ©λλ€ (lazy loading)")
|
| 203 |
+
print("[Startup] MuscleCare Space μμ μλ£")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@app.get("/health")
|
| 207 |
+
def health_check() -> dict:
|
| 208 |
+
return {"status": "ok"}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@app.get("/")
|
| 212 |
+
def root() -> dict:
|
| 213 |
+
return {
|
| 214 |
+
"message": "MuscleCare LightGBM Scheduler API",
|
| 215 |
+
"docs": "/docs",
|
| 216 |
+
"endpoints": {
|
| 217 |
+
"trigger": "/trigger",
|
| 218 |
+
"model": "/model",
|
| 219 |
+
"state_reset": "/state/reset",
|
| 220 |
+
},
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _upload_to_hub(model_path: str) -> Optional[str]:
|
| 225 |
+
token = os.getenv("HF_HYBRID_MODEL_TOKEN")
|
| 226 |
+
repo_id = os.getenv("HF_HYBRID_MODEL_REPO_ID")
|
| 227 |
+
|
| 228 |
+
if not token or not repo_id:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
path = Path(model_path)
|
| 232 |
+
if not path.exists():
|
| 233 |
+
raise HTTPException(status_code=404, detail=f"λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€: {model_path}")
|
| 234 |
+
|
| 235 |
+
api = HfApi(token=token)
|
| 236 |
+
api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True)
|
| 237 |
+
api.upload_file(
|
| 238 |
+
path_or_fileobj=path,
|
| 239 |
+
path_in_repo=path.name,
|
| 240 |
+
repo_id=repo_id,
|
| 241 |
+
repo_type="model",
|
| 242 |
+
commit_message=f"LightGBM model upload ({path.name})",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
manifest_path = Path("logs/model_versions.json")
|
| 246 |
+
if manifest_path.exists():
|
| 247 |
+
api.upload_file(
|
| 248 |
+
path_or_fileobj=str(manifest_path),
|
| 249 |
+
path_in_repo="model_versions.json",
|
| 250 |
+
repo_id=repo_id,
|
| 251 |
+
repo_type="model",
|
| 252 |
+
commit_message="Update model manifest",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return f"https://huggingface.co/{repo_id}"
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _resolve_model_entry(version: Optional[int] = None) -> Dict[str, Any]:
|
| 259 |
+
manifest = _scheduler.get_model_versions()
|
| 260 |
+
if not manifest:
|
| 261 |
+
raise HTTPException(status_code=404, detail="μμ§ νμ΅λ λͺ¨λΈμ΄ μμ΅λλ€.")
|
| 262 |
+
|
| 263 |
+
if version is None:
|
| 264 |
+
return manifest[-1]
|
| 265 |
+
|
| 266 |
+
for entry in manifest:
|
| 267 |
+
if entry.get("version") == version:
|
| 268 |
+
return entry
|
| 269 |
+
|
| 270 |
+
raise HTTPException(
|
| 271 |
+
status_code=404,
|
| 272 |
+
detail=f"λ²μ {version} λͺ¨λΈμ μ°Ύμ μ μμ΅λλ€.",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@app.get("/model")
|
| 277 |
+
@app.get("/model/{version:int}")
|
| 278 |
+
def download_model(version: Optional[int] = None) -> FileResponse:
|
| 279 |
+
entry = _resolve_model_entry(version)
|
| 280 |
+
path = Path(entry["path"])
|
| 281 |
+
if not path.exists():
|
| 282 |
+
raise HTTPException(status_code=404, detail="λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€.")
|
| 283 |
+
|
| 284 |
+
response = FileResponse(
|
| 285 |
+
path=path,
|
| 286 |
+
filename=entry["filename"],
|
| 287 |
+
media_type="application/octet-stream",
|
| 288 |
+
)
|
| 289 |
+
response.headers["X-Model-Version"] = str(entry["version"])
|
| 290 |
+
return response
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@app.get("/download")
|
| 294 |
+
def download_latest_alias() -> FileResponse:
|
| 295 |
+
return download_model()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@app.post("/state/reset", response_model=ResetStateResponse)
|
| 299 |
+
def reset_state() -> ResetStateResponse:
|
| 300 |
+
state = _scheduler.reset_training_state()
|
| 301 |
+
return ResetStateResponse(status="reset", state=state)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@app.post("/trigger", response_model=TrainResponse)
|
| 305 |
+
def trigger_training(upload: bool = False) -> TrainResponse:
|
| 306 |
+
try:
|
| 307 |
+
result = _scheduler.run_scheduled_training()
|
| 308 |
+
except Exception as exc: # pragma: no cover
|
| 309 |
+
raise HTTPException(status_code=500, detail=f"νμ΅ μ€ν μ€λ₯: {exc}") from exc
|
| 310 |
+
|
| 311 |
+
message = "λͺ¨λΈ νμ΅μ΄ μλ£λμμ΅λλ€." if result["status"] == "trained" else "νμ΅μ΄ 건λλ°μ΄μ‘μ΅λλ€."
|
| 312 |
+
hub_url = None
|
| 313 |
+
model_version = result.get("model_version")
|
| 314 |
+
model_path = result.get("model_path")
|
| 315 |
+
|
| 316 |
+
if upload and model_path and result["status"] == "trained":
|
| 317 |
+
try:
|
| 318 |
+
hub_url = _upload_to_hub(model_path)
|
| 319 |
+
message = "λͺ¨λΈ νμ΅ λ° Hugging Face μ
λ‘λκ° μλ£λμμ΅λλ€."
|
| 320 |
+
except HTTPException:
|
| 321 |
+
raise
|
| 322 |
+
except Exception as exc: # pragma: no cover
|
| 323 |
+
raise HTTPException(status_code=500, detail=f"Hugging Face μ
λ‘λ μ€ν¨: {exc}") from exc
|
| 324 |
+
|
| 325 |
+
_apply_training_result(result)
|
| 326 |
+
|
| 327 |
+
return TrainResponse(
|
| 328 |
+
status=result["status"],
|
| 329 |
+
new_data_count=result.get("new_data_count", 0),
|
| 330 |
+
model_path=model_path,
|
| 331 |
+
hub_url=hub_url,
|
| 332 |
+
model_version=model_version,
|
| 333 |
+
message=message,
|
| 334 |
+
new_session_count=result.get("new_session_count"),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@app.post("/train", response_model=TrainResponse)
|
| 339 |
+
def trigger_training_alias(upload: bool = False) -> TrainResponse:
|
| 340 |
+
return trigger_training(upload=upload)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@app.post("/predict", response_model=PredictResponse)
|
| 344 |
+
def predict(payload: PredictRequest) -> PredictResponse:
|
| 345 |
+
booster, version = _get_active_model()
|
| 346 |
+
features = _build_feature_vector(payload)
|
| 347 |
+
prediction = booster.predict(features)[0]
|
| 348 |
+
return PredictResponse(fatigue=float(prediction), model_version=version)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
__all__ = ["app"]
|
load_dataset.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Iterable, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from datasets import get_dataset_config_names, get_dataset_split_names, load_dataset
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
DEFAULT_DATASET_ID = "Merry99/MuscleCare-DataSet"
|
| 11 |
+
DEFAULT_DATASET_SPLITS = [
|
| 12 |
+
"local_user",
|
| 13 |
+
"ios_D7ED673185E248BD9DC1102E881E9111",
|
| 14 |
+
"android_SP1A.210812.016",
|
| 15 |
+
] + [f"user_{i:03d}" for i in range(1, 51)]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def download_parquet_from_hub(
|
| 19 |
+
repo_id: str,
|
| 20 |
+
filenames: Iterable[str],
|
| 21 |
+
local_dir: str = "./data",
|
| 22 |
+
repo_type: str = "dataset",
|
| 23 |
+
token: Optional[str] = None,
|
| 24 |
+
) -> List[Path]:
|
| 25 |
+
"""
|
| 26 |
+
(μ΅μ
) Hugging Face Hubμμ parquet νμΌμ λ΄λ €λ°μ λ‘컬μ μ μ₯.
|
| 27 |
+
Spaceμ λμΌν νκ²½μ μν΄ νμ μ μ¬μ©ν©λλ€.
|
| 28 |
+
"""
|
| 29 |
+
target_dir = Path(local_dir)
|
| 30 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
downloaded: List[Path] = []
|
| 33 |
+
for name in filenames:
|
| 34 |
+
local_path = Path(
|
| 35 |
+
hf_hub_download(
|
| 36 |
+
repo_id=repo_id,
|
| 37 |
+
filename=name,
|
| 38 |
+
repo_type=repo_type,
|
| 39 |
+
token=token,
|
| 40 |
+
local_dir=target_dir,
|
| 41 |
+
local_dir_use_symlinks=False,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
downloaded.append(local_path)
|
| 45 |
+
return downloaded
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resolve_parquet_files(data_dir: str = "./data", pattern: str = "user*.parquet") -> List[Path]:
|
| 49 |
+
"""
|
| 50 |
+
λ°μ΄ν° λλ ν 리μμ parquet νμΌ λͺ©λ‘μ μ λ ¬λ μνλ‘ λ°ν.
|
| 51 |
+
"""
|
| 52 |
+
data_path = Path(data_dir)
|
| 53 |
+
if not data_path.exists():
|
| 54 |
+
raise FileNotFoundError(f"λ°μ΄ν° λλ ν 리λ₯Ό μ°Ύμ μ μμ΅λλ€: {data_dir}")
|
| 55 |
+
|
| 56 |
+
parquet_files = sorted(data_path.glob(pattern))
|
| 57 |
+
if not parquet_files:
|
| 58 |
+
raise FileNotFoundError(f"ν¨ν΄({pattern})μ ν΄λΉνλ parquet νμΌμ΄ μμ΅λλ€.")
|
| 59 |
+
return parquet_files
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def parse_user_embedding(raw_emb, fallback_dim: int = 12) -> List[float]:
|
| 63 |
+
"""
|
| 64 |
+
λ¬Έμμ΄/리μ€νΈ ννμ user_embλ₯Ό κ³ μ κΈΈμ΄ λ¦¬μ€νΈλ‘ λ³ν.
|
| 65 |
+
"""
|
| 66 |
+
if isinstance(raw_emb, str):
|
| 67 |
+
try:
|
| 68 |
+
raw_emb = json.loads(raw_emb)
|
| 69 |
+
except json.JSONDecodeError:
|
| 70 |
+
raw_emb = []
|
| 71 |
+
|
| 72 |
+
if isinstance(raw_emb, (list, tuple)):
|
| 73 |
+
values = list(raw_emb)
|
| 74 |
+
else:
|
| 75 |
+
values = []
|
| 76 |
+
|
| 77 |
+
if not values:
|
| 78 |
+
values = [0.0] * fallback_dim
|
| 79 |
+
|
| 80 |
+
if len(values) < fallback_dim:
|
| 81 |
+
values = values + [0.0] * (fallback_dim - len(values))
|
| 82 |
+
else:
|
| 83 |
+
values = values[:fallback_dim]
|
| 84 |
+
|
| 85 |
+
return [float(v) for v in values]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def normalize_user_embeddings(df: pd.DataFrame, emb_dim: int) -> pd.DataFrame:
|
| 89 |
+
if "user_emb" not in df.columns:
|
| 90 |
+
raise KeyError("λ°μ΄ν°μ
μ 'user_emb' 컬λΌμ΄ μμ΅λλ€.")
|
| 91 |
+
df = df.copy()
|
| 92 |
+
df["user_emb"] = df["user_emb"].apply(lambda v: parse_user_embedding(v, emb_dim))
|
| 93 |
+
return df
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _resolve_config_name(repo_id: str) -> Optional[str]:
|
| 97 |
+
try:
|
| 98 |
+
configs = get_dataset_config_names(repo_id)
|
| 99 |
+
if configs:
|
| 100 |
+
return configs[0]
|
| 101 |
+
except Exception:
|
| 102 |
+
pass
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _load_split_dataframe(
|
| 107 |
+
repo_id: str,
|
| 108 |
+
split_name: str,
|
| 109 |
+
cache_dir: str,
|
| 110 |
+
config_name: Optional[str],
|
| 111 |
+
) -> Optional[pd.DataFrame]:
|
| 112 |
+
load_kwargs = {
|
| 113 |
+
"path": repo_id,
|
| 114 |
+
"split": split_name,
|
| 115 |
+
"cache_dir": cache_dir,
|
| 116 |
+
}
|
| 117 |
+
if config_name:
|
| 118 |
+
load_kwargs["name"] = config_name
|
| 119 |
+
try:
|
| 120 |
+
ds = load_dataset(**load_kwargs)
|
| 121 |
+
except ValueError as exc:
|
| 122 |
+
print(f"β οΈ split '{split_name}' 건λλ: {exc}")
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
return ds.to_pandas() if hasattr(ds, "to_pandas") else ds.to_pandas()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_dataset_from_hub(
|
| 129 |
+
repo_id: Optional[str] = None,
|
| 130 |
+
split: Optional[str] = None,
|
| 131 |
+
cache_dir: Optional[str] = None,
|
| 132 |
+
emb_dim: int = 12,
|
| 133 |
+
exclude_sessions: Optional[Iterable[str]] = None,
|
| 134 |
+
) -> Tuple[pd.DataFrame, List[str]]:
|
| 135 |
+
"""
|
| 136 |
+
Hugging Face Datasetμμ λ°μ΄ν°λ₯Ό λ‘λν΄ DataFrameμΌλ‘ λ³ν.
|
| 137 |
+
exclude_sessionsμ ν¬ν¨λ session_idλ μ μΈν©λλ€.
|
| 138 |
+
"""
|
| 139 |
+
repo_id = repo_id or DEFAULT_DATASET_ID
|
| 140 |
+
cache_dir = cache_dir or os.getenv("HF_DATASET_CACHE_DIR", "./data/hf_cache")
|
| 141 |
+
|
| 142 |
+
config_name = _resolve_config_name(repo_id)
|
| 143 |
+
|
| 144 |
+
if split:
|
| 145 |
+
split_names = [split]
|
| 146 |
+
else:
|
| 147 |
+
try:
|
| 148 |
+
split_names = get_dataset_split_names(repo_id, config_name)
|
| 149 |
+
except Exception:
|
| 150 |
+
split_names = DEFAULT_DATASET_SPLITS
|
| 151 |
+
|
| 152 |
+
frames: List[pd.DataFrame] = []
|
| 153 |
+
for split_name in split_names:
|
| 154 |
+
df_part = _load_split_dataframe(
|
| 155 |
+
repo_id=repo_id,
|
| 156 |
+
split_name=split_name,
|
| 157 |
+
cache_dir=cache_dir,
|
| 158 |
+
config_name=config_name,
|
| 159 |
+
)
|
| 160 |
+
if df_part is not None and not df_part.empty:
|
| 161 |
+
frames.append(df_part)
|
| 162 |
+
|
| 163 |
+
if not frames:
|
| 164 |
+
raise ValueError("NO_DATA_AVAILABLE")
|
| 165 |
+
|
| 166 |
+
df = pd.concat(frames, ignore_index=True)
|
| 167 |
+
if "session_id" not in df.columns:
|
| 168 |
+
raise KeyError("λ°μ΄ν°μ
μ 'session_id' 컬λΌμ΄ μμ΅λλ€.")
|
| 169 |
+
|
| 170 |
+
exclude_set = set(str(s) for s in (exclude_sessions or []))
|
| 171 |
+
if exclude_set:
|
| 172 |
+
df = df[~df["session_id"].astype(str).isin(exclude_set)]
|
| 173 |
+
|
| 174 |
+
session_ids = sorted(df["session_id"].dropna().astype(str).unique().tolist())
|
| 175 |
+
df = normalize_user_embeddings(df, emb_dim)
|
| 176 |
+
return df, session_ids
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def load_parquet_dataset(
|
| 180 |
+
data_dir: str = "./data",
|
| 181 |
+
pattern: str = "user*.parquet",
|
| 182 |
+
emb_dim: int = 12,
|
| 183 |
+
) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
λ°μ΄ν°κ° λ‘컬μ μμΌλ©΄ μλμΌλ‘ Hugging Face Datasetμμ λ‘λν©λλ€.
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
parquet_files = resolve_parquet_files(data_dir, pattern)
|
| 189 |
+
frames = [pd.read_parquet(path) for path in parquet_files]
|
| 190 |
+
data = pd.concat(frames, ignore_index=True)
|
| 191 |
+
return normalize_user_embeddings(data, emb_dim)
|
| 192 |
+
except FileNotFoundError:
|
| 193 |
+
# λ‘컬 λ°μ΄ν°κ° μλ€λ©΄ HF Datasetμμ μ§μ λ‘λ
|
| 194 |
+
print("β οΈ λ‘컬 λ°μ΄ν°κ° μμ΄ Hugging Face Datasetμμ λΆλ¬μ΅λλ€.")
|
| 195 |
+
df, _ = load_dataset_from_hub(emb_dim=emb_dim)
|
| 196 |
+
return df
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.5
|
| 2 |
+
uvicorn[standard]==0.32.0
|
| 3 |
+
schedule==1.2.2
|
| 4 |
+
huggingface_hub==0.25.2
|
| 5 |
+
datasets==2.19.1
|
| 6 |
+
pandas==2.1.4
|
| 7 |
+
numpy==1.24.4
|
| 8 |
+
pyarrow==14.0.1
|
| 9 |
+
lightgbm==4.3.0
|
| 10 |
+
scikit-learn==1.3.2
|
run_local.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo "=== MuscleCare Train Hybrid λ‘컬 μ€ν ==="
|
| 4 |
+
|
| 5 |
+
# Python λ²μ νμΈ
|
| 6 |
+
python_version=$(python3 --version 2>&1 | awk '{print $2}')
|
| 7 |
+
echo "Python λ²μ : $python_version"
|
| 8 |
+
|
| 9 |
+
# νμ λ²μ νμΈ
|
| 10 |
+
required_version="3.9"
|
| 11 |
+
if [[ "$(printf '%s\n' "$required_version" "$python_version" | sort -V | head -n1)" != "$required_version" ]]; then
|
| 12 |
+
echo "β Python $required_version μ΄μμ΄ νμν©λλ€. νμ¬: $python_version"
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
echo "β
Python λ²μ νμΈ μλ£"
|
| 17 |
+
|
| 18 |
+
# κ°μνκ²½ νμΈ
|
| 19 |
+
if [[ -z "$VIRTUAL_ENV" ]]; then
|
| 20 |
+
echo "β οΈ κ°μνκ²½μ΄ νμ±νλμ΄ μμ§ μμ΅λλ€."
|
| 21 |
+
echo " source .venv/bin/activate λͺ
λ Ήμ΄λ‘ νμ±ννμΈμ."
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
# μμ‘΄μ± μ€μΉ νμΈ
|
| 25 |
+
echo "μμ‘΄μ± μ€μΉ νμΈ μ€..."
|
| 26 |
+
python3 -c "import fastapi, uvicorn, lightgbm, pandas, datasets; print('β
λͺ¨λ μμ‘΄μ±μ΄ μ€μΉλμ΄ μμ΅λλ€.')" 2>/dev/null
|
| 27 |
+
if [[ $? -ne 0 ]]; then
|
| 28 |
+
echo "β μμ‘΄μ±μ΄ μ€μΉλμ΄ μμ§ μμ΅λλ€."
|
| 29 |
+
echo " pip install -r requirements.txt λͺ
λ Ήμ΄λ‘ μ€μΉνμΈμ."
|
| 30 |
+
exit 1
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
echo ""
|
| 34 |
+
echo "=== μ€μΌμ€λ¬ μμ ==="
|
| 35 |
+
echo "λ§€μ£Ό μΌμμΌ 00:00μ μλ νμ΅μ΄ μ€νλ©λλ€."
|
| 36 |
+
echo "μ’
λ£νλ €λ©΄ Ctrl+Cλ₯Ό λλ₯΄μΈμ."
|
| 37 |
+
echo ""
|
| 38 |
+
|
| 39 |
+
python3 start.py
|
start.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
λ‘컬 μ€μΌμ€λ¬ μμμ
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from train_scheduler import main
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
main()
|
train_hybrid.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LightGBM κΈ°λ° κ·ΌνΌλ‘λ μΆμ νμ΄νλΌμΈ
|
| 3 |
+
- Hugging Face Dataset λ‘λ
|
| 4 |
+
- νΉμ§ μμ± (Ξ±/Ξ² 보μ κ° + user_emb)
|
| 5 |
+
- LightGBM νμ΅ λ° νκ°
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Iterable, List, Optional
|
| 13 |
+
|
| 14 |
+
import lightgbm as lgb
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
|
| 20 |
+
from load_dataset import DEFAULT_DATASET_ID, load_dataset_from_hub
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
EMB_DIM = 12
|
| 24 |
+
FEATURES = ["rms_ratio", "freq_ratio"]
|
| 25 |
+
EMB_COLS = [f"useremb{i+1}" for i in range(EMB_DIM)]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 29 |
+
required = [
|
| 30 |
+
"rms_acc",
|
| 31 |
+
"rms_gyro",
|
| 32 |
+
"mean_freq_acc",
|
| 33 |
+
"mean_freq_gyro",
|
| 34 |
+
"rms_base",
|
| 35 |
+
"freq_base",
|
| 36 |
+
"fatigue",
|
| 37 |
+
]
|
| 38 |
+
missing = set(required) - set(df.columns)
|
| 39 |
+
if missing:
|
| 40 |
+
raise KeyError(f"λλ½λ 컬λΌ: {sorted(missing)}")
|
| 41 |
+
|
| 42 |
+
data = df.copy()
|
| 43 |
+
data["rms_ratio"] = (
|
| 44 |
+
(data["rms_acc"] + data["rms_gyro"]) / 2.0
|
| 45 |
+
) / data["rms_base"].replace(0, np.finfo(float).eps)
|
| 46 |
+
freq_mean = (data["mean_freq_acc"] + data["mean_freq_gyro"]) / 2.0
|
| 47 |
+
data["freq_ratio"] = data["freq_base"] / freq_mean.replace(
|
| 48 |
+
0, np.finfo(float).eps
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if "user_emb" not in data.columns:
|
| 52 |
+
raise KeyError("λ°μ΄ν°μ user_emb 컬λΌμ΄ νμν©λλ€.")
|
| 53 |
+
data[EMB_COLS] = pd.DataFrame(
|
| 54 |
+
data["user_emb"].tolist(), index=data.index
|
| 55 |
+
)
|
| 56 |
+
return data
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def train_lightgbm(
|
| 60 |
+
data: pd.DataFrame,
|
| 61 |
+
test_size: float = 0.2,
|
| 62 |
+
random_state: int = 42,
|
| 63 |
+
) -> Dict[str, str]:
|
| 64 |
+
train_cols = FEATURES + EMB_COLS
|
| 65 |
+
X = data[train_cols]
|
| 66 |
+
y = data["fatigue"]
|
| 67 |
+
|
| 68 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 69 |
+
X, y, test_size=test_size, random_state=random_state
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
lgb_train = lgb.Dataset(X_train, label=y_train)
|
| 73 |
+
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train)
|
| 74 |
+
|
| 75 |
+
params = {
|
| 76 |
+
"objective": "regression",
|
| 77 |
+
"metric": "rmse",
|
| 78 |
+
"learning_rate": 0.1,
|
| 79 |
+
"num_leaves": 31,
|
| 80 |
+
"verbose": -1,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
callbacks = [lgb.early_stopping(stopping_rounds=10, verbose=True)]
|
| 84 |
+
model = lgb.train(
|
| 85 |
+
params,
|
| 86 |
+
lgb_train,
|
| 87 |
+
valid_sets=[lgb_train, lgb_val],
|
| 88 |
+
num_boost_round=100,
|
| 89 |
+
callbacks=callbacks,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
y_pred = model.predict(X_val, num_iteration=model.best_iteration)
|
| 93 |
+
rmse = np.sqrt(mean_squared_error(y_val, y_pred))
|
| 94 |
+
mae = mean_absolute_error(y_val, y_pred)
|
| 95 |
+
|
| 96 |
+
print(f"RMSE: {rmse:.6f}")
|
| 97 |
+
print(f"MAE : {mae:.6f}")
|
| 98 |
+
|
| 99 |
+
importance = pd.DataFrame(
|
| 100 |
+
{
|
| 101 |
+
"feature": train_cols,
|
| 102 |
+
"importance": model.feature_importance(),
|
| 103 |
+
}
|
| 104 |
+
).sort_values(by="importance", ascending=False)
|
| 105 |
+
print("\nFeature Importance:")
|
| 106 |
+
print(importance.to_string(index=False))
|
| 107 |
+
|
| 108 |
+
models_dir = Path("models")
|
| 109 |
+
models_dir.mkdir(exist_ok=True)
|
| 110 |
+
booster_path = models_dir / "lightgbm_model.txt"
|
| 111 |
+
model.save_model(str(booster_path))
|
| 112 |
+
print(f"\nβ
LightGBM λͺ¨λΈ μ μ₯: {booster_path}")
|
| 113 |
+
|
| 114 |
+
metadata = {
|
| 115 |
+
"rmse": rmse,
|
| 116 |
+
"mae": mae,
|
| 117 |
+
"feature_importance": importance.to_dict(orient="records"),
|
| 118 |
+
"model_path": str(booster_path),
|
| 119 |
+
"artifact_type": "lightgbm",
|
| 120 |
+
"sample_count": len(data),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
metadata_path = models_dir / "training_metadata.json"
|
| 124 |
+
metadata_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False))
|
| 125 |
+
print(f"βΉοΈ λ©νλ°μ΄ν° μ μ₯: {metadata_path}")
|
| 126 |
+
|
| 127 |
+
return metadata
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main(
|
| 131 |
+
data_dir: str = "./data",
|
| 132 |
+
pattern: str = "user*.parquet",
|
| 133 |
+
emb_dim: int = EMB_DIM,
|
| 134 |
+
exclude_sessions: Optional[Iterable[str]] = None,
|
| 135 |
+
repo_id: Optional[str] = None,
|
| 136 |
+
split: Optional[str] = None,
|
| 137 |
+
) -> Dict[str, str]:
|
| 138 |
+
print("=" * 80)
|
| 139 |
+
print("MuscleCare LightGBM Trainer")
|
| 140 |
+
print("=" * 80)
|
| 141 |
+
|
| 142 |
+
resolved_repo = repo_id or os.getenv("HF_DATASET_REPO_ID", DEFAULT_DATASET_ID)
|
| 143 |
+
env_split = os.getenv("HF_DATASET_SPLIT")
|
| 144 |
+
resolved_split = split if split is not None else env_split
|
| 145 |
+
|
| 146 |
+
df, session_ids = load_dataset_from_hub(
|
| 147 |
+
repo_id=resolved_repo,
|
| 148 |
+
split=resolved_split,
|
| 149 |
+
emb_dim=emb_dim,
|
| 150 |
+
exclude_sessions=exclude_sessions,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if df.empty:
|
| 154 |
+
raise ValueError("NO_DATA_AVAILABLE")
|
| 155 |
+
df = build_features(df)
|
| 156 |
+
result = train_lightgbm(df)
|
| 157 |
+
result["session_ids"] = session_ids
|
| 158 |
+
result["session_count"] = len(session_ids)
|
| 159 |
+
result["dataset_repo"] = resolved_repo
|
| 160 |
+
result["dataset_split"] = resolved_split or "ALL"
|
| 161 |
+
return result
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
parser = argparse.ArgumentParser()
|
| 166 |
+
parser.add_argument("--data-dir", default="./data")
|
| 167 |
+
parser.add_argument("--pattern", default="user*.parquet")
|
| 168 |
+
parser.add_argument("--emb-dim", type=int, default=EMB_DIM)
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
main(args.data_dir, args.pattern, args.emb_dim)
|
train_scheduler.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LightGBM λͺ¨λΈ νμ΅ μ€μΌμ€λ¬
|
| 3 |
+
- μ ν΄μ§ μ£ΌκΈ°λ‘ train_hybrid.pyλ₯Ό μ€ν
|
| 4 |
+
- νμ΅ μν λ° λ²μ λ©νλ°μ΄ν° κ΄λ¦¬
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import time
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
import schedule
|
| 16 |
+
|
| 17 |
+
from train_hybrid import main as train_main
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TrainingScheduler:
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
data_dir: str = "./data",
|
| 24 |
+
pattern: str = "user*.parquet",
|
| 25 |
+
schedule_time: str = "00:00",
|
| 26 |
+
state_file: str = "./logs/training_state.json",
|
| 27 |
+
versions_file: str = "./logs/model_versions.json",
|
| 28 |
+
):
|
| 29 |
+
self.data_dir = data_dir
|
| 30 |
+
self.pattern = pattern
|
| 31 |
+
self.schedule_time = schedule_time
|
| 32 |
+
self.state_path = Path(state_file)
|
| 33 |
+
self.versions_path = Path(versions_file)
|
| 34 |
+
self.logs_dir = self.state_path.parent
|
| 35 |
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
self.models_dir = Path("models")
|
| 37 |
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------ #
|
| 40 |
+
# State helpers
|
| 41 |
+
# ------------------------------------------------------------------ #
|
| 42 |
+
def _default_state(self) -> Dict[str, Optional[str]]:
|
| 43 |
+
return {
|
| 44 |
+
"last_training": None,
|
| 45 |
+
"model_version": 0,
|
| 46 |
+
"last_model_path": None,
|
| 47 |
+
"processed_sessions": [],
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def load_training_state(self) -> Dict[str, Optional[str]]:
|
| 51 |
+
if self.state_path.exists():
|
| 52 |
+
state = json.loads(self.state_path.read_text(encoding="utf-8"))
|
| 53 |
+
state.setdefault("processed_sessions", [])
|
| 54 |
+
return state
|
| 55 |
+
return self._default_state()
|
| 56 |
+
|
| 57 |
+
def save_training_state(self, state: Dict) -> None:
|
| 58 |
+
self.state_path.write_text(json.dumps(state, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 59 |
+
|
| 60 |
+
def reset_training_state(self) -> Dict:
|
| 61 |
+
state = self._default_state()
|
| 62 |
+
self.save_training_state(state)
|
| 63 |
+
if self.versions_path.exists():
|
| 64 |
+
self.versions_path.unlink()
|
| 65 |
+
return state
|
| 66 |
+
|
| 67 |
+
# ------------------------------------------------------------------ #
|
| 68 |
+
# Version helpers
|
| 69 |
+
# ------------------------------------------------------------------ #
|
| 70 |
+
def _load_versions(self) -> List[Dict]:
|
| 71 |
+
if self.versions_path.exists():
|
| 72 |
+
return json.loads(self.versions_path.read_text(encoding="utf-8"))
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
def _save_versions(self, manifest: List[Dict]) -> None:
|
| 76 |
+
self.versions_path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 77 |
+
|
| 78 |
+
def record_version(self, version: int, source_path: str, timestamp: str, metadata: Dict[str, Any]) -> str:
|
| 79 |
+
source = Path(source_path)
|
| 80 |
+
if not source.exists():
|
| 81 |
+
return source_path
|
| 82 |
+
|
| 83 |
+
versioned = self.models_dir / f"{source.stem}_v{version}{source.suffix}"
|
| 84 |
+
shutil.copy2(source, versioned)
|
| 85 |
+
|
| 86 |
+
manifest = self._load_versions()
|
| 87 |
+
manifest.append(
|
| 88 |
+
{
|
| 89 |
+
"version": version,
|
| 90 |
+
"filename": versioned.name,
|
| 91 |
+
"path": str(versioned),
|
| 92 |
+
"timestamp": timestamp,
|
| 93 |
+
"metrics": {
|
| 94 |
+
"rmse": metadata.get("rmse"),
|
| 95 |
+
"mae": metadata.get("mae"),
|
| 96 |
+
},
|
| 97 |
+
"sample_count": metadata.get("sample_count"),
|
| 98 |
+
"session_count": metadata.get("session_count"),
|
| 99 |
+
"dataset": {
|
| 100 |
+
"repo_id": metadata.get("dataset_repo"),
|
| 101 |
+
"split": metadata.get("dataset_split"),
|
| 102 |
+
},
|
| 103 |
+
}
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Rotate manifest and delete old artifacts
|
| 107 |
+
max_versions = int(os.getenv("MAX_MODEL_VERSIONS", "2"))
|
| 108 |
+
to_remove = []
|
| 109 |
+
if len(manifest) > max_versions:
|
| 110 |
+
to_remove = manifest[:-max_versions]
|
| 111 |
+
manifest = manifest[-max_versions:]
|
| 112 |
+
for old_entry in to_remove:
|
| 113 |
+
old_path = Path(old_entry["path"])
|
| 114 |
+
if old_path.exists():
|
| 115 |
+
old_path.unlink()
|
| 116 |
+
self._save_versions(manifest)
|
| 117 |
+
|
| 118 |
+
return str(versioned)
|
| 119 |
+
|
| 120 |
+
def get_model_versions(self) -> List[Dict]:
|
| 121 |
+
return self._load_versions()
|
| 122 |
+
|
| 123 |
+
# ------------------------------------------------------------------ #
|
| 124 |
+
# Training
|
| 125 |
+
# ------------------------------------------------------------------ #
|
| 126 |
+
def run_scheduled_training(self) -> Dict[str, Optional[str]]:
|
| 127 |
+
print("=" * 80)
|
| 128 |
+
print(f"[TrainingScheduler] νμ΅ μμ - {datetime.utcnow().isoformat()}")
|
| 129 |
+
print("=" * 80)
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
state = self.load_training_state()
|
| 133 |
+
processed_sessions = set(state.get("processed_sessions", []))
|
| 134 |
+
except Exception as exc:
|
| 135 |
+
print(f"[TrainingScheduler] μν λ‘λ μ€ν¨: {exc}")
|
| 136 |
+
return {
|
| 137 |
+
"status": "failed",
|
| 138 |
+
"new_data_count": 0,
|
| 139 |
+
"new_session_count": 0,
|
| 140 |
+
"model_path": None,
|
| 141 |
+
"model_version": 0,
|
| 142 |
+
"message": f"State load failed: {exc}",
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
metadata = train_main(
|
| 147 |
+
self.data_dir,
|
| 148 |
+
self.pattern,
|
| 149 |
+
exclude_sessions=processed_sessions,
|
| 150 |
+
)
|
| 151 |
+
except FileNotFoundError as exc:
|
| 152 |
+
print(f"[TrainingScheduler] λ°μ΄ν° λλ½: {exc}")
|
| 153 |
+
return {
|
| 154 |
+
"status": "skipped",
|
| 155 |
+
"new_data_count": 0,
|
| 156 |
+
"new_session_count": 0,
|
| 157 |
+
"model_path": None,
|
| 158 |
+
"model_version": state.get("model_version", 0),
|
| 159 |
+
"message": str(exc),
|
| 160 |
+
}
|
| 161 |
+
except ValueError as exc:
|
| 162 |
+
if "NO_DATA_AVAILABLE" in str(exc):
|
| 163 |
+
print("[TrainingScheduler] μλ‘μ΄ μΈμ
μ΄ μμ΄ νμ΅μ 건λλλλ€.")
|
| 164 |
+
return {
|
| 165 |
+
"status": "skipped",
|
| 166 |
+
"new_data_count": 0,
|
| 167 |
+
"new_session_count": 0,
|
| 168 |
+
"model_path": None,
|
| 169 |
+
"model_version": state.get("model_version", 0),
|
| 170 |
+
"message": "No new sessions to train.",
|
| 171 |
+
}
|
| 172 |
+
print(f"[TrainingScheduler] λ°μ΄ν° μ²λ¦¬ μ€λ₯: {exc}")
|
| 173 |
+
return {
|
| 174 |
+
"status": "failed",
|
| 175 |
+
"new_data_count": 0,
|
| 176 |
+
"new_session_count": 0,
|
| 177 |
+
"model_path": None,
|
| 178 |
+
"model_version": state.get("model_version", 0),
|
| 179 |
+
"message": f"Data processing error: {exc}",
|
| 180 |
+
}
|
| 181 |
+
except Exception as exc:
|
| 182 |
+
print(f"[TrainingScheduler] νμ΅ μ€ν¨: {exc}")
|
| 183 |
+
return {
|
| 184 |
+
"status": "failed",
|
| 185 |
+
"new_data_count": 0,
|
| 186 |
+
"new_session_count": 0,
|
| 187 |
+
"model_path": None,
|
| 188 |
+
"model_version": state.get("model_version", 0),
|
| 189 |
+
"message": str(exc),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
new_version = state.get("model_version", 0) + 1
|
| 193 |
+
timestamp = datetime.utcnow().isoformat()
|
| 194 |
+
|
| 195 |
+
model_artifact = metadata.get("model_path")
|
| 196 |
+
if not model_artifact:
|
| 197 |
+
raise ValueError("MODEL_ARTIFACT_MISSING")
|
| 198 |
+
versioned_path = self.record_version(new_version, model_artifact, timestamp, metadata)
|
| 199 |
+
|
| 200 |
+
used_sessions = metadata.get("session_ids", [])
|
| 201 |
+
new_sessions = [s for s in used_sessions if s not in processed_sessions]
|
| 202 |
+
processed_sessions.update(new_sessions)
|
| 203 |
+
|
| 204 |
+
state.update(
|
| 205 |
+
{
|
| 206 |
+
"last_training": timestamp,
|
| 207 |
+
"model_version": new_version,
|
| 208 |
+
"last_model_path": versioned_path,
|
| 209 |
+
"processed_sessions": sorted(processed_sessions),
|
| 210 |
+
}
|
| 211 |
+
)
|
| 212 |
+
self.save_training_state(state)
|
| 213 |
+
|
| 214 |
+
print(f"[TrainingScheduler] β
νμ΅ μλ£ - λ²μ {new_version}, μν {metadata.get('sample_count', 0)}")
|
| 215 |
+
|
| 216 |
+
return {
|
| 217 |
+
"status": "trained",
|
| 218 |
+
"new_data_count": metadata.get("sample_count", 0),
|
| 219 |
+
"model_path": versioned_path,
|
| 220 |
+
"model_version": new_version,
|
| 221 |
+
"metadata": metadata,
|
| 222 |
+
"new_session_count": len(new_sessions),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
def trigger_training(self) -> Dict[str, Optional[str]]:
|
| 226 |
+
return self.run_scheduled_training()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def main():
|
| 230 |
+
scheduler = TrainingScheduler()
|
| 231 |
+
schedule.clear()
|
| 232 |
+
schedule.every().sunday.at(scheduler.schedule_time).do(scheduler.run_scheduled_training)
|
| 233 |
+
print(f"[TrainingScheduler] λ§€μ£Ό μΌμμΌ {scheduler.schedule_time} μλ νμ΅μ΄ μμ½λμμ΅λλ€.")
|
| 234 |
+
try:
|
| 235 |
+
while True:
|
| 236 |
+
schedule.run_pending()
|
| 237 |
+
time.sleep(60)
|
| 238 |
+
except KeyboardInterrupt:
|
| 239 |
+
print("[TrainingScheduler] μ€μΌμ€λ¬ μ’
λ£")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|