Merry99 commited on
Commit
2b83ee8
·
0 Parent(s):

Spaces용 코드만 포함 (모델 파일 제외)

Browse files
Files changed (14) hide show
  1. .dockerignore +22 -0
  2. .gitattributes +35 -0
  3. .gitignore +62 -0
  4. Dockerfile +19 -0
  5. README.md +92 -0
  6. app.py +262 -0
  7. convert_tflite.py +336 -0
  8. load_dataset.py +38 -0
  9. model.md +127 -0
  10. requirements.txt +22 -0
  11. run_local.sh +39 -0
  12. start.py +10 -0
  13. train_e2e.py +319 -0
  14. train_scheduler.py +265 -0
.dockerignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+ .git
12
+ .gitignore
13
+ .env
14
+ .venv
15
+ venv/
16
+ ENV/
17
+ env/
18
+ .vscode
19
+ .idea
20
+ *.md
21
+ !README.md
22
+
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ build/
7
+ develop-eggs/
8
+ dist/
9
+ downloads/
10
+ eggs/
11
+ .eggs/
12
+ lib/
13
+ lib64/
14
+ parts/
15
+ sdist/
16
+ var/
17
+ wheels/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+ MANIFEST
22
+
23
+ # Virtual environments
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # IDE
33
+ .vscode/
34
+ .idea/
35
+ *.swp
36
+ *.swo
37
+ *~
38
+
39
+ # OS
40
+ .DS_Store
41
+ Thumbs.db
42
+
43
+ # Logs
44
+ *.log
45
+
46
+ # Model files (큰 파일이므로 Git에서 제외)
47
+ *.pth
48
+ *.pt
49
+ *.ckpt
50
+ *.bin
51
+ *.safetensors
52
+ *.tflite
53
+ *.keras
54
+ *.h5
55
+ *.pb
56
+ *.onnx
57
+ *.pkl
58
+ *.pickle
59
+
60
+ # Model directory
61
+ model/
62
+
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PIP_NO_CACHE_DIR=1
7
+
8
+ COPY requirements.txt .
9
+ RUN apt-get update && apt-get install -y --no-install-recommends build-essential && \
10
+ pip install --upgrade pip && \
11
+ pip install -r requirements.txt && \
12
+ apt-get purge -y build-essential && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
13
+
14
+ COPY . .
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
19
+
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MuscleCare Train AI
3
+ emoji: 🔥
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # MuscleCare Train AI
12
+
13
+ CNN + GRU 기반 근육 피로도 예측 모델 자동 학습 시스템
14
+
15
+ ## 🚀 주요 기능
16
+
17
+ - **자동 데이터 로딩**: Hugging Face `Merry99/MuscleCare-DataSet` 데이터셋 자동 로드
18
+ - **CNN + GRU 모델**: 시퀀스 데이터에서 피로도 예측
19
+ - **자동 학습 스케줄링**: 매주 일요일 자정 자동 모델 업데이트
20
+ - **중복 방지**: 이미 학습된 세션 데이터 자동 제외
21
+ - **TFLite 변환**: 모바일 배포를 위한 TFLite 모델 자동 생성 (필수)
22
+
23
+ ## 📦 실행 방법
24
+
25
+ ### Docker 사용 (권장)
26
+
27
+ ```bash
28
+ # 이미지 빌드
29
+ docker build -t musclecare-train-ai .
30
+
31
+ # 실행
32
+ docker run musclecare-train-ai
33
+ ```
34
+
35
+ ### 로컬 실행 (Python 3.10 필요)
36
+
37
+ ```bash
38
+ # Python 3.10 확인
39
+ python3.10 --version
40
+
41
+ # 패키지 설치
42
+ python3.10 -m pip install -r requirements.txt
43
+
44
+ # 실행
45
+ python3.10 start.py
46
+ ```
47
+
48
+ 또는 스크립트 사용:
49
+ ```bash
50
+ ./run_local.sh
51
+ ```
52
+
53
+ ## 🔄 전체 플로우
54
+
55
+ 1. **데이터 로드**: `load_dataset.py`로 Hugging Face 데이터셋 로드
56
+ 2. **모델 학습**: `train_e2e.py`로 CNN + GRU 모델 학습
57
+ 3. **모델 저장**: 학습된 모델을 `./model/fatigue_net_v2.pt`에 저장 (PyTorch state_dict 형식)
58
+ 4. **TFLite 변환**: `convert_tflite.py`로 TFLite 모델 생성 → `./model/fatigue_net_v2.tflite`
59
+
60
+ ## 📁 파일 구조
61
+
62
+ - `load_dataset.py`: Hugging Face 데이터셋 로드
63
+ - `train_e2e.py`: CNN + GRU 모델 학습 (PyTorch state_dict 형식으로 저장)
64
+ - `convert_tflite.py`: PyTorch → TFLite 변환
65
+ - `train_scheduler.py`: 자동 학습 스케줄러
66
+ - `start.py`: 자동 학습 스케줄러 시작 스크립트
67
+ - `app.py`: FastAPI 애플리케이션 (나중에 구현 예정)
68
+
69
+ ## 🔧 요구사항
70
+
71
+ - Python 3.10 (TFLite 변환 필수)
72
+ - PyTorch 2.0+
73
+ - ONNX, ONNX-TF, TensorFlow (TFLite 변환용)
74
+
75
+ ## 📝 모델 저장 위치
76
+
77
+ - PyTorch 모델: `./model/fatigue_net_v2.pt` (state_dict 형식)
78
+ - **TFLite 모델: `./model/fatigue_net_v2.tflite`** (모바일 배포용, 필수)
79
+ - 학습 상태: `./model/training_state.json`
80
+
81
+ ## ⚠️ 중요 사항
82
+
83
+ - **TFLite 변환은 필수입니다** (모바일 디바이스에서 실행 필요)
84
+ - 모델은 반드시 PyTorch state_dict 형식으로 저장되어야 합니다 (TorchScript 형식 불가)
85
+ - Python 3.10 이상이 필요합니다 (TFLite 변환 패키지 호환성)
86
+
87
+ ## 🔄 자동 학습 스케줄
88
+
89
+ - 실행 시간: 매주 일요일 자정 (00:00)
90
+ - 중복 방지: `training_state.json`에 저장된 세션 ID는 자동 제외
91
+ - 모델 버전: 자동 증가
92
+ - TFLite 변환: 학습 후 자동 수행
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI 앱: 수동 학습 및 Hugging Face 업로드 트리거"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import threading
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional
11
+
12
+ import schedule
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.responses import FileResponse
15
+ from huggingface_hub import HfApi, hf_hub_download
16
+ try:
17
+ from huggingface_hub.utils import HfHubHTTPError
18
+ except ImportError: # pragma: no cover
19
+ HfHubHTTPError = Exception # type: ignore
20
+ from pydantic import BaseModel
21
+
22
+ from train_scheduler import TrainingScheduler
23
+
24
+
25
+ app = FastAPI(
26
+ title="MuscleCare Train Scheduler API",
27
+ description="수동으로 모델 학습 및 Hugging Face 업로드를 트리거합니다.",
28
+ )
29
+
30
+ _scheduler = TrainingScheduler()
31
+
32
+
33
+ class TrainResponse(BaseModel):
34
+ status: str
35
+ new_data_count: int
36
+ model_path: Optional[str] = None
37
+ hub_url: Optional[str] = None
38
+ model_version: Optional[int] = None
39
+ message: str
40
+
41
+
42
+ @app.on_event("startup")
43
+ def startup_training() -> None:
44
+ """서버 시작 시 자동으로 모델 학습을 실행합니다."""
45
+ try:
46
+ print("🚀 서버 시작: 자동 모델 학습을 시작합니다...")
47
+ result = _scheduler.run_scheduled_training()
48
+ if result["status"] == "trained":
49
+ print(f"✅ 서버 시작 시 학습 완료: {result['new_data_count']}개 데이터로 학습됨")
50
+ else:
51
+ print(f"ℹ️ 서버 시작 시 학습 건너뜀: {result.get('message', '새로운 데이터 없음')}")
52
+ except Exception as exc:
53
+ print(f"⚠️ 서버 시작 시 학습 실패: {exc}")
54
+
55
+ # 기존 스케줄링 설정
56
+ schedule.clear()
57
+ schedule.every().sunday.at("00:00").do(_scheduler.run_scheduled_training)
58
+
59
+ def _run_schedule() -> None:
60
+ while True:
61
+ schedule.run_pending()
62
+ time.sleep(60)
63
+
64
+ threading.Thread(target=_run_schedule, daemon=True).start()
65
+
66
+
67
+ @app.get("/health")
68
+ def health_check() -> dict:
69
+ return {"status": "ok"}
70
+
71
+
72
+ @app.get("/")
73
+ def root() -> dict:
74
+ return {
75
+ "message": "MuscleCare Train Scheduler API가 실행 중입니다.",
76
+ "endpoints": {
77
+ "health": "/health",
78
+ "trigger": "/trigger",
79
+ },
80
+ "docs": "/docs",
81
+ }
82
+
83
+
84
+ def _upload_to_hub(model_path: str) -> Optional[str]:
85
+ token = os.getenv("HF_E2E_MODEL_TOKEN")
86
+ repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
87
+
88
+ if not token or not repo_id:
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail="환경 변수 HF_E2E_MODEL_TOKEN / HF_E2E_MODEL_REPO_ID가 설정되어 있지 않습니다.",
92
+ )
93
+
94
+ path = Path(model_path)
95
+ if not path.exists():
96
+ raise HTTPException(status_code=404, detail=f"모델 파일을 찾을 수 없습니다: {model_path}")
97
+
98
+ api = HfApi(token=token)
99
+ api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True)
100
+
101
+ api.upload_file(
102
+ path_or_fileobj=path,
103
+ path_in_repo=path.name,
104
+ repo_id=repo_id,
105
+ repo_type="model",
106
+ commit_message="Manual scheduler trigger upload",
107
+ )
108
+
109
+ return f"https://huggingface.co/{repo_id}"
110
+
111
+
112
+ # TODO: include version info in response body
113
+ @app.get("/model")
114
+ @app.get("/model/{version:int}")
115
+ def download_model(
116
+ version: Optional[int] = None,
117
+ filename: Optional[str] = None
118
+ ) -> FileResponse:
119
+ repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
120
+ token = os.getenv("HF_E2E_MODEL_TOKEN")
121
+ default_filename = os.getenv("HF_E2E_MODEL_FILE", "cnn_gru_fatigue.tflite")
122
+
123
+ if not repo_id:
124
+ raise HTTPException(
125
+ status_code=400,
126
+ detail="환경 변수 HF_E2E_MODEL_REPO_ID가 설정되어 있지 않습니다."
127
+ )
128
+
129
+ current_state = _scheduler.load_training_state()
130
+ current_version = int(current_state.get("model_version", 0) or 0)
131
+
132
+ try:
133
+ if not version:
134
+ target_filename = filename or default_filename
135
+ local_path = hf_hub_download(
136
+ repo_id=repo_id,
137
+ filename=target_filename,
138
+ repo_type="model",
139
+ token=token,
140
+ local_dir="./model_cache",
141
+ local_dir_use_symlinks=False,
142
+ )
143
+ actual_version = current_version
144
+ else:
145
+ if version > current_version:
146
+ raise HTTPException(
147
+ status_code=404,
148
+ detail=f"현재 모델 버전은 {current_version}입니다. 버전 {version}은 존재하지 않습니다."
149
+ )
150
+ manifest_path = hf_hub_download(
151
+ repo_id=repo_id,
152
+ filename="model_versions.json",
153
+ repo_type="model",
154
+ token=token,
155
+ local_dir="./model_cache",
156
+ local_dir_use_symlinks=False,
157
+ )
158
+ with open(manifest_path, "r", encoding="utf-8") as f:
159
+ manifest = json.load(f)
160
+
161
+ version_entry = next(
162
+ (entry for entry in manifest if entry.get("version") == version),
163
+ None
164
+ )
165
+
166
+ if version_entry is None:
167
+ raise HTTPException(
168
+ status_code=404,
169
+ detail=f"버전 {version}에 해당하는 모델을 찾을 수 없습니다."
170
+ )
171
+
172
+ target_filename = filename or version_entry.get("filename")
173
+ target_revision = version_entry.get("commit")
174
+
175
+ if not target_filename or not target_revision:
176
+ raise HTTPException(
177
+ status_code=500,
178
+ detail=f"버전 {version} 메타데이터가 올바르지 않습니다."
179
+ )
180
+
181
+ local_path = hf_hub_download(
182
+ repo_id=repo_id,
183
+ filename=target_filename,
184
+ repo_type="model",
185
+ token=token,
186
+ local_dir="./model_cache",
187
+ local_dir_use_symlinks=False,
188
+ revision=target_revision,
189
+ )
190
+ actual_version = version
191
+ except Exception as exc:
192
+ status = getattr(getattr(exc, "response", None), "status_code", None)
193
+ if status == 404:
194
+ raise HTTPException(
195
+ status_code=404,
196
+ detail="허깅페이스에서 지정한 모델 파일을 찾을 수 없습니다."
197
+ ) from exc
198
+ raise HTTPException(
199
+ status_code=500,
200
+ detail=f"Hugging Face Hub 다운로드 실패: {exc}"
201
+ ) from exc
202
+
203
+ response = FileResponse(
204
+ path=local_path,
205
+ filename=Path(target_filename).name,
206
+ media_type="application/octet-stream"
207
+ )
208
+ response.headers["X-Model-Version"] = str(actual_version)
209
+ response.headers["X-Model-Filename"] = Path(target_filename).name
210
+ return response
211
+
212
+
213
+ class ResetStateResponse(BaseModel):
214
+ status: str
215
+ state: Dict[str, Any]
216
+
217
+
218
+ @app.post("/state/reset", response_model=ResetStateResponse)
219
+ def reset_training_state() -> ResetStateResponse:
220
+ try:
221
+ state = _scheduler.reset_training_state()
222
+ return ResetStateResponse(
223
+ status="reset",
224
+ state=state,
225
+ )
226
+ except Exception as exc: # pylint: disable=broad-except
227
+ raise HTTPException(status_code=500, detail=f"학습 상태 초기화에 실패했습니다: {exc}") from exc
228
+
229
+
230
+ @app.post("/trigger", response_model=TrainResponse)
231
+ def trigger_training(upload: bool = True) -> TrainResponse:
232
+ try:
233
+ result = _scheduler.run_scheduled_training()
234
+ except Exception as exc: # pylint: disable=broad-except
235
+ raise HTTPException(status_code=500, detail=f"학습 실행 중 오류가 발생했습니다: {exc}") from exc
236
+
237
+ message = "새로운 데이터가 없어 학습을 건너뜁니다."
238
+ hub_url = None
239
+
240
+ if result["status"] == "trained":
241
+ message = "모델 학습이 완료되었습니다."
242
+ model_path = result.get("model_path")
243
+
244
+ if upload and model_path:
245
+ try:
246
+ hub_url = _upload_to_hub(model_path)
247
+ message = "모델 학습 및 업로드가 완료되었습니다."
248
+ except HTTPException:
249
+ raise
250
+ except Exception as exc: # pylint: disable=broad-except
251
+ raise HTTPException(status_code=500, detail=f"Hugging Face 업로드 실패: {exc}") from exc
252
+
253
+ return TrainResponse(
254
+ status=result["status"],
255
+ new_data_count=result["new_data_count"],
256
+ model_path=result.get("model_path"),
257
+ hub_url=hub_url,
258
+ message=message,
259
+ )
260
+
261
+
262
+ __all__ = ["app"]
convert_tflite.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch 모델을 TensorFlow Lite 형식으로 변환하는 스크립트
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ import os
9
+
10
+ # 선택적 임포트
11
+ ONNX_AVAILABLE = False
12
+ TF_AVAILABLE = False
13
+ ONNX_TF_AVAILABLE = False
14
+
15
+ try:
16
+ import onnx
17
+ ONNX_AVAILABLE = True
18
+ except (ImportError, SyntaxError, Exception) as e:
19
+ ONNX_AVAILABLE = False
20
+ if not isinstance(e, ImportError):
21
+ print(f"⚠️ onnx 패키지 로드 중 오류 발생: {type(e).__name__}")
22
+
23
+ try:
24
+ import tensorflow as tf
25
+ TF_AVAILABLE = True
26
+ except (ImportError, SyntaxError, Exception) as e:
27
+ TF_AVAILABLE = False
28
+ if not isinstance(e, ImportError):
29
+ print(f"⚠️ tensorflow 패키지 로드 중 오류 발생: {type(e).__name__}")
30
+
31
+ try:
32
+ # onnx-tf는 실제로 사용할 때 임포트하도록 변경
33
+ # from onnx_tf.backend import prepare
34
+ ONNX_TF_AVAILABLE = True
35
+ except (ImportError, SyntaxError, Exception) as e:
36
+ ONNX_TF_AVAILABLE = False
37
+ if not isinstance(e, ImportError):
38
+ print(f"⚠️ onnx-tf 패키지 로드 중 오류 발생: {type(e).__name__}")
39
+
40
+
41
+ class FatigueNet(nn.Module):
42
+ """CNN + GRU 기반 피로도 예측 모델 (PyTorch 버전)"""
43
+
44
+ def __init__(self, input_dim=2, hidden_dim=64, num_layers=2, output_dim=1):
45
+ super(FatigueNet, self).__init__()
46
+
47
+ # CNN 부분
48
+ self.conv1 = nn.Conv1d(
49
+ in_channels=input_dim,
50
+ out_channels=32,
51
+ kernel_size=1,
52
+ padding=0
53
+ )
54
+ self.conv2 = nn.Conv1d(
55
+ in_channels=32,
56
+ out_channels=64,
57
+ kernel_size=1,
58
+ padding=0
59
+ )
60
+ self.relu = nn.ReLU()
61
+
62
+ # GRU 부분 (TFLite 호환성을 위해 linear_before_reset=False)
63
+ self.gru = nn.GRU(
64
+ input_size=64,
65
+ hidden_size=hidden_dim,
66
+ num_layers=num_layers,
67
+ batch_first=True,
68
+ dropout=0.2 if num_layers > 1 else 0
69
+ )
70
+
71
+ # Fully Connected 레이어
72
+ self.fc = nn.Linear(hidden_dim, output_dim)
73
+ self.dropout = nn.Dropout(0.3)
74
+
75
+ def forward(self, x):
76
+ if x.dim() == 2:
77
+ x = x.unsqueeze(1)
78
+
79
+ x = x.permute(0, 2, 1)
80
+
81
+ x = self.conv1(x)
82
+ x = self.relu(x)
83
+
84
+ x = self.conv2(x)
85
+ x = self.relu(x)
86
+
87
+ x = x.permute(0, 2, 1)
88
+
89
+ gru_out, _ = self.gru(x)
90
+
91
+ last_output = gru_out[:, -1, :]
92
+
93
+ last_output = self.dropout(last_output)
94
+ output = self.fc(last_output)
95
+
96
+ return output
97
+
98
+
99
+ def convert_pytorch_to_tflite(
100
+ pytorch_model_path='./model/fatigue_net_v2.pt',
101
+ tflite_model_path='./model/fatigue_net_v2.tflite',
102
+ input_shape=(1, 1, 2) # (batch, seq_len, features)
103
+ ):
104
+ """
105
+ PyTorch 모델을 TensorFlow Lite로 변환
106
+
107
+ Args:
108
+ pytorch_model_path: PyTorch 모델 파일 경로
109
+ tflite_model_path: 저장할 TFLite 모델 파일 경로
110
+ input_shape: 입력 텐서 형태 (batch, seq_len, features)
111
+ """
112
+ print("=" * 80)
113
+ print("PyTorch 모델을 TensorFlow Lite로 변환")
114
+ print("=" * 80)
115
+
116
+ # 필수 패키지 확인
117
+ if not ONNX_AVAILABLE or not TF_AVAILABLE or not ONNX_TF_AVAILABLE:
118
+ print("\n❌ 필수 패키지가 설치되지 않았거나 호환성 문제가 있습니다.")
119
+ print("\n📋 Python 버전 확인:")
120
+ import sys
121
+ print(f" 현재 Python 버전: {sys.version}")
122
+ print(f" 권장 Python 버전: 3.10 이상")
123
+
124
+ if sys.version_info < (3, 10):
125
+ print("\n⚠️ Python 3.9에서는 일부 패키지 호환성 문제가 있을 수 있습니다.")
126
+ print(" Python 3.10 이상으로 업그레이드하거나, 다음을 시도하세요:")
127
+ print(" - 가상환경에서 Python 3.10+ 사용")
128
+ print(" - 또는 호환되는 패키지 버전 설치")
129
+
130
+ print("\n📦 설치 명령어:")
131
+ print(" 권장 버전 (Python 3.10 이상):")
132
+ print(" pip install onnx==1.15.0 onnx-tf==1.10.0 tensorflow==2.15.0")
133
+ print("\n⚠️ 참고: Python 3.9에서는 일부 패키지 설치 중 에러가 발생할 수 있습니다.")
134
+ print(" Python 3.10 이상 사용을 강력히 권장합니다.")
135
+
136
+ print("\n❌ TFLite 변환은 필수입니다. 모바일 디바이스에서 실행하기 위해 필요합니다.")
137
+ print(" 필수 패키지를 설치하고 다시 시도하세요.")
138
+ return False
139
+
140
+ # 1️⃣ PyTorch 모델 로드
141
+ print("\n1️⃣ PyTorch 모델 로드 중...")
142
+ if not os.path.exists(pytorch_model_path):
143
+ raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {pytorch_model_path}")
144
+
145
+ # TorchScript 파일인지 먼저 확인
146
+ try:
147
+ checkpoint = torch.jit.load(pytorch_model_path, map_location='cpu')
148
+ if isinstance(checkpoint, torch.jit.ScriptModule):
149
+ raise ValueError(
150
+ f"❌ {pytorch_model_path}는 TorchScript 형식입니다.\n"
151
+ "TFLite 변환을 위해서는 PyTorch state_dict 형식 모델이 필요합니다.\n"
152
+ "모델을 다시 학습하거나 올바른 형식의 모델 파일을 사용하세요."
153
+ )
154
+ except:
155
+ pass
156
+
157
+ # 일반 PyTorch 모델 로드
158
+ checkpoint = torch.load(pytorch_model_path, map_location='cpu')
159
+
160
+ # 일반 PyTorch 모델인지 확인
161
+ if not isinstance(checkpoint, dict) or 'model_state_dict' not in checkpoint:
162
+ raise ValueError(
163
+ f"❌ 올바른 PyTorch 모델 형식이 아닙니다.\n"
164
+ f"'{pytorch_model_path}' 파일에 'model_state_dict' 키가 필요합니다.\n"
165
+ "모델을 다시 학습하거나 올바른 형식의 모델 파일을 사용하세요."
166
+ )
167
+
168
+ model_config = checkpoint.get('model_config', {
169
+ 'input_dim': 2,
170
+ 'hidden_dim': 64,
171
+ 'num_layers': 2,
172
+ 'output_dim': 1
173
+ })
174
+
175
+ model = FatigueNet(**model_config)
176
+ model.load_state_dict(checkpoint['model_state_dict'])
177
+ model.eval()
178
+ print(f"✅ 모델 로드 완료: {pytorch_model_path}")
179
+ print(f" 모델 설정: {model_config}\n")
180
+
181
+ # 2️⃣ ONNX로 변환
182
+ print("2️⃣ ONNX 형식으로 변환 중...")
183
+ onnx_model_path = './model/fatigue_net_v2.onnx'
184
+ os.makedirs('./model', exist_ok=True)
185
+
186
+ # 더미 입력 생성 (고정 batch_size=1로 TFLite 호환성 향상)
187
+ dummy_input = torch.randn(1, 1, 2) # (batch=1, seq_len=1, features=2)
188
+
189
+ try:
190
+ # GRU를 RNN으로 변환하거나 TFLite 호환 옵션 사용
191
+ torch.onnx.export(
192
+ model,
193
+ dummy_input,
194
+ onnx_model_path,
195
+ export_params=True,
196
+ opset_version=11, # onnx-tf 호환성을 위해 11로 낮춤
197
+ do_constant_folding=True,
198
+ input_names=['input'],
199
+ output_names=['output'],
200
+ dynamic_axes={
201
+ 'input': {0: 'batch_size', 1: 'sequence_length'},
202
+ 'output': {0: 'batch_size'}
203
+ },
204
+ # GRU 관련 호환성 옵션
205
+ custom_opsets=None,
206
+ verbose=False
207
+ )
208
+ print(f"✅ ONNX 변환 완료: {onnx_model_path}\n")
209
+ except Exception as e:
210
+ print(f"⚠️ ONNX 변환 중 경고 (계속 진행): {e}\n")
211
+
212
+ # 3️⃣ ONNX를 TensorFlow로 변환
213
+ print("3️⃣ TensorFlow 형식으로 변환 중...")
214
+ try:
215
+ from onnx_tf.backend import prepare
216
+
217
+ # ONNX 모델 로드 및 GRU 속성 수정
218
+ onnx_model = onnx.load(onnx_model_path)
219
+
220
+ # GRU 노드의 linear_before_reset 속성을 0으로 설정 (TensorFlow 호환)
221
+ for node in onnx_model.graph.node:
222
+ if node.op_type == 'GRU':
223
+ # linear_before_reset 속성을 찾아서 0으로 설정
224
+ for attr in node.attribute:
225
+ if attr.name == 'linear_before_reset':
226
+ attr.i = 0
227
+ break
228
+ else:
229
+ # linear_before_reset 속성이 없으면 추가
230
+ attr = onnx.helper.make_attribute('linear_before_reset', 0)
231
+ node.attribute.append(attr)
232
+
233
+ tf_rep = prepare(onnx_model)
234
+
235
+ # TensorFlow SavedModel로 저장
236
+ tf_model_path = './model/tf_model'
237
+ tf_rep.export_graph(tf_model_path)
238
+ print(f"✅ TensorFlow 변환 완료: {tf_model_path}\n")
239
+ except Exception as e:
240
+ print(f"❌ TensorFlow 변환 실패: {e}")
241
+ print("⚠️ ONNX-TF 변환이 실패했습니다.\n")
242
+ print("❌ TFLite 변환은 필수입니다. 모바일 디바이스에서 실행하기 위해 필요합니다.")
243
+ print(" 에러를 해결하고 다시 시도하세요.")
244
+ return False
245
+
246
+ # 4️⃣ TensorFlow Lite로 변환
247
+ print("4️⃣ TensorFlow Lite 형식으로 변환 중...")
248
+
249
+ # TensorFlow Lite 변환기 생성
250
+ converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
251
+
252
+ # GRU 등 복잡한 연산을 위한 설정
253
+ converter.target_spec.supported_ops = [
254
+ tf.lite.OpsSet.TFLITE_BUILTINS,
255
+ tf.lite.OpsSet.SELECT_TF_OPS
256
+ ]
257
+ converter._experimental_lower_tensor_list_ops = False
258
+
259
+ # 최적화 옵션 설정 (선택사항)
260
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
261
+
262
+ # 변환 실행
263
+ tflite_model = converter.convert()
264
+
265
+ # TFLite 모델 저장
266
+ with open(tflite_model_path, 'wb') as f:
267
+ f.write(tflite_model)
268
+
269
+ print(f"✅ TensorFlow Lite 변환 완료: {tflite_model_path}")
270
+
271
+ # 모델 크기 확인
272
+ model_size = os.path.getsize(tflite_model_path) / (1024 * 1024) # MB
273
+ print(f" 모델 크기: {model_size:.2f} MB\n")
274
+
275
+ # 5️⃣ 변환된 모델 테스트
276
+ print("5️⃣ 변환된 모델 테스트 중...")
277
+ try:
278
+ interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
279
+ interpreter.allocate_tensors()
280
+
281
+ input_details = interpreter.get_input_details()
282
+ output_details = interpreter.get_output_details()
283
+
284
+ print(f" 입력 형태: {input_details[0]['shape']}")
285
+ print(f" 출력 형태: {output_details[0]['shape']}")
286
+
287
+ # 테스트 입력 (고정 크기)
288
+ test_input = np.random.randn(1, 1, 2).astype(np.float32)
289
+ interpreter.set_tensor(input_details[0]['index'], test_input)
290
+ interpreter.invoke()
291
+ test_output = interpreter.get_tensor(output_details[0]['index'])
292
+
293
+ print(f" 테스트 출력: {test_output[0][0]:.4f}")
294
+ print(" ✅ 모델 테스트 성공\n")
295
+ except Exception as e:
296
+ print(f" ⚠️ 모델 테스트 중 경고: {e}")
297
+ print(" (모델은 생성되었지만 테스트는 실패했습니다. 모바일 디바이스에서 Flex ops가 필요할 수 있습니다.)\n")
298
+
299
+ # 중간 파일 정리 (선택사항)
300
+ print("6️⃣ 중간 파일 정리 중...")
301
+ try:
302
+ os.remove(onnx_model_path)
303
+ import shutil
304
+ shutil.rmtree(tf_model_path)
305
+ print("✅ 중간 파일 정리 완료\n")
306
+ except Exception as e:
307
+ print(f"⚠️ 중간 파일 정리 실패 (무시 가능): {e}\n")
308
+
309
+ print("=" * 80)
310
+ print(f"✅ 변환 완료!")
311
+ print(f" TFLite 모델: {tflite_model_path}")
312
+ print("=" * 80)
313
+ return True
314
+
315
+
316
+ def main():
317
+ """메인 함수"""
318
+ try:
319
+ success = convert_pytorch_to_tflite(
320
+ pytorch_model_path='./model/fatigue_net_v2.pt',
321
+ tflite_model_path='./model/fatigue_net_v2.tflite'
322
+ )
323
+ if not success:
324
+ return 1
325
+ except Exception as e:
326
+ print(f"\n❌ 변환 실패: {e}")
327
+ import traceback
328
+ traceback.print_exc()
329
+ return 1
330
+
331
+ return 0
332
+
333
+
334
+ if __name__ == "__main__":
335
+ exit(main())
336
+
load_dataset.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face 데이터셋 로드 유틸리티
3
+ MuscleCare-DataSet 데이터셋을 로드하는 함수들을 제공합니다.
4
+ """
5
+
6
+ from datasets import load_dataset
7
+ from typing import Optional
8
+
9
+
10
+ def load_musclecare_dataset(
11
+ split: Optional[str] = None,
12
+ cache_dir: Optional[str] = None
13
+ ):
14
+ """
15
+ MuscleCare-DataSet 데이터셋을 로드합니다.
16
+
17
+ Args:
18
+ split: 데이터셋 split 이름 (None이면 모든 split 로드)
19
+ cache_dir: 캐시 디렉토리 경로
20
+
21
+ Returns:
22
+ Dataset 또는 DatasetDict 객체
23
+ """
24
+ dataset = load_dataset(
25
+ "Merry99/MuscleCare-DataSet",
26
+ split=split,
27
+ cache_dir=cache_dir
28
+ )
29
+ return dataset
30
+
31
+ if __name__ == "__main__":
32
+ print("데이터셋 로딩 중...")
33
+ dataset = load_musclecare_dataset()
34
+ print("✅ 데이터셋 로드 완료")
35
+ if hasattr(dataset, 'keys'):
36
+ print(f"총 {len(dataset.keys())}개의 split이 있습니다.")
37
+
38
+
model.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## `/model` API
2
+
3
+ 모델 다운로드 엔드포인트는 최신 모델과 특정 버전의 모델을 모두 제공하며, 응답 헤더를 통해 실제 버전 정보를 확인할 수 있습니다.
4
+
5
+ ### 요청 형식
6
+
7
+ ```
8
+ GET /model
9
+ GET /model?version={번호}
10
+ GET /model?version={번호}&filename={파일명}
11
+ ```
12
+
13
+ | 파라미터 | 타입 | 설명 |
14
+ | --- | --- | --- |
15
+ | `version` (선택) | int | 생략하거나 빈 값이면 최신 모델. 지정하면 해당 버전 확인 후 다운로드. |
16
+ | `filename` (선택) | string | 내려받을 파일명. 기본값은 환경 변수 `HF_E2E_MODEL_FILE` (기본 `cnn_gru_fatigue.tflite`). |
17
+
18
+ ### 응답
19
+
20
+ - 본문: 요청한 모델 바이너리 (예: `.tflite`, `.keras`, 메타데이터 등)
21
+ - 헤더:
22
+ - `X-Model-Version`: 실제 다운로드된 모델 버전
23
+ - `X-Model-Filename`: 반환된 파일명
24
+ - 에러:
25
+ - `404` – 요청한 버전이 현재 `model_version`보다 크거나 manifest에 존재하지 않을 때
26
+ - `500` – Hugging Face Hub 다운로드 실패 등 내부 오류
27
+
28
+ ### 동작 규칙
29
+
30
+ 1. 서버는 `training_state.json`의 `model_version` 값을 읽어 현재 허용 가능한 최대 버전을 확인합니다.
31
+ 2. `version`을 지정하지 않으면 최신 모델(현재 버전)을 다운로드합니다.
32
+ 3. `version`을 지정하면 서버가 현재 `model_version` 이하인지 확인한 뒤, 동일한 파일명을 내려줍니다(버전별로 파일명을 구분하지 않습니다).
33
+ 4. 요청한 버전이 현재 버전보다 크거나 파일이 존재하지 않으면 `404`를 반환합니다.
34
+
35
+ ### 사용 예시
36
+
37
+ #### 최신 모델 다운로드
38
+ ```bash
39
+ curl -L -o cnn_gru_fatigue_latest.tflite \
40
+ "https://merry99-musclecare-train-ai.hf.space/model"
41
+ ```
42
+
43
+ #### 버전 3 모델 다운로드
44
+ ```bash
45
+ curl -L -o cnn_gru_fatigue_v3.tflite \
46
+ "https://merry99-musclecare-train-ai.hf.space/model?version=3"
47
+ ```
48
+
49
+ #### 버전 3 메타데이터 다운로드
50
+ ```bash
51
+ curl -L -o metadata_v3.json \
52
+ "https://merry99-musclecare-train-ai.hf.space/model?version=3&filename=cnn_gru_fatigue_metadata.json"
53
+ ```
54
+
55
+ #### 헤더 확인
56
+ ```bash
57
+ curl -I "https://merry99-musclecare-train-ai.hf.space/model?version=3"
58
+ ```
59
+ 응답 헤더 예시:
60
+ ```
61
+ X-Model-Version: 3
62
+ X-Model-Filename: cnn_gru_fatigue.tflite
63
+ ```
64
+
65
+ ### 주의 사항
66
+
67
+ - `training_state.json`의 `model_version` 값이 기준이 되며, 그보다 높은 버전을 요청하면 404가 반환됩니다.
68
+ - 버전별로 다른 파일을 유지하지 않고, 같은 파일명을 내려주되 헤더(`X-Model-Version`)로 실제 버전을 확인합니다.
69
+ - 실패(예: 404) 시 JSON 응답이 내려오므로, 클라이언트는 상태 코드를 먼저 확인한 뒤 **200일 때만** `body`를 파일로 저장하세요.
70
+
71
+ Flutter 예시 (Dio):
72
+ ```dart
73
+ final response = await dio.get<List<int>>(
74
+ 'https://merry99-musclecare-train-ai.hf.space/model',
75
+ options: Options(responseType: ResponseType.bytes),
76
+ );
77
+
78
+ if (response.statusCode == 200) {
79
+ final version = response.headers.value('X-Model-Version');
80
+ final filename = response.headers.value('X-Model-Filename') ?? 'model.tflite';
81
+ await File('/path/$filename').writeAsBytes(response.data!);
82
+ } else {
83
+ final errorText = utf8.decode(response.data ?? []);
84
+ // 에러 처리
85
+ }
86
+ ```
87
+ - Space 환경 변수 `HF_E2E_MODEL_TOKEN`, `HF_E2E_MODEL_REPO_ID`가 올바르게 설정돼 있어야 `/model` 및 `/trigger`가 정상 동작합니다.
88
+
89
+
90
+ ## 모델 입력 사양 (Flutter 참고)
91
+
92
+ - 입력 형상: `(batch_size, input_dim)`이며 기본 `input_dim = 10 (FEATURE_COLUMNS) + embedding_dim`.
93
+ - `FEATURE_COLUMNS`: `rms_acc`, `rms_gyro`, `mean_freq_acc`, `mean_freq_gyro`, `entropy_acc`, `entropy_gyro`, `jerk_mean`, `jerk_std`, `stability_index`, `fatigue_prev`.
94
+ - `user_emb`: 메타데이터의 `embedding_dim`과 동일한 길이. 부족하면 뒤를 `0.0f`로 패딩.
95
+ - 메타데이터(`cnn_gru_fatigue_metadata.json`)의 `scaler.mean`, `scaler.scale`로 표준화한 뒤 모델에 전달.
96
+
97
+ ### Flutter에서 실행 순서
98
+ - **메타데이터 로드**: JSON에서 `feature_columns`, `scaler.mean`, `scaler.scale`, `embedding_dim`, `input_dim`을 읽는다.
99
+ - **특징 추출**: 측정 버튼을 눌러 얻은 윈도우에서 10개 피처 값을 계산한다.
100
+ - **표준화**: `(value - mean) / scale`을 수행하되 `scale`이 0이면 0으로 대체.
101
+ - **입력 벡터 구성**: `[정규화된 10개 피처, user_emb(패딩 포함)]`을 이어 붙여 `Float32List`로 만든다.
102
+ - **TFLite 실행**: 입력을 `[1, input_dim]`으로 reshape 후 `interpreter.run(input, output)`을 호출한다.
103
+
104
+ ```dart
105
+ final meta = await loadMetadata(); // JSON 파싱: scaler, embedding_dim 등
106
+ final features = computeFeatureVector(); // 길이 10, float
107
+ final userEmb = ensureEmbeddingLength(rawEmb, meta.embeddingDim); // 패딩
108
+
109
+ final normalized = List<double>.generate(features.length, (i) {
110
+ final scale = meta.scalerScale[i] == 0 ? 1.0 : meta.scalerScale[i];
111
+ return (features[i] - meta.scalerMean[i]) / scale;
112
+ });
113
+
114
+ final inputVector = Float32List.fromList([
115
+ ...normalized,
116
+ ...userEmb.map((e) => e.toDouble()),
117
+ ]);
118
+
119
+ final outputBuffer = Float32List(1);
120
+ interpreter.run(inputVector.reshape([1, inputVector.length]), outputBuffer);
121
+ final fatigueScore = outputBuffer[0];
122
+ ```
123
+
124
+ ### 주의
125
+ - 최초 측정부터 바로 예측 가능하며, 더 이상 5개 윈도우 누적이 필요하지 않습니다.
126
+ - `fatigue_prev`는 직전 측정의 피로도 지표로, 값이 없다면 `0` 또는 직전 예측치로 초기화해 주세요.
127
+ - 피처 추출 로직과 임베딩 차원은 백엔드 학습 파이프라인과 동일해야 합니다.
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python 3.10+ 환경을 가정합니다.
2
+ typing_extensions>=4.8.0,<5.0.0
3
+ numpy>=1.23.0,<1.27.0
4
+ torch>=2.0.0
5
+ transformers>=4.30.0
6
+ datasets>=2.14.0
7
+ pandas>=2.0.0
8
+ scikit-learn>=1.3.0
9
+ tqdm>=4.65.0
10
+ schedule>=1.2.0
11
+ huggingface-hub>=0.24.0
12
+ python-dotenv>=1.0.0
13
+ fastapi>=0.110.0
14
+ uvicorn[standard]>=0.23.0
15
+
16
+ # TFLite 변환용 패키지 (호환 버전)
17
+ # PyTorch와 TensorFlow가 함께 설치될 때 충돌 방지를 위해 순서 중요
18
+ onnx==1.15.0
19
+ onnx-tf==1.10.0
20
+ tensorflow==2.15.0
21
+ protobuf<4.0.0
22
+ tensorflow-probability>=0.23.0
run_local.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 로컬 실행 스크립트 (venv 없이)
3
+
4
+ echo "🚀 MuscleCare Train AI - 로컬 실행"
5
+ echo "=================================="
6
+
7
+ # Python 3.10 확인
8
+ PYTHON_CMD=""
9
+ if command -v python3.10 &> /dev/null; then
10
+ PYTHON_CMD="python3.10"
11
+ elif [ -f /usr/local/bin/python3.10 ]; then
12
+ PYTHON_CMD="/usr/local/bin/python3.10"
13
+ else
14
+ echo "❌ Python 3.10이 필요합니다."
15
+ echo " 설치: brew install python@3.10"
16
+ exit 1
17
+ fi
18
+
19
+ echo "✅ Python 버전: $($PYTHON_CMD --version)"
20
+ echo ""
21
+
22
+ # 패키지 설치 확인
23
+ echo "📦 필수 패키지 확인 중..."
24
+ $PYTHON_CMD -c "import torch; import onnx; import tensorflow" 2>/dev/null
25
+ if [ $? -ne 0 ]; then
26
+ echo "⚠️ 일부 패키지가 설치되지 않았습니다."
27
+ echo " 설치: $PYTHON_CMD -m pip install --user -r requirements.txt"
28
+ echo ""
29
+ read -p "지금 설치하시겠습니까? [y/N]: " -n 1 -r
30
+ echo
31
+ if [[ $REPLY =~ ^[Yy]$ ]]; then
32
+ $PYTHON_CMD -m pip install --user -r requirements.txt
33
+ fi
34
+ fi
35
+
36
+ echo ""
37
+ echo "▶️ 자동 학습 스케줄러 실행 중..."
38
+ $PYTHON_CMD start.py
39
+
start.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuscleCare Train AI - 자동 학습 스케줄러 시작 스크립트
3
+ 매주 일요일 자정에 모델을 자동으로 학습합니다.
4
+ """
5
+
6
+ from train_scheduler import main
7
+
8
+ if __name__ == "__main__":
9
+ main()
10
+
train_e2e.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-End 모델 학습 스크립트 (TensorFlow)
3
+ 단일 윈도우(센서 특징 + user_emb)를 입력으로 받아 피로도를 예측하는
4
+ MLP 기반 회귀 모델을 학습하고 SavedModel/TFLite 형식으로 저장합니다.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Dict, Iterable, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import tensorflow as tf
14
+ from sklearn.preprocessing import StandardScaler
15
+ from tensorflow.keras.layers import BatchNormalization, Dense, Dropout, Input
16
+ from tensorflow.keras.models import Model
17
+
18
+ from load_dataset import load_musclecare_dataset
19
+
20
+
21
+ FEATURE_COLUMNS = [
22
+ 'rms_acc',
23
+ 'rms_gyro',
24
+ 'mean_freq_acc',
25
+ 'mean_freq_gyro',
26
+ 'entropy_acc',
27
+ 'entropy_gyro',
28
+ 'jerk_mean',
29
+ 'jerk_std',
30
+ 'stability_index',
31
+ 'fatigue_prev',
32
+ ]
33
+
34
+ DEFAULT_EPOCHS = 30
35
+ DEFAULT_EMBED_DIM = 12
36
+ DEFAULT_BATCH_SIZE = 64
37
+
38
+
39
+ def parse_user_emb(emb: Union[str, Iterable[float], np.ndarray]) -> np.ndarray:
40
+ """사용자 임베딩을 numpy 배열로 변환"""
41
+ arr: Optional[np.ndarray] = None
42
+
43
+ if isinstance(emb, np.ndarray):
44
+ arr = emb.astype(np.float32)
45
+ elif isinstance(emb, str):
46
+ try:
47
+ arr = np.array(json.loads(emb), dtype=np.float32)
48
+ except (json.JSONDecodeError, TypeError):
49
+ arr = None
50
+ elif isinstance(emb, Iterable):
51
+ arr = np.array(list(emb), dtype=np.float32)
52
+
53
+ if arr is None or arr.ndim == 0:
54
+ arr = np.zeros(DEFAULT_EMBED_DIM, dtype=np.float32)
55
+
56
+ return arr
57
+
58
+
59
+ def pad_embedding(embedding: np.ndarray, target_dim: int) -> np.ndarray:
60
+ """임베딩 길이를 target_dim에 맞춰 패딩"""
61
+ padded = np.zeros(target_dim, dtype=np.float32)
62
+ length = min(target_dim, embedding.size)
63
+ padded[:length] = embedding[:length]
64
+ return padded
65
+
66
+
67
+ def dataset_split_to_dataframe(dataset_split) -> pd.DataFrame:
68
+ """HuggingFace Dataset split을 pandas DataFrame으로 변환"""
69
+ if hasattr(dataset_split, "to_pandas"):
70
+ return dataset_split.to_pandas()
71
+ return pd.DataFrame(dataset_split)
72
+
73
+
74
+ def build_dataframe_from_source(
75
+ dataset_source,
76
+ exclude_sessions: Optional[Iterable[str]] = None
77
+ ) -> pd.DataFrame:
78
+ """데이터 소스를 단일 DataFrame으로 통합"""
79
+ frames = []
80
+ exclude_sessions = set(exclude_sessions or [])
81
+
82
+ if hasattr(dataset_source, "items"):
83
+ iterator = dataset_source.items()
84
+ else:
85
+ iterator = [("all", dataset_source)]
86
+
87
+ for split_name, split_dataset in iterator:
88
+ df_split = dataset_split_to_dataframe(split_dataset)
89
+ if df_split.empty:
90
+ continue
91
+
92
+ if exclude_sessions:
93
+ if 'session_id' not in df_split.columns:
94
+ raise KeyError("데이터셋에 'session_id' 컬럼이 없습니다.")
95
+ df_split = df_split[~df_split['session_id'].isin(exclude_sessions)]
96
+
97
+ if not df_split.empty:
98
+ frames.append(df_split)
99
+ print(f" - {split_name}: {len(df_split)}개 샘플 (필터링 후)")
100
+
101
+ if not frames:
102
+ return pd.DataFrame()
103
+
104
+ return pd.concat(frames, ignore_index=True)
105
+
106
+
107
+ def prepare_training_arrays(
108
+ df: pd.DataFrame,
109
+ feature_cols: Iterable[str]
110
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
111
+ """단일 윈도우 입력을 위한 학습 데이터를 생성"""
112
+ required_columns = set(feature_cols) | {'fatigue', 'user_emb'}
113
+ missing_columns = required_columns - set(df.columns)
114
+ if missing_columns:
115
+ raise KeyError(f"데이터셋에 누락된 컬럼이 있습니다: {sorted(missing_columns)}")
116
+
117
+ feature_values = (
118
+ df[list(feature_cols)]
119
+ .astype(np.float32)
120
+ .replace([np.inf, -np.inf], np.nan)
121
+ .fillna(0.0)
122
+ )
123
+
124
+ scaler = StandardScaler()
125
+ features_scaled = scaler.fit_transform(feature_values).astype(np.float32)
126
+
127
+ user_embeddings = np.stack([
128
+ emb.astype(np.float32) if isinstance(emb, np.ndarray) else np.zeros(DEFAULT_EMBED_DIM, dtype=np.float32)
129
+ for emb in df['user_emb']
130
+ ])
131
+
132
+ X = np.concatenate([features_scaled, user_embeddings], axis=1).astype(np.float32)
133
+ y = df['fatigue'].astype(np.float32).to_numpy()
134
+
135
+ return X, y, scaler.mean_.astype(np.float32), scaler.scale_.astype(np.float32)
136
+
137
+
138
+ def build_dense_regression_model(
139
+ input_dim: int,
140
+ learning_rate: float = 0.001
141
+ ) -> Model:
142
+ """단일 윈도우 입력용 MLP 회귀 모델"""
143
+ inputs = Input(shape=(input_dim,), name="features")
144
+
145
+ x = Dense(128, activation='relu')(inputs)
146
+ x = BatchNormalization()(x)
147
+ x = Dropout(0.3)(x)
148
+
149
+ x = Dense(64, activation='relu')(x)
150
+ x = BatchNormalization()(x)
151
+ x = Dropout(0.2)(x)
152
+
153
+ outputs = Dense(1, activation='linear', name='fatigue')(x)
154
+
155
+ model = Model(inputs=inputs, outputs=outputs)
156
+ model.compile(
157
+ optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
158
+ loss='mse',
159
+ metrics=['mae']
160
+ )
161
+ return model
162
+
163
+
164
+ def ensure_embeddings(df: pd.DataFrame) -> Tuple[pd.DataFrame, int]:
165
+ """user_emb 컬럼을 numpy 배열로 정규화하고 통일된 차원으로 패딩"""
166
+ if 'user_emb' not in df.columns:
167
+ raise KeyError("데이터셋에 'user_emb' 컬럼이 없습니다.")
168
+
169
+ df = df.copy()
170
+ df['user_emb'] = df['user_emb'].apply(parse_user_emb)
171
+
172
+ dims = [
173
+ emb.size for emb in df['user_emb']
174
+ if isinstance(emb, np.ndarray) and emb.size > 0
175
+ ]
176
+ target_dim = max(dims) if dims else DEFAULT_EMBED_DIM
177
+
178
+ df['user_emb'] = df['user_emb'].apply(lambda emb: pad_embedding(emb, target_dim))
179
+
180
+ return df, target_dim
181
+
182
+
183
+ def main(
184
+ data_list: Optional[Iterable[Dict]] = None,
185
+ exclude_sessions: Optional[Iterable[str]] = None,
186
+ epochs: int = DEFAULT_EPOCHS
187
+ ) -> Optional[Dict[str, str]]:
188
+ """
189
+ 메인 학습 함수
190
+
191
+ Args:
192
+ data_list: 사용할 데이터 리스트 (None이면 전체 데이터 사용)
193
+ exclude_sessions: 제외할 session_id 집합 (중복 방지용)
194
+ epochs: 학습 에포크 수
195
+ """
196
+ print("=" * 80)
197
+ print("MuscleCare Train AI - TensorFlow Single-Window Training")
198
+ print("=" * 80)
199
+
200
+ tf.keras.utils.set_random_seed(42)
201
+
202
+ # 1️⃣ 데이터 로드
203
+ print("1️⃣ 데이터셋 로딩 중...")
204
+ if data_list is None:
205
+ dataset_source = load_musclecare_dataset()
206
+ df = build_dataframe_from_source(dataset_source, exclude_sessions)
207
+ else:
208
+ df = pd.DataFrame(data_list)
209
+ if exclude_sessions:
210
+ df = df[~df['session_id'].isin(set(exclude_sessions))]
211
+
212
+ if df.empty:
213
+ print("⚠️ 학습 가능한 데이터가 없습니다. 학습을 종료합니다.")
214
+ print("=" * 80)
215
+ return None
216
+
217
+ print(f"✅ 데이터 로드 완료: {len(df)}개 행")
218
+
219
+ # 2️⃣ 사용자 임베딩 정규화
220
+ print("2️⃣ 사용자 임베딩 정규화 중...")
221
+ df, emb_dim = ensure_embeddings(df)
222
+ print(f"✅ 임베딩 차원: {emb_dim}")
223
+
224
+ # 3️⃣ 학습 데이터 생성
225
+ print("3️⃣ 학습 데이터 생성 중...")
226
+ X, y, scaler_mean, scaler_scale = prepare_training_arrays(df, FEATURE_COLUMNS)
227
+ if X.size == 0:
228
+ print("⚠️ 학습할 입력 데이터가 없습니다. 학습을 종료합니다.")
229
+ print("=" * 80)
230
+ return None
231
+ num_samples, input_dim = X.shape
232
+ print(f"✅ 학습 데이터 생성 완료: {num_samples}개 샘플, 입력 차원 {input_dim}")
233
+
234
+ # 4️⃣ 모델 생성
235
+ print("4️⃣ 모델 생성 중...")
236
+ model = build_dense_regression_model(input_dim)
237
+ model.summary(print_fn=lambda x: print(" " + x))
238
+ print("✅ 모델 생성 완료")
239
+
240
+ # 5️⃣ 모델 학습
241
+ print("5️⃣ 모델 학습 시작...")
242
+ callbacks = [
243
+ tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
244
+ tf.keras.callbacks.ReduceLROnPlateau(patience=3, factor=0.5),
245
+ ]
246
+ validation_split = 0.1 if num_samples >= 20 else 0.0
247
+ history = model.fit(
248
+ X,
249
+ y,
250
+ epochs=epochs,
251
+ batch_size=min(DEFAULT_BATCH_SIZE, num_samples),
252
+ shuffle=True,
253
+ validation_split=validation_split,
254
+ callbacks=callbacks,
255
+ verbose=1,
256
+ )
257
+ print("✅ 모델 학습 완료")
258
+
259
+ # 6️⃣ 모델 및 메타데이터 저장
260
+ print("6️⃣ 모델 저장 중...")
261
+ model_dir = './model'
262
+ os.makedirs(model_dir, exist_ok=True)
263
+
264
+ keras_model_path = os.path.join(model_dir, 'cnn_gru_fatigue.keras')
265
+ model.save(keras_model_path)
266
+
267
+ metadata = {
268
+ "feature_columns": list(FEATURE_COLUMNS),
269
+ "embedding_dim": emb_dim,
270
+ "input_dim": input_dim,
271
+ "epochs": epochs,
272
+ "num_samples": int(num_samples),
273
+ "scaler": {
274
+ "mean": scaler_mean.tolist(),
275
+ "scale": scaler_scale.tolist(),
276
+ },
277
+ "history": {
278
+ "loss": history.history.get('loss', []),
279
+ "mae": history.history.get('mae', []),
280
+ "val_loss": history.history.get('val_loss', []),
281
+ "val_mae": history.history.get('val_mae', []),
282
+ },
283
+ }
284
+
285
+ metadata_path = os.path.join(model_dir, 'cnn_gru_fatigue_metadata.json')
286
+ with open(metadata_path, 'w', encoding='utf-8') as f:
287
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
288
+
289
+ print(f"✅ 모델 저장 완료: {keras_model_path}")
290
+ print(f" 메타데이터 저장: {metadata_path}")
291
+
292
+ # 7️⃣ TFLite 변환
293
+ print("7️⃣ TFLite 변환 중...")
294
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
295
+ converter.target_spec.supported_ops = [
296
+ tf.lite.OpsSet.TFLITE_BUILTINS,
297
+ tf.lite.OpsSet.SELECT_TF_OPS,
298
+ ]
299
+ converter._experimental_lower_tensor_list_ops = False
300
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
301
+ tflite_model = converter.convert()
302
+
303
+ tflite_model_path = os.path.join(model_dir, 'cnn_gru_fatigue.tflite')
304
+ with open(tflite_model_path, 'wb') as f:
305
+ f.write(tflite_model)
306
+
307
+ print(f"✅ TFLite 모델 저장 완료: {tflite_model_path}")
308
+ print("=" * 80)
309
+
310
+ return {
311
+ "keras": os.path.abspath(keras_model_path),
312
+ "tflite": os.path.abspath(tflite_model_path),
313
+ "metadata": os.path.abspath(metadata_path),
314
+ }
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
319
+
train_scheduler.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 주 1회 자동 모델 학습 스케줄러
3
+ 매주 일요일 자정에 실행되어 모델을 자동으로 업데이트합니다.
4
+ """
5
+
6
+ import schedule
7
+ import time
8
+ import os
9
+ import json
10
+ import shutil
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Dict, Optional
14
+
15
+ from huggingface_hub import HfApi, hf_hub_download
16
+ try:
17
+ from huggingface_hub.utils import HfHubHTTPError
18
+ except ImportError: # fallback for older versions
19
+ HfHubHTTPError = Exception # type: ignore
20
+
21
+ from train_e2e import main as train_main
22
+ from load_dataset import load_musclecare_dataset
23
+
24
+
25
+ class TrainingScheduler:
26
+ """모델 학습 스케줄러 클래스"""
27
+
28
+ def __init__(self, state_file: str = './model/training_state.json'):
29
+ """
30
+ Args:
31
+ state_file: 학습 상태를 저장할 파일 경로
32
+ """
33
+ self.state_file = state_file
34
+ self.state_dir = os.path.dirname(state_file)
35
+ os.makedirs(self.state_dir, exist_ok=True)
36
+ self._hf_token = os.getenv("HF_E2E_MODEL_TOKEN")
37
+ self._hf_repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
38
+ self._hf_state_filename = os.getenv("HF_E2E_MODEL_STATE_FILE", Path(state_file).name)
39
+
40
+ if not os.path.exists(self.state_file):
41
+ self._download_state_from_hub()
42
+
43
+ def load_training_state(self):
44
+ """학습 상태 로드"""
45
+ if os.path.exists(self.state_file):
46
+ try:
47
+ with open(self.state_file, 'r', encoding='utf-8') as f:
48
+ state = json.load(f)
49
+ return state
50
+ except Exception as e:
51
+ print(f"⚠️ 학습 상태 로드 실패: {e}")
52
+ return self._get_default_state()
53
+ if self._download_state_from_hub():
54
+ return self.load_training_state()
55
+ return self._get_default_state()
56
+
57
+ def save_training_state(self, state):
58
+ """학습 상태 저장"""
59
+ try:
60
+ with open(self.state_file, 'w', encoding='utf-8') as f:
61
+ json.dump(state, f, indent=2, ensure_ascii=False)
62
+ self._upload_state_to_hub()
63
+ except Exception as e:
64
+ print(f"⚠️ 학습 상태 저장 실패: {e}")
65
+
66
+ def _get_default_state(self):
67
+ """기본 학습 상태"""
68
+ return {
69
+ 'processed_sessions': [],
70
+ 'last_training_date': None,
71
+ 'model_version': 0,
72
+ 'total_data_count': 0
73
+ }
74
+
75
+ def reset_training_state(self):
76
+ """학습 상태 초기화"""
77
+ state = self._get_default_state()
78
+ self.save_training_state(state)
79
+ return state
80
+
81
+ def get_new_data(self, processed_sessions):
82
+ """
83
+ 새로운 데이터만 수집 (중복 방지)
84
+
85
+ Args:
86
+ processed_sessions: 이미 처리된 session_id 집합
87
+
88
+ Returns:
89
+ list: 새로운 데이터 리스트
90
+ """
91
+ print("📊 새로운 데이터 수집 중...")
92
+ dataset_dict = load_musclecare_dataset()
93
+
94
+ new_data = []
95
+ new_sessions = set()
96
+
97
+ for split_name in dataset_dict.keys():
98
+ for item in dataset_dict[split_name]:
99
+ session_id = item.get('session_id', '')
100
+
101
+ # 중복 체크
102
+ if session_id not in processed_sessions:
103
+ new_data.append(item)
104
+ new_sessions.add(session_id)
105
+
106
+ print(f"✅ 새로운 데이터: {len(new_data)}개 (새로운 세션: {len(new_sessions)}개)")
107
+ return new_data, new_sessions
108
+
109
+ def train_incremental_model(self, new_data, processed_sessions):
110
+ """
111
+ 증분 학습 수행 (전체 데이터로 재학습하되 중복 제외)
112
+
113
+ Args:
114
+ new_data: 새로운 데이터 리스트
115
+ processed_sessions: 이미 처리된 session_id 집합
116
+ """
117
+ if not new_data:
118
+ print("⚠️ 새로운 데이터가 없어 학습을 건너뜁니다.")
119
+ return None
120
+
121
+ print(f"\n🔄 모델 학습 시작 (새로운 데이터: {len(new_data)}개 포함)...")
122
+
123
+ # 전체 데이터를 가져오되, 중복된 세션은 제외
124
+ # train_e2e.py의 main 함수에 exclude_sessions 파라미터 전달
125
+ from train_e2e import main as train_main
126
+ training_outputs = train_main(data_list=None, exclude_sessions=processed_sessions)
127
+
128
+ if isinstance(training_outputs, dict):
129
+ return (
130
+ training_outputs.get('tflite')
131
+ or training_outputs.get('keras')
132
+ or training_outputs.get('metadata')
133
+ )
134
+
135
+ return training_outputs
136
+
137
+ def run_scheduled_training(self) -> Dict[str, Optional[str]]:
138
+ """스케줄된 학습 실행"""
139
+ print("=" * 80)
140
+ print(f"🕛 자동 학습 시작 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
141
+ print("=" * 80)
142
+
143
+ # 학습 상태 로드
144
+ state = self.load_training_state()
145
+ processed_sessions = set(state.get('processed_sessions', []))
146
+
147
+ print(f"📋 현재 상태:")
148
+ print(f" - 처리된 세션 수: {len(processed_sessions)}")
149
+ print(f" - 마지막 학습일: {state.get('last_training_date', '없음')}")
150
+ print(f" - 모델 버전: {state.get('model_version', 0)}")
151
+
152
+ # 새로운 데이터 수집
153
+ new_data, new_sessions = self.get_new_data(processed_sessions)
154
+
155
+ result: Dict[str, Optional[str]] = {
156
+ "status": "skipped",
157
+ "model_path": None,
158
+ "new_data_count": len(new_data),
159
+ }
160
+
161
+ if new_data:
162
+ # 증분 학습 수행 (전체 데이터로 재학습하되 중복 제외)
163
+ model_path = self.train_incremental_model(
164
+ new_data,
165
+ processed_sessions
166
+ )
167
+
168
+ if model_path:
169
+ # 학습 상태 업데이트
170
+ processed_sessions.update(new_sessions)
171
+ state['processed_sessions'] = list(processed_sessions)
172
+ state['last_training_date'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
173
+ new_version = state.get('model_version', 0) + 1
174
+ state['model_version'] = new_version
175
+ state['total_data_count'] = state.get('total_data_count', 0) + len(new_data)
176
+
177
+ self.save_training_state(state)
178
+
179
+ print("\n✅ 자동 학습 완료!")
180
+ print(f" - 모델 경로: {model_path}")
181
+ print(f" - 새로운 모델 버전: {state['model_version']}")
182
+ print(f" - 총 처리된 데이터: {state['total_data_count']}개")
183
+ result.update({
184
+ "status": "trained",
185
+ "model_path": model_path,
186
+ "new_data_count": len(new_data),
187
+ "model_version": str(state['model_version']),
188
+ })
189
+
190
+ else:
191
+ print("\n⚠️ 새로운 데이터가 없어 학습을 건너뜁니다.")
192
+
193
+ print("=" * 80)
194
+ return result
195
+
196
+ def _get_hf_api(self) -> Optional[HfApi]:
197
+ if not self._hf_repo_id or not self._hf_token:
198
+ return None
199
+ return HfApi(token=self._hf_token)
200
+
201
+ def _download_state_from_hub(self) -> bool:
202
+ api = self._get_hf_api()
203
+ if api is None:
204
+ return False
205
+ try:
206
+ downloaded_path = hf_hub_download(
207
+ repo_id=self._hf_repo_id,
208
+ filename=self._hf_state_filename,
209
+ repo_type="model",
210
+ token=self._hf_token,
211
+ local_dir=self.state_dir,
212
+ local_dir_use_symlinks=False,
213
+ )
214
+ shutil.move(downloaded_path, self.state_file)
215
+ print(f"✅ Hugging Face Hub에서 학습 상태를 다운로드했습니다: {self._hf_state_filename}")
216
+ return True
217
+ except Exception as e:
218
+ status = getattr(getattr(e, "response", None), "status_code", None)
219
+ if status == 404:
220
+ print("ℹ️ Hugging Face Hub에 학습 상태 파일이 없어 새로 생성합니다.")
221
+ else:
222
+ print(f"⚠️ 학습 상태 다운로드 중 오류가 발생했습니다: {e}")
223
+ return False
224
+
225
+ def _upload_state_to_hub(self) -> None:
226
+ api = self._get_hf_api()
227
+ if api is None:
228
+ return
229
+ try:
230
+ api.create_repo(repo_id=self._hf_repo_id, repo_type="model", private=False, exist_ok=True)
231
+ api.upload_file(
232
+ path_or_fileobj=self.state_file,
233
+ path_in_repo=self._hf_state_filename,
234
+ repo_id=self._hf_repo_id,
235
+ repo_type="model",
236
+ commit_message="Update training state",
237
+ )
238
+ print("✅ 학습 상태를 Hugging Face Hub에 업로드했습니다.")
239
+ except Exception as e:
240
+ print(f"⚠️ 학습 상태 업로드 실패: {e}")
241
+
242
+
243
+ def main():
244
+ """메인 함수"""
245
+ scheduler = TrainingScheduler()
246
+
247
+ # 매주 일요일 자정에 실행
248
+ schedule.every().day.at("00:00").do(scheduler.run_scheduled_training)
249
+
250
+ print("📅 자동 학습 스케줄러 시작")
251
+ print(" - 실행 시간: 매일 00:00")
252
+ print(" - 종료하려면 Ctrl+C를 누르세요\n")
253
+
254
+ # 스케줄러 실행
255
+ try:
256
+ while True:
257
+ schedule.run_pending()
258
+ time.sleep(60) # 1분마다 체크
259
+ except KeyboardInterrupt:
260
+ print("\n\n⏹️ 스케줄러 종료")
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()
265
+