Merry99 commited on
Commit
9a12dde
Β·
1 Parent(s): 3fe8345

Restore: Space paused fixes - memory optimization and error handling

Browse files

- Added proper .gitignore to exclude venv and cache files
- Fixed memory leak issues with model caching
- Improved error handling in startup and training
- Added lazy loading for models to prevent startup failures

Files changed (9) hide show
  1. .gitignore +83 -0
  2. Dockerfile +20 -0
  3. app.py +351 -0
  4. load_dataset.py +196 -0
  5. requirements.txt +10 -0
  6. run_local.sh +39 -0
  7. start.py +9 -0
  8. train_hybrid.py +170 -0
  9. train_scheduler.py +243 -0
.gitignore ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+
7
+ # Virtual environments
8
+ .venv/
9
+ venv/
10
+ ENV/
11
+ env/
12
+
13
+ # IDE
14
+ .vscode/
15
+ .idea/
16
+ *.swp
17
+ *.swo
18
+
19
+ # OS
20
+ .DS_Store
21
+ Thumbs.db
22
+
23
+ # Logs
24
+ logs/
25
+ *.log
26
+
27
+ # Models (keep only in git if needed)
28
+ models/
29
+ !models/.gitkeep
30
+
31
+ # Data cache
32
+ data/
33
+ !data/.gitkeep
34
+
35
+ # Temporary files
36
+ *.tmp
37
+ *.temp
38
+
39
+ # Jupyter Notebook
40
+ .ipynb_checkpoints
41
+
42
+ # Distribution / packaging
43
+ .Python
44
+ build/
45
+ develop-eggs/
46
+ dist/
47
+ downloads/
48
+ eggs/
49
+ .eggs/
50
+ lib/
51
+ lib64/
52
+ parts/
53
+ sdist/
54
+ var/
55
+ wheels/
56
+ pip-wheel-metadata/
57
+ share/python-wheels/
58
+ *.egg-info/
59
+ .installed.cfg
60
+ *.egg
61
+ MANIFEST
62
+
63
+ # Unit test / coverage reports
64
+ htmlcov/
65
+ .tox/
66
+ .nox/
67
+ .coverage
68
+ .coverage.*
69
+ .cache
70
+ nosetests.xml
71
+ coverage.xml
72
+ *.cover
73
+ *.py,cover
74
+ .hypothesis/
75
+ .pytest_cache/
76
+
77
+ # Environment variables
78
+ .env
79
+ .env.local
80
+ .env.production
81
+
82
+ # Hugging Face
83
+ hf_cache/
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PIP_NO_CACHE_DIR=1
7
+
8
+ COPY requirements.txt .
9
+ RUN apt-get update && apt-get install -y --no-install-recommends build-essential libgomp1 && \
10
+ pip install --upgrade pip && \
11
+ pip install -r requirements.txt && \
12
+ apt-get purge -y build-essential && \
13
+ apt-get autoremove -y && \
14
+ rm -rf /var/lib/apt/lists/*
15
+
16
+ COPY . .
17
+
18
+ EXPOSE 7860
19
+
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI μ•±: μˆ˜λ™ ν•™μŠ΅ 및 λͺ¨λΈ λ‹€μš΄λ‘œλ“œ/μ—…λ‘œλ“œ"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import threading
7
+ import time
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import schedule
12
+ import lightgbm as lgb
13
+ import numpy as np
14
+ from fastapi import FastAPI, HTTPException
15
+ from fastapi.responses import FileResponse
16
+ from huggingface_hub import HfApi
17
+ from pydantic import BaseModel, field_validator
18
+
19
+ from train_scheduler import TrainingScheduler
20
+
21
+
22
+ app = FastAPI(
23
+ title="MuscleCare LightGBM Scheduler",
24
+ description="MuscleCare-Train-AI Space와 λ™μΌν•œ APIλ₯Ό LightGBM λͺ¨λΈλ‘œ μ œκ³΅ν•©λ‹ˆλ‹€.",
25
+ )
26
+
27
+ _scheduler = TrainingScheduler()
28
+
29
+ _model_lock = threading.Lock()
30
+ _current_model: Optional[lgb.Booster] = None
31
+ _current_model_path: Optional[str] = None
32
+ _current_model_version: Optional[int] = None
33
+ _model_cache_timestamp: Optional[float] = None
34
+ MODEL_CACHE_TIMEOUT = 3600 # 1μ‹œκ°„
35
+
36
+
37
+ class TrainResponse(BaseModel):
38
+ status: str
39
+ new_data_count: int
40
+ model_path: Optional[str] = None
41
+ hub_url: Optional[str] = None
42
+ model_version: Optional[int] = None
43
+ message: str
44
+ new_session_count: Optional[int] = None
45
+
46
+
47
+ class ResetStateResponse(BaseModel):
48
+ status: str
49
+ state: Dict[str, Any]
50
+
51
+
52
+ class PredictRequest(BaseModel):
53
+ rms_acc: float
54
+ rms_gyro: float
55
+ mean_freq_acc: float
56
+ mean_freq_gyro: float
57
+ rms_base: float
58
+ freq_base: float
59
+ user_emb: List[float]
60
+
61
+ @field_validator("user_emb")
62
+ @classmethod
63
+ def validate_user_emb(cls, v: List[float]) -> List[float]:
64
+ if len(v) != 12:
65
+ raise ValueError("user_emb must contain exactly 12 values.")
66
+ return v
67
+
68
+
69
+ class PredictResponse(BaseModel):
70
+ fatigue: float
71
+ model_version: Optional[int]
72
+
73
+
74
+ def _schedule_background_job() -> None:
75
+ schedule.clear()
76
+ schedule.every().sunday.at(_scheduler.schedule_time).do(_scheduler.run_scheduled_training)
77
+
78
+ def _loop() -> None:
79
+ while True:
80
+ schedule.run_pending()
81
+ time.sleep(60)
82
+
83
+ threading.Thread(target=_loop, daemon=True).start()
84
+
85
+
86
+ def _apply_training_result(result: Dict[str, Any]) -> None:
87
+ if result.get("status") != "trained":
88
+ return
89
+ model_path = result.get("model_path")
90
+ if not model_path:
91
+ print("[Model] ν•™μŠ΅ 결과에 model_pathκ°€ μ—†μ–΄ λͺ¨λΈμ„ λ‘œλ“œν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.")
92
+ return
93
+ try:
94
+ _load_model_from_path(Path(model_path), result.get("model_version"))
95
+ except Exception as exc:
96
+ print(f"[Model] μƒˆ λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨: {exc}")
97
+
98
+
99
+ def _load_model_from_path(path: Path, version: Optional[int] = None) -> None:
100
+ if not path.exists():
101
+ raise FileNotFoundError(f"λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€: {path}")
102
+ booster = lgb.Booster(model_file=str(path))
103
+ with _model_lock:
104
+ global _current_model, _current_model_path, _current_model_version, _model_cache_timestamp
105
+ _current_model = booster
106
+ _current_model_path = str(path)
107
+ _current_model_version = version
108
+ _model_cache_timestamp = time.time()
109
+ print(f"[Model] Loaded LightGBM model from {path} (version={version})")
110
+
111
+
112
+ def _get_cached_model() -> Optional[lgb.Booster]:
113
+ """μΊμ‹œλœ λͺ¨λΈ λ°˜ν™˜, νƒ€μž„μ•„μ›ƒ μ‹œ None λ°˜ν™˜"""
114
+ with _model_lock:
115
+ if _current_model is None:
116
+ return None
117
+ if _model_cache_timestamp is None:
118
+ return None
119
+ if time.time() - _model_cache_timestamp > MODEL_CACHE_TIMEOUT:
120
+ print("[Model] λͺ¨λΈ μΊμ‹œ 만료, μž¬λ‘œλ“œ ν•„μš”")
121
+ _current_model = None
122
+ return None
123
+ return _current_model
124
+
125
+
126
+ def _maybe_load_latest_model() -> None:
127
+ try:
128
+ manifest = _scheduler.get_model_versions()
129
+ target_entry = manifest[-1] if manifest else None
130
+ candidate_path: Optional[Path] = None
131
+ candidate_version: Optional[int] = None
132
+
133
+ if target_entry:
134
+ candidate_path = Path(target_entry["path"])
135
+ candidate_version = target_entry.get("version")
136
+ else:
137
+ default_path = Path("models/lightgbm_model.txt")
138
+ if default_path.exists():
139
+ candidate_path = default_path
140
+
141
+ if candidate_path and candidate_path.exists():
142
+ try:
143
+ _load_model_from_path(candidate_path, candidate_version)
144
+ print(f"[Model] λͺ¨λΈ λ‘œλ“œ 성곡: {candidate_path}")
145
+ except Exception as exc:
146
+ print(f"[Model] λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ (계속 μ§„ν–‰): {exc}")
147
+ else:
148
+ print("[Model] λ‘œλ“œν•  λͺ¨λΈμ΄ 아직 μ—†μŠ΅λ‹ˆλ‹€.")
149
+ except Exception as exc:
150
+ print(f"[Model] λͺ¨λΈ λ‘œλ“œ κ³Όμ •μ—μ„œ μ˜ˆμ™Έ λ°œμƒ: {exc}")
151
+
152
+
153
+ def _get_active_model() -> Tuple[lgb.Booster, Optional[int]]:
154
+ # λ¨Όμ € μΊμ‹œλœ λͺ¨λΈ 확인
155
+ cached_model = _get_cached_model()
156
+ if cached_model is not None:
157
+ return cached_model, _current_model_version
158
+
159
+ # μΊμ‹œλœ λͺ¨λΈμ΄ μ—†μœΌλ©΄ μ΅œμ‹  λͺ¨λΈ λ‘œλ“œ μ‹œλ„
160
+ try:
161
+ manifest = _scheduler.get_model_versions()
162
+ target_entry = manifest[-1] if manifest else None
163
+
164
+ if target_entry:
165
+ path = Path(target_entry["path"])
166
+ version = target_entry.get("version")
167
+ else:
168
+ path = Path("models/lightgbm_model.txt")
169
+
170
+ if path.exists():
171
+ _load_model_from_path(path, version)
172
+ return _current_model, _current_model_version
173
+ else:
174
+ raise HTTPException(status_code=503, detail="λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
175
+ except Exception as exc:
176
+ raise HTTPException(status_code=503, detail=f"λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨: {exc}")
177
+
178
+
179
+ def _build_feature_vector(payload: PredictRequest) -> np.ndarray:
180
+ rms_base = payload.rms_base if payload.rms_base != 0 else 1e-6
181
+ freq_mean = (payload.mean_freq_acc + payload.mean_freq_gyro) / 2.0
182
+ if freq_mean == 0:
183
+ freq_mean = 1e-6
184
+
185
+ rms_ratio = ((payload.rms_acc + payload.rms_gyro) / 2.0) / rms_base
186
+ freq_ratio = payload.freq_base / freq_mean
187
+
188
+ feature_vector = [rms_ratio, freq_ratio, *payload.user_emb]
189
+ return np.asarray([feature_vector], dtype=np.float32)
190
+
191
+
192
+ @app.on_event("startup")
193
+ def on_startup() -> None:
194
+ print("[Startup] MuscleCare Space μ‹œμž‘ 쀑...")
195
+ try:
196
+ _schedule_background_job()
197
+ print("[Startup] μŠ€μΌ€μ€„λŸ¬ μ΄ˆκΈ°ν™” μ™„λ£Œ")
198
+ except Exception as exc:
199
+ print(f"[Startup] μŠ€μΌ€μ€„λŸ¬ μ΄ˆκΈ°ν™” μ‹€νŒ¨ (계속 μ§„ν–‰): {exc}")
200
+
201
+ # λͺ¨λΈμ€ 예츑 μ‹œμ μ— ν•„μš”ν•  λ•Œ λ‘œλ“œ (lazy loading)
202
+ print("[Startup] λͺ¨λΈμ€ ν•„μš” μ‹œμ μ— λ‘œλ“œλ©λ‹ˆλ‹€ (lazy loading)")
203
+ print("[Startup] MuscleCare Space μ‹œμž‘ μ™„λ£Œ")
204
+
205
+
206
+ @app.get("/health")
207
+ def health_check() -> dict:
208
+ return {"status": "ok"}
209
+
210
+
211
+ @app.get("/")
212
+ def root() -> dict:
213
+ return {
214
+ "message": "MuscleCare LightGBM Scheduler API",
215
+ "docs": "/docs",
216
+ "endpoints": {
217
+ "trigger": "/trigger",
218
+ "model": "/model",
219
+ "state_reset": "/state/reset",
220
+ },
221
+ }
222
+
223
+
224
+ def _upload_to_hub(model_path: str) -> Optional[str]:
225
+ token = os.getenv("HF_HYBRID_MODEL_TOKEN")
226
+ repo_id = os.getenv("HF_HYBRID_MODEL_REPO_ID")
227
+
228
+ if not token or not repo_id:
229
+ return None
230
+
231
+ path = Path(model_path)
232
+ if not path.exists():
233
+ raise HTTPException(status_code=404, detail=f"λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€: {model_path}")
234
+
235
+ api = HfApi(token=token)
236
+ api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True)
237
+ api.upload_file(
238
+ path_or_fileobj=path,
239
+ path_in_repo=path.name,
240
+ repo_id=repo_id,
241
+ repo_type="model",
242
+ commit_message=f"LightGBM model upload ({path.name})",
243
+ )
244
+
245
+ manifest_path = Path("logs/model_versions.json")
246
+ if manifest_path.exists():
247
+ api.upload_file(
248
+ path_or_fileobj=str(manifest_path),
249
+ path_in_repo="model_versions.json",
250
+ repo_id=repo_id,
251
+ repo_type="model",
252
+ commit_message="Update model manifest",
253
+ )
254
+
255
+ return f"https://huggingface.co/{repo_id}"
256
+
257
+
258
+ def _resolve_model_entry(version: Optional[int] = None) -> Dict[str, Any]:
259
+ manifest = _scheduler.get_model_versions()
260
+ if not manifest:
261
+ raise HTTPException(status_code=404, detail="아직 ν•™μŠ΅λœ λͺ¨λΈμ΄ μ—†μŠ΅λ‹ˆλ‹€.")
262
+
263
+ if version is None:
264
+ return manifest[-1]
265
+
266
+ for entry in manifest:
267
+ if entry.get("version") == version:
268
+ return entry
269
+
270
+ raise HTTPException(
271
+ status_code=404,
272
+ detail=f"버전 {version} λͺ¨λΈμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.",
273
+ )
274
+
275
+
276
+ @app.get("/model")
277
+ @app.get("/model/{version:int}")
278
+ def download_model(version: Optional[int] = None) -> FileResponse:
279
+ entry = _resolve_model_entry(version)
280
+ path = Path(entry["path"])
281
+ if not path.exists():
282
+ raise HTTPException(status_code=404, detail="λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
283
+
284
+ response = FileResponse(
285
+ path=path,
286
+ filename=entry["filename"],
287
+ media_type="application/octet-stream",
288
+ )
289
+ response.headers["X-Model-Version"] = str(entry["version"])
290
+ return response
291
+
292
+
293
+ @app.get("/download")
294
+ def download_latest_alias() -> FileResponse:
295
+ return download_model()
296
+
297
+
298
+ @app.post("/state/reset", response_model=ResetStateResponse)
299
+ def reset_state() -> ResetStateResponse:
300
+ state = _scheduler.reset_training_state()
301
+ return ResetStateResponse(status="reset", state=state)
302
+
303
+
304
+ @app.post("/trigger", response_model=TrainResponse)
305
+ def trigger_training(upload: bool = False) -> TrainResponse:
306
+ try:
307
+ result = _scheduler.run_scheduled_training()
308
+ except Exception as exc: # pragma: no cover
309
+ raise HTTPException(status_code=500, detail=f"ν•™μŠ΅ μ‹€ν–‰ 였λ₯˜: {exc}") from exc
310
+
311
+ message = "λͺ¨λΈ ν•™μŠ΅μ΄ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€." if result["status"] == "trained" else "ν•™μŠ΅μ΄ κ±΄λ„ˆλ›°μ–΄μ‘ŒμŠ΅λ‹ˆλ‹€."
312
+ hub_url = None
313
+ model_version = result.get("model_version")
314
+ model_path = result.get("model_path")
315
+
316
+ if upload and model_path and result["status"] == "trained":
317
+ try:
318
+ hub_url = _upload_to_hub(model_path)
319
+ message = "λͺ¨λΈ ν•™μŠ΅ 및 Hugging Face μ—…λ‘œλ“œκ°€ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€."
320
+ except HTTPException:
321
+ raise
322
+ except Exception as exc: # pragma: no cover
323
+ raise HTTPException(status_code=500, detail=f"Hugging Face μ—…λ‘œλ“œ μ‹€νŒ¨: {exc}") from exc
324
+
325
+ _apply_training_result(result)
326
+
327
+ return TrainResponse(
328
+ status=result["status"],
329
+ new_data_count=result.get("new_data_count", 0),
330
+ model_path=model_path,
331
+ hub_url=hub_url,
332
+ model_version=model_version,
333
+ message=message,
334
+ new_session_count=result.get("new_session_count"),
335
+ )
336
+
337
+
338
+ @app.post("/train", response_model=TrainResponse)
339
+ def trigger_training_alias(upload: bool = False) -> TrainResponse:
340
+ return trigger_training(upload=upload)
341
+
342
+
343
+ @app.post("/predict", response_model=PredictResponse)
344
+ def predict(payload: PredictRequest) -> PredictResponse:
345
+ booster, version = _get_active_model()
346
+ features = _build_feature_vector(payload)
347
+ prediction = booster.predict(features)[0]
348
+ return PredictResponse(fatigue=float(prediction), model_version=version)
349
+
350
+
351
+ __all__ = ["app"]
load_dataset.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Iterable, List, Optional, Tuple
5
+
6
+ import pandas as pd
7
+ from datasets import get_dataset_config_names, get_dataset_split_names, load_dataset
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ DEFAULT_DATASET_ID = "Merry99/MuscleCare-DataSet"
11
+ DEFAULT_DATASET_SPLITS = [
12
+ "local_user",
13
+ "ios_D7ED673185E248BD9DC1102E881E9111",
14
+ "android_SP1A.210812.016",
15
+ ] + [f"user_{i:03d}" for i in range(1, 51)]
16
+
17
+
18
+ def download_parquet_from_hub(
19
+ repo_id: str,
20
+ filenames: Iterable[str],
21
+ local_dir: str = "./data",
22
+ repo_type: str = "dataset",
23
+ token: Optional[str] = None,
24
+ ) -> List[Path]:
25
+ """
26
+ (μ˜΅μ…˜) Hugging Face Hubμ—μ„œ parquet νŒŒμΌμ„ λ‚΄λ €λ°›μ•„ λ‘œμ»¬μ— μ €μž₯.
27
+ Space와 λ™μΌν•œ ν™˜κ²½μ„ μœ„ν•΄ ν•„μš” μ‹œ μ‚¬μš©ν•©λ‹ˆλ‹€.
28
+ """
29
+ target_dir = Path(local_dir)
30
+ target_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+ downloaded: List[Path] = []
33
+ for name in filenames:
34
+ local_path = Path(
35
+ hf_hub_download(
36
+ repo_id=repo_id,
37
+ filename=name,
38
+ repo_type=repo_type,
39
+ token=token,
40
+ local_dir=target_dir,
41
+ local_dir_use_symlinks=False,
42
+ )
43
+ )
44
+ downloaded.append(local_path)
45
+ return downloaded
46
+
47
+
48
+ def resolve_parquet_files(data_dir: str = "./data", pattern: str = "user*.parquet") -> List[Path]:
49
+ """
50
+ 데이터 λ””λ ‰ν† λ¦¬μ—μ„œ parquet 파일 λͺ©λ‘μ„ μ •λ ¬λœ μƒνƒœλ‘œ λ°˜ν™˜.
51
+ """
52
+ data_path = Path(data_dir)
53
+ if not data_path.exists():
54
+ raise FileNotFoundError(f"데이터 디렉토리λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€: {data_dir}")
55
+
56
+ parquet_files = sorted(data_path.glob(pattern))
57
+ if not parquet_files:
58
+ raise FileNotFoundError(f"νŒ¨ν„΄({pattern})에 ν•΄λ‹Ήν•˜λŠ” parquet 파일이 μ—†μŠ΅λ‹ˆλ‹€.")
59
+ return parquet_files
60
+
61
+
62
+ def parse_user_embedding(raw_emb, fallback_dim: int = 12) -> List[float]:
63
+ """
64
+ λ¬Έμžμ—΄/리슀트 ν˜•νƒœμ˜ user_embλ₯Ό κ³ μ • 길이 리슀트둜 λ³€ν™˜.
65
+ """
66
+ if isinstance(raw_emb, str):
67
+ try:
68
+ raw_emb = json.loads(raw_emb)
69
+ except json.JSONDecodeError:
70
+ raw_emb = []
71
+
72
+ if isinstance(raw_emb, (list, tuple)):
73
+ values = list(raw_emb)
74
+ else:
75
+ values = []
76
+
77
+ if not values:
78
+ values = [0.0] * fallback_dim
79
+
80
+ if len(values) < fallback_dim:
81
+ values = values + [0.0] * (fallback_dim - len(values))
82
+ else:
83
+ values = values[:fallback_dim]
84
+
85
+ return [float(v) for v in values]
86
+
87
+
88
+ def normalize_user_embeddings(df: pd.DataFrame, emb_dim: int) -> pd.DataFrame:
89
+ if "user_emb" not in df.columns:
90
+ raise KeyError("데이터셋에 'user_emb' 컬럼이 μ—†μŠ΅λ‹ˆλ‹€.")
91
+ df = df.copy()
92
+ df["user_emb"] = df["user_emb"].apply(lambda v: parse_user_embedding(v, emb_dim))
93
+ return df
94
+
95
+
96
+ def _resolve_config_name(repo_id: str) -> Optional[str]:
97
+ try:
98
+ configs = get_dataset_config_names(repo_id)
99
+ if configs:
100
+ return configs[0]
101
+ except Exception:
102
+ pass
103
+ return None
104
+
105
+
106
+ def _load_split_dataframe(
107
+ repo_id: str,
108
+ split_name: str,
109
+ cache_dir: str,
110
+ config_name: Optional[str],
111
+ ) -> Optional[pd.DataFrame]:
112
+ load_kwargs = {
113
+ "path": repo_id,
114
+ "split": split_name,
115
+ "cache_dir": cache_dir,
116
+ }
117
+ if config_name:
118
+ load_kwargs["name"] = config_name
119
+ try:
120
+ ds = load_dataset(**load_kwargs)
121
+ except ValueError as exc:
122
+ print(f"⚠️ split '{split_name}' κ±΄λ„ˆλœ€: {exc}")
123
+ return None
124
+
125
+ return ds.to_pandas() if hasattr(ds, "to_pandas") else ds.to_pandas()
126
+
127
+
128
+ def load_dataset_from_hub(
129
+ repo_id: Optional[str] = None,
130
+ split: Optional[str] = None,
131
+ cache_dir: Optional[str] = None,
132
+ emb_dim: int = 12,
133
+ exclude_sessions: Optional[Iterable[str]] = None,
134
+ ) -> Tuple[pd.DataFrame, List[str]]:
135
+ """
136
+ Hugging Face Datasetμ—μ„œ 데이터λ₯Ό λ‘œλ“œν•΄ DataFrame으둜 λ³€ν™˜.
137
+ exclude_sessions에 ν¬ν•¨λœ session_idλŠ” μ œμ™Έν•©λ‹ˆλ‹€.
138
+ """
139
+ repo_id = repo_id or DEFAULT_DATASET_ID
140
+ cache_dir = cache_dir or os.getenv("HF_DATASET_CACHE_DIR", "./data/hf_cache")
141
+
142
+ config_name = _resolve_config_name(repo_id)
143
+
144
+ if split:
145
+ split_names = [split]
146
+ else:
147
+ try:
148
+ split_names = get_dataset_split_names(repo_id, config_name)
149
+ except Exception:
150
+ split_names = DEFAULT_DATASET_SPLITS
151
+
152
+ frames: List[pd.DataFrame] = []
153
+ for split_name in split_names:
154
+ df_part = _load_split_dataframe(
155
+ repo_id=repo_id,
156
+ split_name=split_name,
157
+ cache_dir=cache_dir,
158
+ config_name=config_name,
159
+ )
160
+ if df_part is not None and not df_part.empty:
161
+ frames.append(df_part)
162
+
163
+ if not frames:
164
+ raise ValueError("NO_DATA_AVAILABLE")
165
+
166
+ df = pd.concat(frames, ignore_index=True)
167
+ if "session_id" not in df.columns:
168
+ raise KeyError("데이터셋에 'session_id' 컬럼이 μ—†μŠ΅λ‹ˆλ‹€.")
169
+
170
+ exclude_set = set(str(s) for s in (exclude_sessions or []))
171
+ if exclude_set:
172
+ df = df[~df["session_id"].astype(str).isin(exclude_set)]
173
+
174
+ session_ids = sorted(df["session_id"].dropna().astype(str).unique().tolist())
175
+ df = normalize_user_embeddings(df, emb_dim)
176
+ return df, session_ids
177
+
178
+
179
+ def load_parquet_dataset(
180
+ data_dir: str = "./data",
181
+ pattern: str = "user*.parquet",
182
+ emb_dim: int = 12,
183
+ ) -> pd.DataFrame:
184
+ """
185
+ 데이터가 λ‘œμ»¬μ— μ—†μœΌλ©΄ μžλ™μœΌλ‘œ Hugging Face Datasetμ—μ„œ λ‘œλ“œν•©λ‹ˆλ‹€.
186
+ """
187
+ try:
188
+ parquet_files = resolve_parquet_files(data_dir, pattern)
189
+ frames = [pd.read_parquet(path) for path in parquet_files]
190
+ data = pd.concat(frames, ignore_index=True)
191
+ return normalize_user_embeddings(data, emb_dim)
192
+ except FileNotFoundError:
193
+ # 둜컬 데이터가 μ—†λ‹€λ©΄ HF Datasetμ—μ„œ 직접 λ‘œλ“œ
194
+ print("⚠️ 둜컬 데이터가 μ—†μ–΄ Hugging Face Datasetμ—μ„œ λΆˆλŸ¬μ˜΅λ‹ˆλ‹€.")
195
+ df, _ = load_dataset_from_hub(emb_dim=emb_dim)
196
+ return df
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.5
2
+ uvicorn[standard]==0.32.0
3
+ schedule==1.2.2
4
+ huggingface_hub==0.25.2
5
+ datasets==2.19.1
6
+ pandas==2.1.4
7
+ numpy==1.24.4
8
+ pyarrow==14.0.1
9
+ lightgbm==4.3.0
10
+ scikit-learn==1.3.2
run_local.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "=== MuscleCare Train Hybrid 둜컬 μ‹€ν–‰ ==="
4
+
5
+ # Python 버전 확인
6
+ python_version=$(python3 --version 2>&1 | awk '{print $2}')
7
+ echo "Python 버전: $python_version"
8
+
9
+ # ν•„μˆ˜ 버전 확인
10
+ required_version="3.9"
11
+ if [[ "$(printf '%s\n' "$required_version" "$python_version" | sort -V | head -n1)" != "$required_version" ]]; then
12
+ echo "❌ Python $required_version 이상이 ν•„μš”ν•©λ‹ˆλ‹€. ν˜„μž¬: $python_version"
13
+ exit 1
14
+ fi
15
+
16
+ echo "βœ… Python 버전 확인 μ™„λ£Œ"
17
+
18
+ # κ°€μƒν™˜κ²½ 확인
19
+ if [[ -z "$VIRTUAL_ENV" ]]; then
20
+ echo "⚠️ κ°€μƒν™˜κ²½μ΄ ν™œμ„±ν™”λ˜μ–΄ μžˆμ§€ μ•ŠμŠ΅λ‹ˆλ‹€."
21
+ echo " source .venv/bin/activate λͺ…λ Ήμ–΄λ‘œ ν™œμ„±ν™”ν•˜μ„Έμš”."
22
+ fi
23
+
24
+ # μ˜μ‘΄μ„± μ„€μΉ˜ 확인
25
+ echo "μ˜μ‘΄μ„± μ„€μΉ˜ 확인 쀑..."
26
+ python3 -c "import fastapi, uvicorn, lightgbm, pandas, datasets; print('βœ… λͺ¨λ“  μ˜μ‘΄μ„±μ΄ μ„€μΉ˜λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.')" 2>/dev/null
27
+ if [[ $? -ne 0 ]]; then
28
+ echo "❌ μ˜μ‘΄μ„±μ΄ μ„€μΉ˜λ˜μ–΄ μžˆμ§€ μ•ŠμŠ΅λ‹ˆλ‹€."
29
+ echo " pip install -r requirements.txt λͺ…λ Ήμ–΄λ‘œ μ„€μΉ˜ν•˜μ„Έμš”."
30
+ exit 1
31
+ fi
32
+
33
+ echo ""
34
+ echo "=== μŠ€μΌ€μ€„λŸ¬ μ‹œμž‘ ==="
35
+ echo "λ§€μ£Ό μΌμš”μΌ 00:00에 μžλ™ ν•™μŠ΅μ΄ μ‹€ν–‰λ©λ‹ˆλ‹€."
36
+ echo "μ’…λ£Œν•˜λ €λ©΄ Ctrl+Cλ₯Ό λˆ„λ₯΄μ„Έμš”."
37
+ echo ""
38
+
39
+ python3 start.py
start.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 둜컬 μŠ€μΌ€μ€„λŸ¬ μ‹œμž‘μ 
4
+ """
5
+
6
+ from train_scheduler import main
7
+
8
+ if __name__ == "__main__":
9
+ main()
train_hybrid.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightGBM 기반 κ·Όν”Όλ‘œλ„ μΆ”μ • νŒŒμ΄ν”„λΌμΈ
3
+ - Hugging Face Dataset λ‘œλ“œ
4
+ - νŠΉμ§• 생성 (Ξ±/Ξ² 보정값 + user_emb)
5
+ - LightGBM ν•™μŠ΅ 및 평가
6
+ """
7
+
8
+ import os
9
+ import argparse
10
+ import json
11
+ from pathlib import Path
12
+ from typing import Dict, Iterable, List, Optional
13
+
14
+ import lightgbm as lgb
15
+ import numpy as np
16
+ import pandas as pd
17
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
18
+ from sklearn.model_selection import train_test_split
19
+
20
+ from load_dataset import DEFAULT_DATASET_ID, load_dataset_from_hub
21
+
22
+
23
+ EMB_DIM = 12
24
+ FEATURES = ["rms_ratio", "freq_ratio"]
25
+ EMB_COLS = [f"useremb{i+1}" for i in range(EMB_DIM)]
26
+
27
+
28
+ def build_features(df: pd.DataFrame) -> pd.DataFrame:
29
+ required = [
30
+ "rms_acc",
31
+ "rms_gyro",
32
+ "mean_freq_acc",
33
+ "mean_freq_gyro",
34
+ "rms_base",
35
+ "freq_base",
36
+ "fatigue",
37
+ ]
38
+ missing = set(required) - set(df.columns)
39
+ if missing:
40
+ raise KeyError(f"λˆ„λ½λœ 컬럼: {sorted(missing)}")
41
+
42
+ data = df.copy()
43
+ data["rms_ratio"] = (
44
+ (data["rms_acc"] + data["rms_gyro"]) / 2.0
45
+ ) / data["rms_base"].replace(0, np.finfo(float).eps)
46
+ freq_mean = (data["mean_freq_acc"] + data["mean_freq_gyro"]) / 2.0
47
+ data["freq_ratio"] = data["freq_base"] / freq_mean.replace(
48
+ 0, np.finfo(float).eps
49
+ )
50
+
51
+ if "user_emb" not in data.columns:
52
+ raise KeyError("데이터에 user_emb 컬럼이 ν•„μš”ν•©λ‹ˆλ‹€.")
53
+ data[EMB_COLS] = pd.DataFrame(
54
+ data["user_emb"].tolist(), index=data.index
55
+ )
56
+ return data
57
+
58
+
59
+ def train_lightgbm(
60
+ data: pd.DataFrame,
61
+ test_size: float = 0.2,
62
+ random_state: int = 42,
63
+ ) -> Dict[str, str]:
64
+ train_cols = FEATURES + EMB_COLS
65
+ X = data[train_cols]
66
+ y = data["fatigue"]
67
+
68
+ X_train, X_val, y_train, y_val = train_test_split(
69
+ X, y, test_size=test_size, random_state=random_state
70
+ )
71
+
72
+ lgb_train = lgb.Dataset(X_train, label=y_train)
73
+ lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train)
74
+
75
+ params = {
76
+ "objective": "regression",
77
+ "metric": "rmse",
78
+ "learning_rate": 0.1,
79
+ "num_leaves": 31,
80
+ "verbose": -1,
81
+ }
82
+
83
+ callbacks = [lgb.early_stopping(stopping_rounds=10, verbose=True)]
84
+ model = lgb.train(
85
+ params,
86
+ lgb_train,
87
+ valid_sets=[lgb_train, lgb_val],
88
+ num_boost_round=100,
89
+ callbacks=callbacks,
90
+ )
91
+
92
+ y_pred = model.predict(X_val, num_iteration=model.best_iteration)
93
+ rmse = np.sqrt(mean_squared_error(y_val, y_pred))
94
+ mae = mean_absolute_error(y_val, y_pred)
95
+
96
+ print(f"RMSE: {rmse:.6f}")
97
+ print(f"MAE : {mae:.6f}")
98
+
99
+ importance = pd.DataFrame(
100
+ {
101
+ "feature": train_cols,
102
+ "importance": model.feature_importance(),
103
+ }
104
+ ).sort_values(by="importance", ascending=False)
105
+ print("\nFeature Importance:")
106
+ print(importance.to_string(index=False))
107
+
108
+ models_dir = Path("models")
109
+ models_dir.mkdir(exist_ok=True)
110
+ booster_path = models_dir / "lightgbm_model.txt"
111
+ model.save_model(str(booster_path))
112
+ print(f"\nβœ… LightGBM λͺ¨λΈ μ €μž₯: {booster_path}")
113
+
114
+ metadata = {
115
+ "rmse": rmse,
116
+ "mae": mae,
117
+ "feature_importance": importance.to_dict(orient="records"),
118
+ "model_path": str(booster_path),
119
+ "artifact_type": "lightgbm",
120
+ "sample_count": len(data),
121
+ }
122
+
123
+ metadata_path = models_dir / "training_metadata.json"
124
+ metadata_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False))
125
+ print(f"ℹ️ 메타데이터 μ €μž₯: {metadata_path}")
126
+
127
+ return metadata
128
+
129
+
130
+ def main(
131
+ data_dir: str = "./data",
132
+ pattern: str = "user*.parquet",
133
+ emb_dim: int = EMB_DIM,
134
+ exclude_sessions: Optional[Iterable[str]] = None,
135
+ repo_id: Optional[str] = None,
136
+ split: Optional[str] = None,
137
+ ) -> Dict[str, str]:
138
+ print("=" * 80)
139
+ print("MuscleCare LightGBM Trainer")
140
+ print("=" * 80)
141
+
142
+ resolved_repo = repo_id or os.getenv("HF_DATASET_REPO_ID", DEFAULT_DATASET_ID)
143
+ env_split = os.getenv("HF_DATASET_SPLIT")
144
+ resolved_split = split if split is not None else env_split
145
+
146
+ df, session_ids = load_dataset_from_hub(
147
+ repo_id=resolved_repo,
148
+ split=resolved_split,
149
+ emb_dim=emb_dim,
150
+ exclude_sessions=exclude_sessions,
151
+ )
152
+
153
+ if df.empty:
154
+ raise ValueError("NO_DATA_AVAILABLE")
155
+ df = build_features(df)
156
+ result = train_lightgbm(df)
157
+ result["session_ids"] = session_ids
158
+ result["session_count"] = len(session_ids)
159
+ result["dataset_repo"] = resolved_repo
160
+ result["dataset_split"] = resolved_split or "ALL"
161
+ return result
162
+
163
+
164
+ if __name__ == "__main__":
165
+ parser = argparse.ArgumentParser()
166
+ parser.add_argument("--data-dir", default="./data")
167
+ parser.add_argument("--pattern", default="user*.parquet")
168
+ parser.add_argument("--emb-dim", type=int, default=EMB_DIM)
169
+ args = parser.parse_args()
170
+ main(args.data_dir, args.pattern, args.emb_dim)
train_scheduler.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightGBM λͺ¨λΈ ν•™μŠ΅ μŠ€μΌ€μ€„λŸ¬
3
+ - μ •ν•΄μ§„ 주기둜 train_hybrid.pyλ₯Ό μ‹€ν–‰
4
+ - ν•™μŠ΅ μƒνƒœ 및 버전 메타데이터 관리
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import shutil
10
+ import time
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import schedule
16
+
17
+ from train_hybrid import main as train_main
18
+
19
+
20
+ class TrainingScheduler:
21
+ def __init__(
22
+ self,
23
+ data_dir: str = "./data",
24
+ pattern: str = "user*.parquet",
25
+ schedule_time: str = "00:00",
26
+ state_file: str = "./logs/training_state.json",
27
+ versions_file: str = "./logs/model_versions.json",
28
+ ):
29
+ self.data_dir = data_dir
30
+ self.pattern = pattern
31
+ self.schedule_time = schedule_time
32
+ self.state_path = Path(state_file)
33
+ self.versions_path = Path(versions_file)
34
+ self.logs_dir = self.state_path.parent
35
+ self.logs_dir.mkdir(parents=True, exist_ok=True)
36
+ self.models_dir = Path("models")
37
+ self.models_dir.mkdir(parents=True, exist_ok=True)
38
+
39
+ # ------------------------------------------------------------------ #
40
+ # State helpers
41
+ # ------------------------------------------------------------------ #
42
+ def _default_state(self) -> Dict[str, Optional[str]]:
43
+ return {
44
+ "last_training": None,
45
+ "model_version": 0,
46
+ "last_model_path": None,
47
+ "processed_sessions": [],
48
+ }
49
+
50
+ def load_training_state(self) -> Dict[str, Optional[str]]:
51
+ if self.state_path.exists():
52
+ state = json.loads(self.state_path.read_text(encoding="utf-8"))
53
+ state.setdefault("processed_sessions", [])
54
+ return state
55
+ return self._default_state()
56
+
57
+ def save_training_state(self, state: Dict) -> None:
58
+ self.state_path.write_text(json.dumps(state, indent=2, ensure_ascii=False), encoding="utf-8")
59
+
60
+ def reset_training_state(self) -> Dict:
61
+ state = self._default_state()
62
+ self.save_training_state(state)
63
+ if self.versions_path.exists():
64
+ self.versions_path.unlink()
65
+ return state
66
+
67
+ # ------------------------------------------------------------------ #
68
+ # Version helpers
69
+ # ------------------------------------------------------------------ #
70
+ def _load_versions(self) -> List[Dict]:
71
+ if self.versions_path.exists():
72
+ return json.loads(self.versions_path.read_text(encoding="utf-8"))
73
+ return []
74
+
75
+ def _save_versions(self, manifest: List[Dict]) -> None:
76
+ self.versions_path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8")
77
+
78
+ def record_version(self, version: int, source_path: str, timestamp: str, metadata: Dict[str, Any]) -> str:
79
+ source = Path(source_path)
80
+ if not source.exists():
81
+ return source_path
82
+
83
+ versioned = self.models_dir / f"{source.stem}_v{version}{source.suffix}"
84
+ shutil.copy2(source, versioned)
85
+
86
+ manifest = self._load_versions()
87
+ manifest.append(
88
+ {
89
+ "version": version,
90
+ "filename": versioned.name,
91
+ "path": str(versioned),
92
+ "timestamp": timestamp,
93
+ "metrics": {
94
+ "rmse": metadata.get("rmse"),
95
+ "mae": metadata.get("mae"),
96
+ },
97
+ "sample_count": metadata.get("sample_count"),
98
+ "session_count": metadata.get("session_count"),
99
+ "dataset": {
100
+ "repo_id": metadata.get("dataset_repo"),
101
+ "split": metadata.get("dataset_split"),
102
+ },
103
+ }
104
+ )
105
+
106
+ # Rotate manifest and delete old artifacts
107
+ max_versions = int(os.getenv("MAX_MODEL_VERSIONS", "2"))
108
+ to_remove = []
109
+ if len(manifest) > max_versions:
110
+ to_remove = manifest[:-max_versions]
111
+ manifest = manifest[-max_versions:]
112
+ for old_entry in to_remove:
113
+ old_path = Path(old_entry["path"])
114
+ if old_path.exists():
115
+ old_path.unlink()
116
+ self._save_versions(manifest)
117
+
118
+ return str(versioned)
119
+
120
+ def get_model_versions(self) -> List[Dict]:
121
+ return self._load_versions()
122
+
123
+ # ------------------------------------------------------------------ #
124
+ # Training
125
+ # ------------------------------------------------------------------ #
126
+ def run_scheduled_training(self) -> Dict[str, Optional[str]]:
127
+ print("=" * 80)
128
+ print(f"[TrainingScheduler] ν•™μŠ΅ μ‹œμž‘ - {datetime.utcnow().isoformat()}")
129
+ print("=" * 80)
130
+
131
+ try:
132
+ state = self.load_training_state()
133
+ processed_sessions = set(state.get("processed_sessions", []))
134
+ except Exception as exc:
135
+ print(f"[TrainingScheduler] μƒνƒœ λ‘œλ“œ μ‹€νŒ¨: {exc}")
136
+ return {
137
+ "status": "failed",
138
+ "new_data_count": 0,
139
+ "new_session_count": 0,
140
+ "model_path": None,
141
+ "model_version": 0,
142
+ "message": f"State load failed: {exc}",
143
+ }
144
+
145
+ try:
146
+ metadata = train_main(
147
+ self.data_dir,
148
+ self.pattern,
149
+ exclude_sessions=processed_sessions,
150
+ )
151
+ except FileNotFoundError as exc:
152
+ print(f"[TrainingScheduler] 데이터 λˆ„λ½: {exc}")
153
+ return {
154
+ "status": "skipped",
155
+ "new_data_count": 0,
156
+ "new_session_count": 0,
157
+ "model_path": None,
158
+ "model_version": state.get("model_version", 0),
159
+ "message": str(exc),
160
+ }
161
+ except ValueError as exc:
162
+ if "NO_DATA_AVAILABLE" in str(exc):
163
+ print("[TrainingScheduler] μƒˆλ‘œμš΄ μ„Έμ…˜μ΄ μ—†μ–΄ ν•™μŠ΅μ„ κ±΄λ„ˆλœλ‹ˆλ‹€.")
164
+ return {
165
+ "status": "skipped",
166
+ "new_data_count": 0,
167
+ "new_session_count": 0,
168
+ "model_path": None,
169
+ "model_version": state.get("model_version", 0),
170
+ "message": "No new sessions to train.",
171
+ }
172
+ print(f"[TrainingScheduler] 데이터 처리 였λ₯˜: {exc}")
173
+ return {
174
+ "status": "failed",
175
+ "new_data_count": 0,
176
+ "new_session_count": 0,
177
+ "model_path": None,
178
+ "model_version": state.get("model_version", 0),
179
+ "message": f"Data processing error: {exc}",
180
+ }
181
+ except Exception as exc:
182
+ print(f"[TrainingScheduler] ν•™μŠ΅ μ‹€νŒ¨: {exc}")
183
+ return {
184
+ "status": "failed",
185
+ "new_data_count": 0,
186
+ "new_session_count": 0,
187
+ "model_path": None,
188
+ "model_version": state.get("model_version", 0),
189
+ "message": str(exc),
190
+ }
191
+
192
+ new_version = state.get("model_version", 0) + 1
193
+ timestamp = datetime.utcnow().isoformat()
194
+
195
+ model_artifact = metadata.get("model_path")
196
+ if not model_artifact:
197
+ raise ValueError("MODEL_ARTIFACT_MISSING")
198
+ versioned_path = self.record_version(new_version, model_artifact, timestamp, metadata)
199
+
200
+ used_sessions = metadata.get("session_ids", [])
201
+ new_sessions = [s for s in used_sessions if s not in processed_sessions]
202
+ processed_sessions.update(new_sessions)
203
+
204
+ state.update(
205
+ {
206
+ "last_training": timestamp,
207
+ "model_version": new_version,
208
+ "last_model_path": versioned_path,
209
+ "processed_sessions": sorted(processed_sessions),
210
+ }
211
+ )
212
+ self.save_training_state(state)
213
+
214
+ print(f"[TrainingScheduler] βœ… ν•™μŠ΅ μ™„λ£Œ - 버전 {new_version}, μƒ˜ν”Œ {metadata.get('sample_count', 0)}")
215
+
216
+ return {
217
+ "status": "trained",
218
+ "new_data_count": metadata.get("sample_count", 0),
219
+ "model_path": versioned_path,
220
+ "model_version": new_version,
221
+ "metadata": metadata,
222
+ "new_session_count": len(new_sessions),
223
+ }
224
+
225
+ def trigger_training(self) -> Dict[str, Optional[str]]:
226
+ return self.run_scheduled_training()
227
+
228
+
229
+ def main():
230
+ scheduler = TrainingScheduler()
231
+ schedule.clear()
232
+ schedule.every().sunday.at(scheduler.schedule_time).do(scheduler.run_scheduled_training)
233
+ print(f"[TrainingScheduler] λ§€μ£Ό μΌμš”μΌ {scheduler.schedule_time} μžλ™ ν•™μŠ΅μ΄ μ˜ˆμ•½λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
234
+ try:
235
+ while True:
236
+ schedule.run_pending()
237
+ time.sleep(60)
238
+ except KeyboardInterrupt:
239
+ print("[TrainingScheduler] μŠ€μΌ€μ€„λŸ¬ μ’…λ£Œ")
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()