Spaces:
Sleeping
Sleeping
Commit ·
2b83ee8
0
Parent(s):
Spaces용 코드만 포함 (모델 파일 제외)
Browse files- .dockerignore +22 -0
- .gitattributes +35 -0
- .gitignore +62 -0
- Dockerfile +19 -0
- README.md +92 -0
- app.py +262 -0
- convert_tflite.py +336 -0
- load_dataset.py +38 -0
- model.md +127 -0
- requirements.txt +22 -0
- run_local.sh +39 -0
- start.py +10 -0
- train_e2e.py +319 -0
- train_scheduler.py +265 -0
.dockerignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.so
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info
|
| 9 |
+
dist
|
| 10 |
+
build
|
| 11 |
+
.git
|
| 12 |
+
.gitignore
|
| 13 |
+
.env
|
| 14 |
+
.venv
|
| 15 |
+
venv/
|
| 16 |
+
ENV/
|
| 17 |
+
env/
|
| 18 |
+
.vscode
|
| 19 |
+
.idea
|
| 20 |
+
*.md
|
| 21 |
+
!README.md
|
| 22 |
+
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.Python
|
| 6 |
+
build/
|
| 7 |
+
develop-eggs/
|
| 8 |
+
dist/
|
| 9 |
+
downloads/
|
| 10 |
+
eggs/
|
| 11 |
+
.eggs/
|
| 12 |
+
lib/
|
| 13 |
+
lib64/
|
| 14 |
+
parts/
|
| 15 |
+
sdist/
|
| 16 |
+
var/
|
| 17 |
+
wheels/
|
| 18 |
+
*.egg-info/
|
| 19 |
+
.installed.cfg
|
| 20 |
+
*.egg
|
| 21 |
+
MANIFEST
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.env
|
| 25 |
+
.venv
|
| 26 |
+
env/
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
env.bak/
|
| 30 |
+
venv.bak/
|
| 31 |
+
|
| 32 |
+
# IDE
|
| 33 |
+
.vscode/
|
| 34 |
+
.idea/
|
| 35 |
+
*.swp
|
| 36 |
+
*.swo
|
| 37 |
+
*~
|
| 38 |
+
|
| 39 |
+
# OS
|
| 40 |
+
.DS_Store
|
| 41 |
+
Thumbs.db
|
| 42 |
+
|
| 43 |
+
# Logs
|
| 44 |
+
*.log
|
| 45 |
+
|
| 46 |
+
# Model files (큰 파일이므로 Git에서 제외)
|
| 47 |
+
*.pth
|
| 48 |
+
*.pt
|
| 49 |
+
*.ckpt
|
| 50 |
+
*.bin
|
| 51 |
+
*.safetensors
|
| 52 |
+
*.tflite
|
| 53 |
+
*.keras
|
| 54 |
+
*.h5
|
| 55 |
+
*.pb
|
| 56 |
+
*.onnx
|
| 57 |
+
*.pkl
|
| 58 |
+
*.pickle
|
| 59 |
+
|
| 60 |
+
# Model directory
|
| 61 |
+
model/
|
| 62 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1
|
| 7 |
+
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends build-essential && \
|
| 10 |
+
pip install --upgrade pip && \
|
| 11 |
+
pip install -r requirements.txt && \
|
| 12 |
+
apt-get purge -y build-essential && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 19 |
+
|
README.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MuscleCare Train AI
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# MuscleCare Train AI
|
| 12 |
+
|
| 13 |
+
CNN + GRU 기반 근육 피로도 예측 모델 자동 학습 시스템
|
| 14 |
+
|
| 15 |
+
## 🚀 주요 기능
|
| 16 |
+
|
| 17 |
+
- **자동 데이터 로딩**: Hugging Face `Merry99/MuscleCare-DataSet` 데이터셋 자동 로드
|
| 18 |
+
- **CNN + GRU 모델**: 시퀀스 데이터에서 피로도 예측
|
| 19 |
+
- **자동 학습 스케줄링**: 매주 일요일 자정 자동 모델 업데이트
|
| 20 |
+
- **중복 방지**: 이미 학습된 세션 데이터 자동 제외
|
| 21 |
+
- **TFLite 변환**: 모바일 배포를 위한 TFLite 모델 자동 생성 (필수)
|
| 22 |
+
|
| 23 |
+
## 📦 실행 방법
|
| 24 |
+
|
| 25 |
+
### Docker 사용 (권장)
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# 이미지 빌드
|
| 29 |
+
docker build -t musclecare-train-ai .
|
| 30 |
+
|
| 31 |
+
# 실행
|
| 32 |
+
docker run musclecare-train-ai
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 로컬 실행 (Python 3.10 필요)
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# Python 3.10 확인
|
| 39 |
+
python3.10 --version
|
| 40 |
+
|
| 41 |
+
# 패키지 설치
|
| 42 |
+
python3.10 -m pip install -r requirements.txt
|
| 43 |
+
|
| 44 |
+
# 실행
|
| 45 |
+
python3.10 start.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
또는 스크립트 사용:
|
| 49 |
+
```bash
|
| 50 |
+
./run_local.sh
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## 🔄 전체 플로우
|
| 54 |
+
|
| 55 |
+
1. **데이터 로드**: `load_dataset.py`로 Hugging Face 데이터셋 로드
|
| 56 |
+
2. **모델 학습**: `train_e2e.py`로 CNN + GRU 모델 학습
|
| 57 |
+
3. **모델 저장**: 학습된 모델을 `./model/fatigue_net_v2.pt`에 저장 (PyTorch state_dict 형식)
|
| 58 |
+
4. **TFLite 변환**: `convert_tflite.py`로 TFLite 모델 생성 → `./model/fatigue_net_v2.tflite`
|
| 59 |
+
|
| 60 |
+
## 📁 파일 구조
|
| 61 |
+
|
| 62 |
+
- `load_dataset.py`: Hugging Face 데이터셋 로드
|
| 63 |
+
- `train_e2e.py`: CNN + GRU 모델 학습 (PyTorch state_dict 형식으로 저장)
|
| 64 |
+
- `convert_tflite.py`: PyTorch → TFLite 변환
|
| 65 |
+
- `train_scheduler.py`: 자동 학습 스케줄러
|
| 66 |
+
- `start.py`: 자동 학습 스케줄러 시작 스크립트
|
| 67 |
+
- `app.py`: FastAPI 애플리케이션 (나중에 구현 예정)
|
| 68 |
+
|
| 69 |
+
## 🔧 요구사항
|
| 70 |
+
|
| 71 |
+
- Python 3.10 (TFLite 변환 필수)
|
| 72 |
+
- PyTorch 2.0+
|
| 73 |
+
- ONNX, ONNX-TF, TensorFlow (TFLite 변환용)
|
| 74 |
+
|
| 75 |
+
## 📝 모델 저장 위치
|
| 76 |
+
|
| 77 |
+
- PyTorch 모델: `./model/fatigue_net_v2.pt` (state_dict 형식)
|
| 78 |
+
- **TFLite 모델: `./model/fatigue_net_v2.tflite`** (모바일 배포용, 필수)
|
| 79 |
+
- 학습 상태: `./model/training_state.json`
|
| 80 |
+
|
| 81 |
+
## ⚠️ 중요 사항
|
| 82 |
+
|
| 83 |
+
- **TFLite 변환은 필수입니다** (모바일 디바이스에서 실행 필요)
|
| 84 |
+
- 모델은 반드시 PyTorch state_dict 형식으로 저장되어야 합니다 (TorchScript 형식 불가)
|
| 85 |
+
- Python 3.10 이상이 필요합니다 (TFLite 변환 패키지 호환성)
|
| 86 |
+
|
| 87 |
+
## 🔄 자동 학습 스케줄
|
| 88 |
+
|
| 89 |
+
- 실행 시간: 매주 일요일 자정 (00:00)
|
| 90 |
+
- 중복 방지: `training_state.json`에 저장된 세션 ID는 자동 제외
|
| 91 |
+
- 모델 버전: 자동 증가
|
| 92 |
+
- TFLite 변환: 학습 후 자동 수행
|
app.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI 앱: 수동 학습 및 Hugging Face 업로드 트리거"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, Optional
|
| 11 |
+
|
| 12 |
+
import schedule
|
| 13 |
+
from fastapi import FastAPI, HTTPException
|
| 14 |
+
from fastapi.responses import FileResponse
|
| 15 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 16 |
+
try:
|
| 17 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 18 |
+
except ImportError: # pragma: no cover
|
| 19 |
+
HfHubHTTPError = Exception # type: ignore
|
| 20 |
+
from pydantic import BaseModel
|
| 21 |
+
|
| 22 |
+
from train_scheduler import TrainingScheduler
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
app = FastAPI(
|
| 26 |
+
title="MuscleCare Train Scheduler API",
|
| 27 |
+
description="수동으로 모델 학습 및 Hugging Face 업로드를 트리거합니다.",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
_scheduler = TrainingScheduler()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TrainResponse(BaseModel):
|
| 34 |
+
status: str
|
| 35 |
+
new_data_count: int
|
| 36 |
+
model_path: Optional[str] = None
|
| 37 |
+
hub_url: Optional[str] = None
|
| 38 |
+
model_version: Optional[int] = None
|
| 39 |
+
message: str
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.on_event("startup")
|
| 43 |
+
def startup_training() -> None:
|
| 44 |
+
"""서버 시작 시 자동으로 모델 학습을 실행합니다."""
|
| 45 |
+
try:
|
| 46 |
+
print("🚀 서버 시작: 자동 모델 학습을 시작합니다...")
|
| 47 |
+
result = _scheduler.run_scheduled_training()
|
| 48 |
+
if result["status"] == "trained":
|
| 49 |
+
print(f"✅ 서버 시작 시 학습 완료: {result['new_data_count']}개 데이터로 학습됨")
|
| 50 |
+
else:
|
| 51 |
+
print(f"ℹ️ 서버 시작 시 학습 건너뜀: {result.get('message', '새로운 데이터 없음')}")
|
| 52 |
+
except Exception as exc:
|
| 53 |
+
print(f"⚠️ 서버 시작 시 학습 실패: {exc}")
|
| 54 |
+
|
| 55 |
+
# 기존 스케줄링 설정
|
| 56 |
+
schedule.clear()
|
| 57 |
+
schedule.every().sunday.at("00:00").do(_scheduler.run_scheduled_training)
|
| 58 |
+
|
| 59 |
+
def _run_schedule() -> None:
|
| 60 |
+
while True:
|
| 61 |
+
schedule.run_pending()
|
| 62 |
+
time.sleep(60)
|
| 63 |
+
|
| 64 |
+
threading.Thread(target=_run_schedule, daemon=True).start()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.get("/health")
|
| 68 |
+
def health_check() -> dict:
|
| 69 |
+
return {"status": "ok"}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.get("/")
|
| 73 |
+
def root() -> dict:
|
| 74 |
+
return {
|
| 75 |
+
"message": "MuscleCare Train Scheduler API가 실행 중입니다.",
|
| 76 |
+
"endpoints": {
|
| 77 |
+
"health": "/health",
|
| 78 |
+
"trigger": "/trigger",
|
| 79 |
+
},
|
| 80 |
+
"docs": "/docs",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _upload_to_hub(model_path: str) -> Optional[str]:
|
| 85 |
+
token = os.getenv("HF_E2E_MODEL_TOKEN")
|
| 86 |
+
repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
|
| 87 |
+
|
| 88 |
+
if not token or not repo_id:
|
| 89 |
+
raise HTTPException(
|
| 90 |
+
status_code=400,
|
| 91 |
+
detail="환경 변수 HF_E2E_MODEL_TOKEN / HF_E2E_MODEL_REPO_ID가 설정되어 있지 않습니다.",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
path = Path(model_path)
|
| 95 |
+
if not path.exists():
|
| 96 |
+
raise HTTPException(status_code=404, detail=f"모델 파일을 찾을 수 없습니다: {model_path}")
|
| 97 |
+
|
| 98 |
+
api = HfApi(token=token)
|
| 99 |
+
api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
api.upload_file(
|
| 102 |
+
path_or_fileobj=path,
|
| 103 |
+
path_in_repo=path.name,
|
| 104 |
+
repo_id=repo_id,
|
| 105 |
+
repo_type="model",
|
| 106 |
+
commit_message="Manual scheduler trigger upload",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return f"https://huggingface.co/{repo_id}"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# TODO: include version info in response body
|
| 113 |
+
@app.get("/model")
|
| 114 |
+
@app.get("/model/{version:int}")
|
| 115 |
+
def download_model(
|
| 116 |
+
version: Optional[int] = None,
|
| 117 |
+
filename: Optional[str] = None
|
| 118 |
+
) -> FileResponse:
|
| 119 |
+
repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
|
| 120 |
+
token = os.getenv("HF_E2E_MODEL_TOKEN")
|
| 121 |
+
default_filename = os.getenv("HF_E2E_MODEL_FILE", "cnn_gru_fatigue.tflite")
|
| 122 |
+
|
| 123 |
+
if not repo_id:
|
| 124 |
+
raise HTTPException(
|
| 125 |
+
status_code=400,
|
| 126 |
+
detail="환경 변수 HF_E2E_MODEL_REPO_ID가 설정되어 있지 않습니다."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
current_state = _scheduler.load_training_state()
|
| 130 |
+
current_version = int(current_state.get("model_version", 0) or 0)
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
if not version:
|
| 134 |
+
target_filename = filename or default_filename
|
| 135 |
+
local_path = hf_hub_download(
|
| 136 |
+
repo_id=repo_id,
|
| 137 |
+
filename=target_filename,
|
| 138 |
+
repo_type="model",
|
| 139 |
+
token=token,
|
| 140 |
+
local_dir="./model_cache",
|
| 141 |
+
local_dir_use_symlinks=False,
|
| 142 |
+
)
|
| 143 |
+
actual_version = current_version
|
| 144 |
+
else:
|
| 145 |
+
if version > current_version:
|
| 146 |
+
raise HTTPException(
|
| 147 |
+
status_code=404,
|
| 148 |
+
detail=f"현재 모델 버전은 {current_version}입니다. 버전 {version}은 존재하지 않습니다."
|
| 149 |
+
)
|
| 150 |
+
manifest_path = hf_hub_download(
|
| 151 |
+
repo_id=repo_id,
|
| 152 |
+
filename="model_versions.json",
|
| 153 |
+
repo_type="model",
|
| 154 |
+
token=token,
|
| 155 |
+
local_dir="./model_cache",
|
| 156 |
+
local_dir_use_symlinks=False,
|
| 157 |
+
)
|
| 158 |
+
with open(manifest_path, "r", encoding="utf-8") as f:
|
| 159 |
+
manifest = json.load(f)
|
| 160 |
+
|
| 161 |
+
version_entry = next(
|
| 162 |
+
(entry for entry in manifest if entry.get("version") == version),
|
| 163 |
+
None
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if version_entry is None:
|
| 167 |
+
raise HTTPException(
|
| 168 |
+
status_code=404,
|
| 169 |
+
detail=f"버전 {version}에 해당하는 모델을 찾을 수 없습니다."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
target_filename = filename or version_entry.get("filename")
|
| 173 |
+
target_revision = version_entry.get("commit")
|
| 174 |
+
|
| 175 |
+
if not target_filename or not target_revision:
|
| 176 |
+
raise HTTPException(
|
| 177 |
+
status_code=500,
|
| 178 |
+
detail=f"버전 {version} 메타데이터가 올바르지 않습니다."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
local_path = hf_hub_download(
|
| 182 |
+
repo_id=repo_id,
|
| 183 |
+
filename=target_filename,
|
| 184 |
+
repo_type="model",
|
| 185 |
+
token=token,
|
| 186 |
+
local_dir="./model_cache",
|
| 187 |
+
local_dir_use_symlinks=False,
|
| 188 |
+
revision=target_revision,
|
| 189 |
+
)
|
| 190 |
+
actual_version = version
|
| 191 |
+
except Exception as exc:
|
| 192 |
+
status = getattr(getattr(exc, "response", None), "status_code", None)
|
| 193 |
+
if status == 404:
|
| 194 |
+
raise HTTPException(
|
| 195 |
+
status_code=404,
|
| 196 |
+
detail="허깅페이스에서 지정한 모델 파일을 찾을 수 없습니다."
|
| 197 |
+
) from exc
|
| 198 |
+
raise HTTPException(
|
| 199 |
+
status_code=500,
|
| 200 |
+
detail=f"Hugging Face Hub 다운로드 실패: {exc}"
|
| 201 |
+
) from exc
|
| 202 |
+
|
| 203 |
+
response = FileResponse(
|
| 204 |
+
path=local_path,
|
| 205 |
+
filename=Path(target_filename).name,
|
| 206 |
+
media_type="application/octet-stream"
|
| 207 |
+
)
|
| 208 |
+
response.headers["X-Model-Version"] = str(actual_version)
|
| 209 |
+
response.headers["X-Model-Filename"] = Path(target_filename).name
|
| 210 |
+
return response
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class ResetStateResponse(BaseModel):
|
| 214 |
+
status: str
|
| 215 |
+
state: Dict[str, Any]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@app.post("/state/reset", response_model=ResetStateResponse)
|
| 219 |
+
def reset_training_state() -> ResetStateResponse:
|
| 220 |
+
try:
|
| 221 |
+
state = _scheduler.reset_training_state()
|
| 222 |
+
return ResetStateResponse(
|
| 223 |
+
status="reset",
|
| 224 |
+
state=state,
|
| 225 |
+
)
|
| 226 |
+
except Exception as exc: # pylint: disable=broad-except
|
| 227 |
+
raise HTTPException(status_code=500, detail=f"학습 상태 초기화에 실패했습니다: {exc}") from exc
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@app.post("/trigger", response_model=TrainResponse)
|
| 231 |
+
def trigger_training(upload: bool = True) -> TrainResponse:
|
| 232 |
+
try:
|
| 233 |
+
result = _scheduler.run_scheduled_training()
|
| 234 |
+
except Exception as exc: # pylint: disable=broad-except
|
| 235 |
+
raise HTTPException(status_code=500, detail=f"학습 실행 중 오류가 발생했습니다: {exc}") from exc
|
| 236 |
+
|
| 237 |
+
message = "새로운 데이터가 없어 학습을 건너뜁니다."
|
| 238 |
+
hub_url = None
|
| 239 |
+
|
| 240 |
+
if result["status"] == "trained":
|
| 241 |
+
message = "모델 학습이 완료되었습니다."
|
| 242 |
+
model_path = result.get("model_path")
|
| 243 |
+
|
| 244 |
+
if upload and model_path:
|
| 245 |
+
try:
|
| 246 |
+
hub_url = _upload_to_hub(model_path)
|
| 247 |
+
message = "모델 학습 및 업로드가 완료되었습니다."
|
| 248 |
+
except HTTPException:
|
| 249 |
+
raise
|
| 250 |
+
except Exception as exc: # pylint: disable=broad-except
|
| 251 |
+
raise HTTPException(status_code=500, detail=f"Hugging Face 업로드 실패: {exc}") from exc
|
| 252 |
+
|
| 253 |
+
return TrainResponse(
|
| 254 |
+
status=result["status"],
|
| 255 |
+
new_data_count=result["new_data_count"],
|
| 256 |
+
model_path=result.get("model_path"),
|
| 257 |
+
hub_url=hub_url,
|
| 258 |
+
message=message,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
__all__ = ["app"]
|
convert_tflite.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch 모델을 TensorFlow Lite 형식으로 변환하는 스크립트
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# 선택적 임포트
|
| 11 |
+
ONNX_AVAILABLE = False
|
| 12 |
+
TF_AVAILABLE = False
|
| 13 |
+
ONNX_TF_AVAILABLE = False
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import onnx
|
| 17 |
+
ONNX_AVAILABLE = True
|
| 18 |
+
except (ImportError, SyntaxError, Exception) as e:
|
| 19 |
+
ONNX_AVAILABLE = False
|
| 20 |
+
if not isinstance(e, ImportError):
|
| 21 |
+
print(f"⚠️ onnx 패키지 로드 중 오류 발생: {type(e).__name__}")
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
TF_AVAILABLE = True
|
| 26 |
+
except (ImportError, SyntaxError, Exception) as e:
|
| 27 |
+
TF_AVAILABLE = False
|
| 28 |
+
if not isinstance(e, ImportError):
|
| 29 |
+
print(f"⚠️ tensorflow 패키지 로드 중 오류 발생: {type(e).__name__}")
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
# onnx-tf는 실제로 사용할 때 임포트하도록 변경
|
| 33 |
+
# from onnx_tf.backend import prepare
|
| 34 |
+
ONNX_TF_AVAILABLE = True
|
| 35 |
+
except (ImportError, SyntaxError, Exception) as e:
|
| 36 |
+
ONNX_TF_AVAILABLE = False
|
| 37 |
+
if not isinstance(e, ImportError):
|
| 38 |
+
print(f"⚠️ onnx-tf 패키지 로드 중 오류 발생: {type(e).__name__}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class FatigueNet(nn.Module):
|
| 42 |
+
"""CNN + GRU 기반 피로도 예측 모델 (PyTorch 버전)"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, input_dim=2, hidden_dim=64, num_layers=2, output_dim=1):
|
| 45 |
+
super(FatigueNet, self).__init__()
|
| 46 |
+
|
| 47 |
+
# CNN 부분
|
| 48 |
+
self.conv1 = nn.Conv1d(
|
| 49 |
+
in_channels=input_dim,
|
| 50 |
+
out_channels=32,
|
| 51 |
+
kernel_size=1,
|
| 52 |
+
padding=0
|
| 53 |
+
)
|
| 54 |
+
self.conv2 = nn.Conv1d(
|
| 55 |
+
in_channels=32,
|
| 56 |
+
out_channels=64,
|
| 57 |
+
kernel_size=1,
|
| 58 |
+
padding=0
|
| 59 |
+
)
|
| 60 |
+
self.relu = nn.ReLU()
|
| 61 |
+
|
| 62 |
+
# GRU 부분 (TFLite 호환성을 위해 linear_before_reset=False)
|
| 63 |
+
self.gru = nn.GRU(
|
| 64 |
+
input_size=64,
|
| 65 |
+
hidden_size=hidden_dim,
|
| 66 |
+
num_layers=num_layers,
|
| 67 |
+
batch_first=True,
|
| 68 |
+
dropout=0.2 if num_layers > 1 else 0
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Fully Connected 레이어
|
| 72 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
| 73 |
+
self.dropout = nn.Dropout(0.3)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
if x.dim() == 2:
|
| 77 |
+
x = x.unsqueeze(1)
|
| 78 |
+
|
| 79 |
+
x = x.permute(0, 2, 1)
|
| 80 |
+
|
| 81 |
+
x = self.conv1(x)
|
| 82 |
+
x = self.relu(x)
|
| 83 |
+
|
| 84 |
+
x = self.conv2(x)
|
| 85 |
+
x = self.relu(x)
|
| 86 |
+
|
| 87 |
+
x = x.permute(0, 2, 1)
|
| 88 |
+
|
| 89 |
+
gru_out, _ = self.gru(x)
|
| 90 |
+
|
| 91 |
+
last_output = gru_out[:, -1, :]
|
| 92 |
+
|
| 93 |
+
last_output = self.dropout(last_output)
|
| 94 |
+
output = self.fc(last_output)
|
| 95 |
+
|
| 96 |
+
return output
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def convert_pytorch_to_tflite(
|
| 100 |
+
pytorch_model_path='./model/fatigue_net_v2.pt',
|
| 101 |
+
tflite_model_path='./model/fatigue_net_v2.tflite',
|
| 102 |
+
input_shape=(1, 1, 2) # (batch, seq_len, features)
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
PyTorch 모델을 TensorFlow Lite로 변환
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
pytorch_model_path: PyTorch 모델 파일 경로
|
| 109 |
+
tflite_model_path: 저장할 TFLite 모델 파일 경로
|
| 110 |
+
input_shape: 입력 텐서 형태 (batch, seq_len, features)
|
| 111 |
+
"""
|
| 112 |
+
print("=" * 80)
|
| 113 |
+
print("PyTorch 모델을 TensorFlow Lite로 변환")
|
| 114 |
+
print("=" * 80)
|
| 115 |
+
|
| 116 |
+
# 필수 패키지 확인
|
| 117 |
+
if not ONNX_AVAILABLE or not TF_AVAILABLE or not ONNX_TF_AVAILABLE:
|
| 118 |
+
print("\n❌ 필수 패키지가 설치되지 않았거나 호환성 문제가 있습니다.")
|
| 119 |
+
print("\n📋 Python 버전 확인:")
|
| 120 |
+
import sys
|
| 121 |
+
print(f" 현재 Python 버전: {sys.version}")
|
| 122 |
+
print(f" 권장 Python 버전: 3.10 이상")
|
| 123 |
+
|
| 124 |
+
if sys.version_info < (3, 10):
|
| 125 |
+
print("\n⚠️ Python 3.9에서는 일부 패키지 호환성 문제가 있을 수 있습니다.")
|
| 126 |
+
print(" Python 3.10 이상으로 업그레이드하거나, 다음을 시도하세요:")
|
| 127 |
+
print(" - 가상환경에서 Python 3.10+ 사용")
|
| 128 |
+
print(" - 또는 호환되는 패키지 버전 설치")
|
| 129 |
+
|
| 130 |
+
print("\n📦 설치 명령어:")
|
| 131 |
+
print(" 권장 버전 (Python 3.10 이상):")
|
| 132 |
+
print(" pip install onnx==1.15.0 onnx-tf==1.10.0 tensorflow==2.15.0")
|
| 133 |
+
print("\n⚠️ 참고: Python 3.9에서는 일부 패키지 설치 중 에러가 발생할 수 있습니다.")
|
| 134 |
+
print(" Python 3.10 이상 사용을 강력히 권장합니다.")
|
| 135 |
+
|
| 136 |
+
print("\n❌ TFLite 변환은 필수입니다. 모바일 디바이스에서 실행하기 위해 필요합니다.")
|
| 137 |
+
print(" 필수 패키지를 설치하고 다시 시도하세요.")
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
# 1️⃣ PyTorch 모델 로드
|
| 141 |
+
print("\n1️⃣ PyTorch 모델 로드 중...")
|
| 142 |
+
if not os.path.exists(pytorch_model_path):
|
| 143 |
+
raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {pytorch_model_path}")
|
| 144 |
+
|
| 145 |
+
# TorchScript 파일인지 먼저 확인
|
| 146 |
+
try:
|
| 147 |
+
checkpoint = torch.jit.load(pytorch_model_path, map_location='cpu')
|
| 148 |
+
if isinstance(checkpoint, torch.jit.ScriptModule):
|
| 149 |
+
raise ValueError(
|
| 150 |
+
f"❌ {pytorch_model_path}는 TorchScript 형식입니다.\n"
|
| 151 |
+
"TFLite 변환을 위해서는 PyTorch state_dict 형식 모델이 필요합니다.\n"
|
| 152 |
+
"모델을 다시 학습하거나 올바른 형식의 모델 파일을 사용하세요."
|
| 153 |
+
)
|
| 154 |
+
except:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
# 일반 PyTorch 모델 로드
|
| 158 |
+
checkpoint = torch.load(pytorch_model_path, map_location='cpu')
|
| 159 |
+
|
| 160 |
+
# 일반 PyTorch 모델인지 확인
|
| 161 |
+
if not isinstance(checkpoint, dict) or 'model_state_dict' not in checkpoint:
|
| 162 |
+
raise ValueError(
|
| 163 |
+
f"❌ 올바른 PyTorch 모델 형식이 아닙니다.\n"
|
| 164 |
+
f"'{pytorch_model_path}' 파일에 'model_state_dict' 키가 필요합니다.\n"
|
| 165 |
+
"모델을 다시 학습하거나 올바른 형식의 모델 파일을 사용하세요."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
model_config = checkpoint.get('model_config', {
|
| 169 |
+
'input_dim': 2,
|
| 170 |
+
'hidden_dim': 64,
|
| 171 |
+
'num_layers': 2,
|
| 172 |
+
'output_dim': 1
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
model = FatigueNet(**model_config)
|
| 176 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 177 |
+
model.eval()
|
| 178 |
+
print(f"✅ 모델 로드 완료: {pytorch_model_path}")
|
| 179 |
+
print(f" 모델 설정: {model_config}\n")
|
| 180 |
+
|
| 181 |
+
# 2️⃣ ONNX로 변환
|
| 182 |
+
print("2️⃣ ONNX 형식으로 변환 중...")
|
| 183 |
+
onnx_model_path = './model/fatigue_net_v2.onnx'
|
| 184 |
+
os.makedirs('./model', exist_ok=True)
|
| 185 |
+
|
| 186 |
+
# 더미 입력 생성 (고정 batch_size=1로 TFLite 호환성 향상)
|
| 187 |
+
dummy_input = torch.randn(1, 1, 2) # (batch=1, seq_len=1, features=2)
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
# GRU를 RNN으로 변환하거나 TFLite 호환 옵션 사용
|
| 191 |
+
torch.onnx.export(
|
| 192 |
+
model,
|
| 193 |
+
dummy_input,
|
| 194 |
+
onnx_model_path,
|
| 195 |
+
export_params=True,
|
| 196 |
+
opset_version=11, # onnx-tf 호환성을 위해 11로 낮춤
|
| 197 |
+
do_constant_folding=True,
|
| 198 |
+
input_names=['input'],
|
| 199 |
+
output_names=['output'],
|
| 200 |
+
dynamic_axes={
|
| 201 |
+
'input': {0: 'batch_size', 1: 'sequence_length'},
|
| 202 |
+
'output': {0: 'batch_size'}
|
| 203 |
+
},
|
| 204 |
+
# GRU 관련 호환성 옵션
|
| 205 |
+
custom_opsets=None,
|
| 206 |
+
verbose=False
|
| 207 |
+
)
|
| 208 |
+
print(f"✅ ONNX 변환 완료: {onnx_model_path}\n")
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"⚠️ ONNX 변환 중 경고 (계속 진행): {e}\n")
|
| 211 |
+
|
| 212 |
+
# 3️⃣ ONNX를 TensorFlow로 변환
|
| 213 |
+
print("3️⃣ TensorFlow 형식으로 변환 중...")
|
| 214 |
+
try:
|
| 215 |
+
from onnx_tf.backend import prepare
|
| 216 |
+
|
| 217 |
+
# ONNX 모델 로드 및 GRU 속성 수정
|
| 218 |
+
onnx_model = onnx.load(onnx_model_path)
|
| 219 |
+
|
| 220 |
+
# GRU 노드의 linear_before_reset 속성을 0으로 설정 (TensorFlow 호환)
|
| 221 |
+
for node in onnx_model.graph.node:
|
| 222 |
+
if node.op_type == 'GRU':
|
| 223 |
+
# linear_before_reset 속성을 찾아서 0으로 설정
|
| 224 |
+
for attr in node.attribute:
|
| 225 |
+
if attr.name == 'linear_before_reset':
|
| 226 |
+
attr.i = 0
|
| 227 |
+
break
|
| 228 |
+
else:
|
| 229 |
+
# linear_before_reset 속성이 없으면 추가
|
| 230 |
+
attr = onnx.helper.make_attribute('linear_before_reset', 0)
|
| 231 |
+
node.attribute.append(attr)
|
| 232 |
+
|
| 233 |
+
tf_rep = prepare(onnx_model)
|
| 234 |
+
|
| 235 |
+
# TensorFlow SavedModel로 저장
|
| 236 |
+
tf_model_path = './model/tf_model'
|
| 237 |
+
tf_rep.export_graph(tf_model_path)
|
| 238 |
+
print(f"✅ TensorFlow 변환 완료: {tf_model_path}\n")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
print(f"❌ TensorFlow 변환 실패: {e}")
|
| 241 |
+
print("⚠️ ONNX-TF 변환이 실패했습니다.\n")
|
| 242 |
+
print("❌ TFLite 변환은 필수입니다. 모바일 디바이스에서 실행하기 위해 필요합니다.")
|
| 243 |
+
print(" 에러를 해결하고 다시 시도하세요.")
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
# 4️⃣ TensorFlow Lite로 변환
|
| 247 |
+
print("4️⃣ TensorFlow Lite 형식으로 변환 중...")
|
| 248 |
+
|
| 249 |
+
# TensorFlow Lite 변환기 생성
|
| 250 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
|
| 251 |
+
|
| 252 |
+
# GRU 등 복잡한 연산을 위한 설정
|
| 253 |
+
converter.target_spec.supported_ops = [
|
| 254 |
+
tf.lite.OpsSet.TFLITE_BUILTINS,
|
| 255 |
+
tf.lite.OpsSet.SELECT_TF_OPS
|
| 256 |
+
]
|
| 257 |
+
converter._experimental_lower_tensor_list_ops = False
|
| 258 |
+
|
| 259 |
+
# 최적화 옵션 설정 (선택사항)
|
| 260 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 261 |
+
|
| 262 |
+
# 변환 실행
|
| 263 |
+
tflite_model = converter.convert()
|
| 264 |
+
|
| 265 |
+
# TFLite 모델 저장
|
| 266 |
+
with open(tflite_model_path, 'wb') as f:
|
| 267 |
+
f.write(tflite_model)
|
| 268 |
+
|
| 269 |
+
print(f"✅ TensorFlow Lite 변환 완료: {tflite_model_path}")
|
| 270 |
+
|
| 271 |
+
# 모델 크기 확인
|
| 272 |
+
model_size = os.path.getsize(tflite_model_path) / (1024 * 1024) # MB
|
| 273 |
+
print(f" 모델 크기: {model_size:.2f} MB\n")
|
| 274 |
+
|
| 275 |
+
# 5️⃣ 변환된 모델 테스트
|
| 276 |
+
print("5️⃣ 변환된 모델 테스트 중...")
|
| 277 |
+
try:
|
| 278 |
+
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
|
| 279 |
+
interpreter.allocate_tensors()
|
| 280 |
+
|
| 281 |
+
input_details = interpreter.get_input_details()
|
| 282 |
+
output_details = interpreter.get_output_details()
|
| 283 |
+
|
| 284 |
+
print(f" 입력 형태: {input_details[0]['shape']}")
|
| 285 |
+
print(f" 출력 형태: {output_details[0]['shape']}")
|
| 286 |
+
|
| 287 |
+
# 테스트 입력 (고정 크기)
|
| 288 |
+
test_input = np.random.randn(1, 1, 2).astype(np.float32)
|
| 289 |
+
interpreter.set_tensor(input_details[0]['index'], test_input)
|
| 290 |
+
interpreter.invoke()
|
| 291 |
+
test_output = interpreter.get_tensor(output_details[0]['index'])
|
| 292 |
+
|
| 293 |
+
print(f" 테스트 출력: {test_output[0][0]:.4f}")
|
| 294 |
+
print(" ✅ 모델 테스트 성공\n")
|
| 295 |
+
except Exception as e:
|
| 296 |
+
print(f" ⚠️ 모델 테스트 중 경고: {e}")
|
| 297 |
+
print(" (모델은 생성되었지만 테스트는 실패했습니다. 모바일 디바이스에서 Flex ops가 필요할 수 있습니다.)\n")
|
| 298 |
+
|
| 299 |
+
# 중간 파일 정리 (선택사항)
|
| 300 |
+
print("6️⃣ 중간 파일 정리 중...")
|
| 301 |
+
try:
|
| 302 |
+
os.remove(onnx_model_path)
|
| 303 |
+
import shutil
|
| 304 |
+
shutil.rmtree(tf_model_path)
|
| 305 |
+
print("✅ 중간 파일 정리 완료\n")
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"⚠️ 중간 파일 정리 실패 (무시 가능): {e}\n")
|
| 308 |
+
|
| 309 |
+
print("=" * 80)
|
| 310 |
+
print(f"✅ 변환 완료!")
|
| 311 |
+
print(f" TFLite 모델: {tflite_model_path}")
|
| 312 |
+
print("=" * 80)
|
| 313 |
+
return True
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def main():
|
| 317 |
+
"""메인 함수"""
|
| 318 |
+
try:
|
| 319 |
+
success = convert_pytorch_to_tflite(
|
| 320 |
+
pytorch_model_path='./model/fatigue_net_v2.pt',
|
| 321 |
+
tflite_model_path='./model/fatigue_net_v2.tflite'
|
| 322 |
+
)
|
| 323 |
+
if not success:
|
| 324 |
+
return 1
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"\n❌ 변환 실패: {e}")
|
| 327 |
+
import traceback
|
| 328 |
+
traceback.print_exc()
|
| 329 |
+
return 1
|
| 330 |
+
|
| 331 |
+
return 0
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
exit(main())
|
| 336 |
+
|
load_dataset.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face 데이터셋 로드 유틸리티
|
| 3 |
+
MuscleCare-DataSet 데이터셋을 로드하는 함수들을 제공합니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_musclecare_dataset(
|
| 11 |
+
split: Optional[str] = None,
|
| 12 |
+
cache_dir: Optional[str] = None
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
MuscleCare-DataSet 데이터셋을 로드합니다.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
split: 데이터셋 split 이름 (None이면 모든 split 로드)
|
| 19 |
+
cache_dir: 캐시 디렉토리 경로
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Dataset 또는 DatasetDict 객체
|
| 23 |
+
"""
|
| 24 |
+
dataset = load_dataset(
|
| 25 |
+
"Merry99/MuscleCare-DataSet",
|
| 26 |
+
split=split,
|
| 27 |
+
cache_dir=cache_dir
|
| 28 |
+
)
|
| 29 |
+
return dataset
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
print("데이터셋 로딩 중...")
|
| 33 |
+
dataset = load_musclecare_dataset()
|
| 34 |
+
print("✅ 데이터셋 로드 완료")
|
| 35 |
+
if hasattr(dataset, 'keys'):
|
| 36 |
+
print(f"총 {len(dataset.keys())}개의 split이 있습니다.")
|
| 37 |
+
|
| 38 |
+
|
model.md
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## `/model` API
|
| 2 |
+
|
| 3 |
+
모델 다운로드 엔드포인트는 최신 모델과 특정 버전의 모델을 모두 제공하며, 응답 헤더를 통해 실제 버전 정보를 확인할 수 있습니다.
|
| 4 |
+
|
| 5 |
+
### 요청 형식
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
GET /model
|
| 9 |
+
GET /model?version={번호}
|
| 10 |
+
GET /model?version={번호}&filename={파일명}
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
| 파라미터 | 타입 | 설명 |
|
| 14 |
+
| --- | --- | --- |
|
| 15 |
+
| `version` (선택) | int | 생략하거나 빈 값이면 최신 모델. 지정하면 해당 버전 확인 후 다운로드. |
|
| 16 |
+
| `filename` (선택) | string | 내려받을 파일명. 기본값은 환경 변수 `HF_E2E_MODEL_FILE` (기본 `cnn_gru_fatigue.tflite`). |
|
| 17 |
+
|
| 18 |
+
### 응답
|
| 19 |
+
|
| 20 |
+
- 본문: 요청한 모델 바이너리 (예: `.tflite`, `.keras`, 메타데이터 등)
|
| 21 |
+
- 헤더:
|
| 22 |
+
- `X-Model-Version`: 실제 다운로드된 모델 버전
|
| 23 |
+
- `X-Model-Filename`: 반환된 파일명
|
| 24 |
+
- 에러:
|
| 25 |
+
- `404` – 요청한 버전이 현재 `model_version`보다 크거나 manifest에 존재하지 않을 때
|
| 26 |
+
- `500` – Hugging Face Hub 다운로드 실패 등 내부 오류
|
| 27 |
+
|
| 28 |
+
### 동작 규칙
|
| 29 |
+
|
| 30 |
+
1. 서버는 `training_state.json`의 `model_version` 값을 읽어 현재 허용 가능한 최대 버전을 확인합니다.
|
| 31 |
+
2. `version`을 지정하지 않으면 최신 모델(현재 버전)을 다운로드합니다.
|
| 32 |
+
3. `version`을 지정하면 서버가 현재 `model_version` 이하인지 확인한 뒤, 동일한 파일명을 내려줍니다(버전별로 파일명을 구분하지 않습니다).
|
| 33 |
+
4. 요청한 버전이 현재 버전보다 크거나 파일이 존재하지 않으면 `404`를 반환합니다.
|
| 34 |
+
|
| 35 |
+
### 사용 예시
|
| 36 |
+
|
| 37 |
+
#### 최신 모델 다운로드
|
| 38 |
+
```bash
|
| 39 |
+
curl -L -o cnn_gru_fatigue_latest.tflite \
|
| 40 |
+
"https://merry99-musclecare-train-ai.hf.space/model"
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
#### 버전 3 모델 다운로드
|
| 44 |
+
```bash
|
| 45 |
+
curl -L -o cnn_gru_fatigue_v3.tflite \
|
| 46 |
+
"https://merry99-musclecare-train-ai.hf.space/model?version=3"
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
#### 버전 3 메타데이터 다운로드
|
| 50 |
+
```bash
|
| 51 |
+
curl -L -o metadata_v3.json \
|
| 52 |
+
"https://merry99-musclecare-train-ai.hf.space/model?version=3&filename=cnn_gru_fatigue_metadata.json"
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
#### 헤더 확인
|
| 56 |
+
```bash
|
| 57 |
+
curl -I "https://merry99-musclecare-train-ai.hf.space/model?version=3"
|
| 58 |
+
```
|
| 59 |
+
응답 헤더 예시:
|
| 60 |
+
```
|
| 61 |
+
X-Model-Version: 3
|
| 62 |
+
X-Model-Filename: cnn_gru_fatigue.tflite
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### 주의 사항
|
| 66 |
+
|
| 67 |
+
- `training_state.json`의 `model_version` 값이 기준이 되며, 그보다 높은 버전을 요청하면 404가 반환됩니다.
|
| 68 |
+
- 버전별로 다른 파일을 유지하지 않고, 같은 파일명을 내려주되 헤더(`X-Model-Version`)로 실제 버전을 확인합니다.
|
| 69 |
+
- 실패(예: 404) 시 JSON 응답이 내려오므로, 클라이언트는 상태 코드를 먼저 확인한 뒤 **200일 때만** `body`를 파일로 저장하세요.
|
| 70 |
+
|
| 71 |
+
Flutter 예시 (Dio):
|
| 72 |
+
```dart
|
| 73 |
+
final response = await dio.get<List<int>>(
|
| 74 |
+
'https://merry99-musclecare-train-ai.hf.space/model',
|
| 75 |
+
options: Options(responseType: ResponseType.bytes),
|
| 76 |
+
);
|
| 77 |
+
|
| 78 |
+
if (response.statusCode == 200) {
|
| 79 |
+
final version = response.headers.value('X-Model-Version');
|
| 80 |
+
final filename = response.headers.value('X-Model-Filename') ?? 'model.tflite';
|
| 81 |
+
await File('/path/$filename').writeAsBytes(response.data!);
|
| 82 |
+
} else {
|
| 83 |
+
final errorText = utf8.decode(response.data ?? []);
|
| 84 |
+
// 에러 처리
|
| 85 |
+
}
|
| 86 |
+
```
|
| 87 |
+
- Space 환경 변수 `HF_E2E_MODEL_TOKEN`, `HF_E2E_MODEL_REPO_ID`가 올바르게 설정돼 있어야 `/model` 및 `/trigger`가 정상 동작합니다.
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
## 모델 입력 사양 (Flutter 참고)
|
| 91 |
+
|
| 92 |
+
- 입력 형상: `(batch_size, input_dim)`이며 기본 `input_dim = 10 (FEATURE_COLUMNS) + embedding_dim`.
|
| 93 |
+
- `FEATURE_COLUMNS`: `rms_acc`, `rms_gyro`, `mean_freq_acc`, `mean_freq_gyro`, `entropy_acc`, `entropy_gyro`, `jerk_mean`, `jerk_std`, `stability_index`, `fatigue_prev`.
|
| 94 |
+
- `user_emb`: 메타데이터의 `embedding_dim`과 동일한 길이. 부족하면 뒤를 `0.0f`로 패딩.
|
| 95 |
+
- 메타데이터(`cnn_gru_fatigue_metadata.json`)의 `scaler.mean`, `scaler.scale`로 표준화한 뒤 모델에 전달.
|
| 96 |
+
|
| 97 |
+
### Flutter에서 실행 순서
|
| 98 |
+
- **메타데이터 로드**: JSON에서 `feature_columns`, `scaler.mean`, `scaler.scale`, `embedding_dim`, `input_dim`을 읽는다.
|
| 99 |
+
- **특징 추출**: 측정 버튼을 눌러 얻은 윈도우에서 10개 피처 값을 계산한다.
|
| 100 |
+
- **표준화**: `(value - mean) / scale`을 수행하되 `scale`이 0이면 0으로 대체.
|
| 101 |
+
- **입력 벡터 구성**: `[정규화된 10개 피처, user_emb(패딩 포함)]`을 이어 붙여 `Float32List`로 만든다.
|
| 102 |
+
- **TFLite 실행**: 입력을 `[1, input_dim]`으로 reshape 후 `interpreter.run(input, output)`을 호출한다.
|
| 103 |
+
|
| 104 |
+
```dart
|
| 105 |
+
final meta = await loadMetadata(); // JSON 파싱: scaler, embedding_dim 등
|
| 106 |
+
final features = computeFeatureVector(); // 길이 10, float
|
| 107 |
+
final userEmb = ensureEmbeddingLength(rawEmb, meta.embeddingDim); // 패딩
|
| 108 |
+
|
| 109 |
+
final normalized = List<double>.generate(features.length, (i) {
|
| 110 |
+
final scale = meta.scalerScale[i] == 0 ? 1.0 : meta.scalerScale[i];
|
| 111 |
+
return (features[i] - meta.scalerMean[i]) / scale;
|
| 112 |
+
});
|
| 113 |
+
|
| 114 |
+
final inputVector = Float32List.fromList([
|
| 115 |
+
...normalized,
|
| 116 |
+
...userEmb.map((e) => e.toDouble()),
|
| 117 |
+
]);
|
| 118 |
+
|
| 119 |
+
final outputBuffer = Float32List(1);
|
| 120 |
+
interpreter.run(inputVector.reshape([1, inputVector.length]), outputBuffer);
|
| 121 |
+
final fatigueScore = outputBuffer[0];
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 주의
|
| 125 |
+
- 최초 측정부터 바로 예측 가능하며, 더 이상 5개 윈도우 누적이 필요하지 않습니다.
|
| 126 |
+
- `fatigue_prev`는 직전 측정의 피로도 지표로, 값이 없다면 `0` 또는 직전 예측치로 초기화해 주세요.
|
| 127 |
+
- 피처 추출 로직과 임베딩 차원은 백엔드 학습 파이프라인과 동일해야 합니다.
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python 3.10+ 환경을 가정합니다.
|
| 2 |
+
typing_extensions>=4.8.0,<5.0.0
|
| 3 |
+
numpy>=1.23.0,<1.27.0
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
transformers>=4.30.0
|
| 6 |
+
datasets>=2.14.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
scikit-learn>=1.3.0
|
| 9 |
+
tqdm>=4.65.0
|
| 10 |
+
schedule>=1.2.0
|
| 11 |
+
huggingface-hub>=0.24.0
|
| 12 |
+
python-dotenv>=1.0.0
|
| 13 |
+
fastapi>=0.110.0
|
| 14 |
+
uvicorn[standard]>=0.23.0
|
| 15 |
+
|
| 16 |
+
# TFLite 변환용 패키지 (호환 버전)
|
| 17 |
+
# PyTorch와 TensorFlow가 함께 설치될 때 충돌 방지를 위해 순서 중요
|
| 18 |
+
onnx==1.15.0
|
| 19 |
+
onnx-tf==1.10.0
|
| 20 |
+
tensorflow==2.15.0
|
| 21 |
+
protobuf<4.0.0
|
| 22 |
+
tensorflow-probability>=0.23.0
|
run_local.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 로컬 실행 스크립트 (venv 없이)
|
| 3 |
+
|
| 4 |
+
echo "🚀 MuscleCare Train AI - 로컬 실행"
|
| 5 |
+
echo "=================================="
|
| 6 |
+
|
| 7 |
+
# Python 3.10 확인
|
| 8 |
+
PYTHON_CMD=""
|
| 9 |
+
if command -v python3.10 &> /dev/null; then
|
| 10 |
+
PYTHON_CMD="python3.10"
|
| 11 |
+
elif [ -f /usr/local/bin/python3.10 ]; then
|
| 12 |
+
PYTHON_CMD="/usr/local/bin/python3.10"
|
| 13 |
+
else
|
| 14 |
+
echo "❌ Python 3.10이 필요합니다."
|
| 15 |
+
echo " 설치: brew install python@3.10"
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
echo "✅ Python 버전: $($PYTHON_CMD --version)"
|
| 20 |
+
echo ""
|
| 21 |
+
|
| 22 |
+
# 패키지 설치 확인
|
| 23 |
+
echo "📦 필수 패키지 확인 중..."
|
| 24 |
+
$PYTHON_CMD -c "import torch; import onnx; import tensorflow" 2>/dev/null
|
| 25 |
+
if [ $? -ne 0 ]; then
|
| 26 |
+
echo "⚠️ 일부 패키지가 설치되지 않았습니다."
|
| 27 |
+
echo " 설치: $PYTHON_CMD -m pip install --user -r requirements.txt"
|
| 28 |
+
echo ""
|
| 29 |
+
read -p "지금 설치하시겠습니까? [y/N]: " -n 1 -r
|
| 30 |
+
echo
|
| 31 |
+
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
| 32 |
+
$PYTHON_CMD -m pip install --user -r requirements.txt
|
| 33 |
+
fi
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
echo ""
|
| 37 |
+
echo "▶️ 자동 학습 스케줄러 실행 중..."
|
| 38 |
+
$PYTHON_CMD start.py
|
| 39 |
+
|
start.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MuscleCare Train AI - 자동 학습 스케줄러 시작 스크립트
|
| 3 |
+
매주 일요일 자정에 모델을 자동으로 학습합니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from train_scheduler import main
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
main()
|
| 10 |
+
|
train_e2e.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-End 모델 학습 스크립트 (TensorFlow)
|
| 3 |
+
단일 윈도우(센서 특징 + user_emb)를 입력으로 받아 피로도를 예측하는
|
| 4 |
+
MLP 기반 회귀 모델을 학습하고 SavedModel/TFLite 형식으로 저장합니다.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from typing import Dict, Iterable, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
from sklearn.preprocessing import StandardScaler
|
| 15 |
+
from tensorflow.keras.layers import BatchNormalization, Dense, Dropout, Input
|
| 16 |
+
from tensorflow.keras.models import Model
|
| 17 |
+
|
| 18 |
+
from load_dataset import load_musclecare_dataset
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
FEATURE_COLUMNS = [
|
| 22 |
+
'rms_acc',
|
| 23 |
+
'rms_gyro',
|
| 24 |
+
'mean_freq_acc',
|
| 25 |
+
'mean_freq_gyro',
|
| 26 |
+
'entropy_acc',
|
| 27 |
+
'entropy_gyro',
|
| 28 |
+
'jerk_mean',
|
| 29 |
+
'jerk_std',
|
| 30 |
+
'stability_index',
|
| 31 |
+
'fatigue_prev',
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
DEFAULT_EPOCHS = 30
|
| 35 |
+
DEFAULT_EMBED_DIM = 12
|
| 36 |
+
DEFAULT_BATCH_SIZE = 64
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def parse_user_emb(emb: Union[str, Iterable[float], np.ndarray]) -> np.ndarray:
|
| 40 |
+
"""사용자 임베딩을 numpy 배열로 변환"""
|
| 41 |
+
arr: Optional[np.ndarray] = None
|
| 42 |
+
|
| 43 |
+
if isinstance(emb, np.ndarray):
|
| 44 |
+
arr = emb.astype(np.float32)
|
| 45 |
+
elif isinstance(emb, str):
|
| 46 |
+
try:
|
| 47 |
+
arr = np.array(json.loads(emb), dtype=np.float32)
|
| 48 |
+
except (json.JSONDecodeError, TypeError):
|
| 49 |
+
arr = None
|
| 50 |
+
elif isinstance(emb, Iterable):
|
| 51 |
+
arr = np.array(list(emb), dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
if arr is None or arr.ndim == 0:
|
| 54 |
+
arr = np.zeros(DEFAULT_EMBED_DIM, dtype=np.float32)
|
| 55 |
+
|
| 56 |
+
return arr
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def pad_embedding(embedding: np.ndarray, target_dim: int) -> np.ndarray:
|
| 60 |
+
"""임베딩 길이를 target_dim에 맞춰 패딩"""
|
| 61 |
+
padded = np.zeros(target_dim, dtype=np.float32)
|
| 62 |
+
length = min(target_dim, embedding.size)
|
| 63 |
+
padded[:length] = embedding[:length]
|
| 64 |
+
return padded
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def dataset_split_to_dataframe(dataset_split) -> pd.DataFrame:
|
| 68 |
+
"""HuggingFace Dataset split을 pandas DataFrame으로 변환"""
|
| 69 |
+
if hasattr(dataset_split, "to_pandas"):
|
| 70 |
+
return dataset_split.to_pandas()
|
| 71 |
+
return pd.DataFrame(dataset_split)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def build_dataframe_from_source(
|
| 75 |
+
dataset_source,
|
| 76 |
+
exclude_sessions: Optional[Iterable[str]] = None
|
| 77 |
+
) -> pd.DataFrame:
|
| 78 |
+
"""데이터 소스를 단일 DataFrame으로 통합"""
|
| 79 |
+
frames = []
|
| 80 |
+
exclude_sessions = set(exclude_sessions or [])
|
| 81 |
+
|
| 82 |
+
if hasattr(dataset_source, "items"):
|
| 83 |
+
iterator = dataset_source.items()
|
| 84 |
+
else:
|
| 85 |
+
iterator = [("all", dataset_source)]
|
| 86 |
+
|
| 87 |
+
for split_name, split_dataset in iterator:
|
| 88 |
+
df_split = dataset_split_to_dataframe(split_dataset)
|
| 89 |
+
if df_split.empty:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
if exclude_sessions:
|
| 93 |
+
if 'session_id' not in df_split.columns:
|
| 94 |
+
raise KeyError("데이터셋에 'session_id' 컬럼이 없습니다.")
|
| 95 |
+
df_split = df_split[~df_split['session_id'].isin(exclude_sessions)]
|
| 96 |
+
|
| 97 |
+
if not df_split.empty:
|
| 98 |
+
frames.append(df_split)
|
| 99 |
+
print(f" - {split_name}: {len(df_split)}개 샘플 (필터링 후)")
|
| 100 |
+
|
| 101 |
+
if not frames:
|
| 102 |
+
return pd.DataFrame()
|
| 103 |
+
|
| 104 |
+
return pd.concat(frames, ignore_index=True)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def prepare_training_arrays(
|
| 108 |
+
df: pd.DataFrame,
|
| 109 |
+
feature_cols: Iterable[str]
|
| 110 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 111 |
+
"""단일 윈도우 입력을 위한 학습 데이터를 생성"""
|
| 112 |
+
required_columns = set(feature_cols) | {'fatigue', 'user_emb'}
|
| 113 |
+
missing_columns = required_columns - set(df.columns)
|
| 114 |
+
if missing_columns:
|
| 115 |
+
raise KeyError(f"데이터셋에 누락된 컬럼이 있습니다: {sorted(missing_columns)}")
|
| 116 |
+
|
| 117 |
+
feature_values = (
|
| 118 |
+
df[list(feature_cols)]
|
| 119 |
+
.astype(np.float32)
|
| 120 |
+
.replace([np.inf, -np.inf], np.nan)
|
| 121 |
+
.fillna(0.0)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
scaler = StandardScaler()
|
| 125 |
+
features_scaled = scaler.fit_transform(feature_values).astype(np.float32)
|
| 126 |
+
|
| 127 |
+
user_embeddings = np.stack([
|
| 128 |
+
emb.astype(np.float32) if isinstance(emb, np.ndarray) else np.zeros(DEFAULT_EMBED_DIM, dtype=np.float32)
|
| 129 |
+
for emb in df['user_emb']
|
| 130 |
+
])
|
| 131 |
+
|
| 132 |
+
X = np.concatenate([features_scaled, user_embeddings], axis=1).astype(np.float32)
|
| 133 |
+
y = df['fatigue'].astype(np.float32).to_numpy()
|
| 134 |
+
|
| 135 |
+
return X, y, scaler.mean_.astype(np.float32), scaler.scale_.astype(np.float32)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def build_dense_regression_model(
|
| 139 |
+
input_dim: int,
|
| 140 |
+
learning_rate: float = 0.001
|
| 141 |
+
) -> Model:
|
| 142 |
+
"""단일 윈도우 입력용 MLP 회귀 모델"""
|
| 143 |
+
inputs = Input(shape=(input_dim,), name="features")
|
| 144 |
+
|
| 145 |
+
x = Dense(128, activation='relu')(inputs)
|
| 146 |
+
x = BatchNormalization()(x)
|
| 147 |
+
x = Dropout(0.3)(x)
|
| 148 |
+
|
| 149 |
+
x = Dense(64, activation='relu')(x)
|
| 150 |
+
x = BatchNormalization()(x)
|
| 151 |
+
x = Dropout(0.2)(x)
|
| 152 |
+
|
| 153 |
+
outputs = Dense(1, activation='linear', name='fatigue')(x)
|
| 154 |
+
|
| 155 |
+
model = Model(inputs=inputs, outputs=outputs)
|
| 156 |
+
model.compile(
|
| 157 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
|
| 158 |
+
loss='mse',
|
| 159 |
+
metrics=['mae']
|
| 160 |
+
)
|
| 161 |
+
return model
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def ensure_embeddings(df: pd.DataFrame) -> Tuple[pd.DataFrame, int]:
|
| 165 |
+
"""user_emb 컬럼을 numpy 배열로 정규화하고 통일된 차원으로 패딩"""
|
| 166 |
+
if 'user_emb' not in df.columns:
|
| 167 |
+
raise KeyError("데이터셋에 'user_emb' 컬럼이 없습니다.")
|
| 168 |
+
|
| 169 |
+
df = df.copy()
|
| 170 |
+
df['user_emb'] = df['user_emb'].apply(parse_user_emb)
|
| 171 |
+
|
| 172 |
+
dims = [
|
| 173 |
+
emb.size for emb in df['user_emb']
|
| 174 |
+
if isinstance(emb, np.ndarray) and emb.size > 0
|
| 175 |
+
]
|
| 176 |
+
target_dim = max(dims) if dims else DEFAULT_EMBED_DIM
|
| 177 |
+
|
| 178 |
+
df['user_emb'] = df['user_emb'].apply(lambda emb: pad_embedding(emb, target_dim))
|
| 179 |
+
|
| 180 |
+
return df, target_dim
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def main(
|
| 184 |
+
data_list: Optional[Iterable[Dict]] = None,
|
| 185 |
+
exclude_sessions: Optional[Iterable[str]] = None,
|
| 186 |
+
epochs: int = DEFAULT_EPOCHS
|
| 187 |
+
) -> Optional[Dict[str, str]]:
|
| 188 |
+
"""
|
| 189 |
+
메인 학습 함수
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
data_list: 사용할 데이터 리스트 (None이면 전체 데이터 사용)
|
| 193 |
+
exclude_sessions: 제외할 session_id 집합 (중복 방지용)
|
| 194 |
+
epochs: 학습 에포크 수
|
| 195 |
+
"""
|
| 196 |
+
print("=" * 80)
|
| 197 |
+
print("MuscleCare Train AI - TensorFlow Single-Window Training")
|
| 198 |
+
print("=" * 80)
|
| 199 |
+
|
| 200 |
+
tf.keras.utils.set_random_seed(42)
|
| 201 |
+
|
| 202 |
+
# 1️⃣ 데이터 로드
|
| 203 |
+
print("1️⃣ 데이터셋 로딩 중...")
|
| 204 |
+
if data_list is None:
|
| 205 |
+
dataset_source = load_musclecare_dataset()
|
| 206 |
+
df = build_dataframe_from_source(dataset_source, exclude_sessions)
|
| 207 |
+
else:
|
| 208 |
+
df = pd.DataFrame(data_list)
|
| 209 |
+
if exclude_sessions:
|
| 210 |
+
df = df[~df['session_id'].isin(set(exclude_sessions))]
|
| 211 |
+
|
| 212 |
+
if df.empty:
|
| 213 |
+
print("⚠️ 학습 가능한 데이터가 없습니다. 학습을 종료합니다.")
|
| 214 |
+
print("=" * 80)
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
print(f"✅ 데이터 로드 완료: {len(df)}개 행")
|
| 218 |
+
|
| 219 |
+
# 2️⃣ 사용자 임베딩 정규화
|
| 220 |
+
print("2️⃣ 사용자 임베딩 정규화 중...")
|
| 221 |
+
df, emb_dim = ensure_embeddings(df)
|
| 222 |
+
print(f"✅ 임베딩 차원: {emb_dim}")
|
| 223 |
+
|
| 224 |
+
# 3️⃣ 학습 데이터 생성
|
| 225 |
+
print("3️⃣ 학습 데이터 생성 중...")
|
| 226 |
+
X, y, scaler_mean, scaler_scale = prepare_training_arrays(df, FEATURE_COLUMNS)
|
| 227 |
+
if X.size == 0:
|
| 228 |
+
print("⚠️ 학습할 입력 데이터가 없습니다. 학습을 종료합니다.")
|
| 229 |
+
print("=" * 80)
|
| 230 |
+
return None
|
| 231 |
+
num_samples, input_dim = X.shape
|
| 232 |
+
print(f"✅ 학습 데이터 생성 완료: {num_samples}개 샘플, 입력 차원 {input_dim}")
|
| 233 |
+
|
| 234 |
+
# 4️⃣ 모델 생성
|
| 235 |
+
print("4️⃣ 모델 생성 중...")
|
| 236 |
+
model = build_dense_regression_model(input_dim)
|
| 237 |
+
model.summary(print_fn=lambda x: print(" " + x))
|
| 238 |
+
print("✅ 모델 생성 완료")
|
| 239 |
+
|
| 240 |
+
# 5️⃣ 모델 학습
|
| 241 |
+
print("5️⃣ 모델 학습 시작...")
|
| 242 |
+
callbacks = [
|
| 243 |
+
tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
|
| 244 |
+
tf.keras.callbacks.ReduceLROnPlateau(patience=3, factor=0.5),
|
| 245 |
+
]
|
| 246 |
+
validation_split = 0.1 if num_samples >= 20 else 0.0
|
| 247 |
+
history = model.fit(
|
| 248 |
+
X,
|
| 249 |
+
y,
|
| 250 |
+
epochs=epochs,
|
| 251 |
+
batch_size=min(DEFAULT_BATCH_SIZE, num_samples),
|
| 252 |
+
shuffle=True,
|
| 253 |
+
validation_split=validation_split,
|
| 254 |
+
callbacks=callbacks,
|
| 255 |
+
verbose=1,
|
| 256 |
+
)
|
| 257 |
+
print("✅ 모델 학습 완료")
|
| 258 |
+
|
| 259 |
+
# 6️⃣ 모델 및 메타데이터 저장
|
| 260 |
+
print("6️⃣ 모델 저장 중...")
|
| 261 |
+
model_dir = './model'
|
| 262 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
keras_model_path = os.path.join(model_dir, 'cnn_gru_fatigue.keras')
|
| 265 |
+
model.save(keras_model_path)
|
| 266 |
+
|
| 267 |
+
metadata = {
|
| 268 |
+
"feature_columns": list(FEATURE_COLUMNS),
|
| 269 |
+
"embedding_dim": emb_dim,
|
| 270 |
+
"input_dim": input_dim,
|
| 271 |
+
"epochs": epochs,
|
| 272 |
+
"num_samples": int(num_samples),
|
| 273 |
+
"scaler": {
|
| 274 |
+
"mean": scaler_mean.tolist(),
|
| 275 |
+
"scale": scaler_scale.tolist(),
|
| 276 |
+
},
|
| 277 |
+
"history": {
|
| 278 |
+
"loss": history.history.get('loss', []),
|
| 279 |
+
"mae": history.history.get('mae', []),
|
| 280 |
+
"val_loss": history.history.get('val_loss', []),
|
| 281 |
+
"val_mae": history.history.get('val_mae', []),
|
| 282 |
+
},
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
metadata_path = os.path.join(model_dir, 'cnn_gru_fatigue_metadata.json')
|
| 286 |
+
with open(metadata_path, 'w', encoding='utf-8') as f:
|
| 287 |
+
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
| 288 |
+
|
| 289 |
+
print(f"✅ 모델 저장 완료: {keras_model_path}")
|
| 290 |
+
print(f" 메타데이터 저장: {metadata_path}")
|
| 291 |
+
|
| 292 |
+
# 7️⃣ TFLite 변환
|
| 293 |
+
print("7️⃣ TFLite 변환 중...")
|
| 294 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
| 295 |
+
converter.target_spec.supported_ops = [
|
| 296 |
+
tf.lite.OpsSet.TFLITE_BUILTINS,
|
| 297 |
+
tf.lite.OpsSet.SELECT_TF_OPS,
|
| 298 |
+
]
|
| 299 |
+
converter._experimental_lower_tensor_list_ops = False
|
| 300 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 301 |
+
tflite_model = converter.convert()
|
| 302 |
+
|
| 303 |
+
tflite_model_path = os.path.join(model_dir, 'cnn_gru_fatigue.tflite')
|
| 304 |
+
with open(tflite_model_path, 'wb') as f:
|
| 305 |
+
f.write(tflite_model)
|
| 306 |
+
|
| 307 |
+
print(f"✅ TFLite 모델 저장 완료: {tflite_model_path}")
|
| 308 |
+
print("=" * 80)
|
| 309 |
+
|
| 310 |
+
return {
|
| 311 |
+
"keras": os.path.abspath(keras_model_path),
|
| 312 |
+
"tflite": os.path.abspath(tflite_model_path),
|
| 313 |
+
"metadata": os.path.abspath(metadata_path),
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
| 319 |
+
|
train_scheduler.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
주 1회 자동 모델 학습 스케줄러
|
| 3 |
+
매주 일요일 자정에 실행되어 모델을 자동으로 업데이트합니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import schedule
|
| 7 |
+
import time
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import shutil
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, Optional
|
| 14 |
+
|
| 15 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 16 |
+
try:
|
| 17 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 18 |
+
except ImportError: # fallback for older versions
|
| 19 |
+
HfHubHTTPError = Exception # type: ignore
|
| 20 |
+
|
| 21 |
+
from train_e2e import main as train_main
|
| 22 |
+
from load_dataset import load_musclecare_dataset
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TrainingScheduler:
|
| 26 |
+
"""모델 학습 스케줄러 클래스"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, state_file: str = './model/training_state.json'):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
state_file: 학습 상태를 저장할 파일 경로
|
| 32 |
+
"""
|
| 33 |
+
self.state_file = state_file
|
| 34 |
+
self.state_dir = os.path.dirname(state_file)
|
| 35 |
+
os.makedirs(self.state_dir, exist_ok=True)
|
| 36 |
+
self._hf_token = os.getenv("HF_E2E_MODEL_TOKEN")
|
| 37 |
+
self._hf_repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
|
| 38 |
+
self._hf_state_filename = os.getenv("HF_E2E_MODEL_STATE_FILE", Path(state_file).name)
|
| 39 |
+
|
| 40 |
+
if not os.path.exists(self.state_file):
|
| 41 |
+
self._download_state_from_hub()
|
| 42 |
+
|
| 43 |
+
def load_training_state(self):
|
| 44 |
+
"""학습 상태 로드"""
|
| 45 |
+
if os.path.exists(self.state_file):
|
| 46 |
+
try:
|
| 47 |
+
with open(self.state_file, 'r', encoding='utf-8') as f:
|
| 48 |
+
state = json.load(f)
|
| 49 |
+
return state
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"⚠️ 학습 상태 로드 실패: {e}")
|
| 52 |
+
return self._get_default_state()
|
| 53 |
+
if self._download_state_from_hub():
|
| 54 |
+
return self.load_training_state()
|
| 55 |
+
return self._get_default_state()
|
| 56 |
+
|
| 57 |
+
def save_training_state(self, state):
|
| 58 |
+
"""학습 상태 저장"""
|
| 59 |
+
try:
|
| 60 |
+
with open(self.state_file, 'w', encoding='utf-8') as f:
|
| 61 |
+
json.dump(state, f, indent=2, ensure_ascii=False)
|
| 62 |
+
self._upload_state_to_hub()
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"⚠️ 학습 상태 저장 실패: {e}")
|
| 65 |
+
|
| 66 |
+
def _get_default_state(self):
|
| 67 |
+
"""기본 학습 상태"""
|
| 68 |
+
return {
|
| 69 |
+
'processed_sessions': [],
|
| 70 |
+
'last_training_date': None,
|
| 71 |
+
'model_version': 0,
|
| 72 |
+
'total_data_count': 0
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def reset_training_state(self):
|
| 76 |
+
"""학습 상태 초기화"""
|
| 77 |
+
state = self._get_default_state()
|
| 78 |
+
self.save_training_state(state)
|
| 79 |
+
return state
|
| 80 |
+
|
| 81 |
+
def get_new_data(self, processed_sessions):
|
| 82 |
+
"""
|
| 83 |
+
새로운 데이터만 수집 (중복 방지)
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
processed_sessions: 이미 처리된 session_id 집합
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
list: 새로운 데이터 리스트
|
| 90 |
+
"""
|
| 91 |
+
print("📊 새로운 데이터 수집 중...")
|
| 92 |
+
dataset_dict = load_musclecare_dataset()
|
| 93 |
+
|
| 94 |
+
new_data = []
|
| 95 |
+
new_sessions = set()
|
| 96 |
+
|
| 97 |
+
for split_name in dataset_dict.keys():
|
| 98 |
+
for item in dataset_dict[split_name]:
|
| 99 |
+
session_id = item.get('session_id', '')
|
| 100 |
+
|
| 101 |
+
# 중복 체크
|
| 102 |
+
if session_id not in processed_sessions:
|
| 103 |
+
new_data.append(item)
|
| 104 |
+
new_sessions.add(session_id)
|
| 105 |
+
|
| 106 |
+
print(f"✅ 새로운 데이터: {len(new_data)}개 (새로운 세션: {len(new_sessions)}개)")
|
| 107 |
+
return new_data, new_sessions
|
| 108 |
+
|
| 109 |
+
def train_incremental_model(self, new_data, processed_sessions):
|
| 110 |
+
"""
|
| 111 |
+
증분 학습 수행 (전체 데이터로 재학습하되 중복 제외)
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
new_data: 새로운 데이터 리스트
|
| 115 |
+
processed_sessions: 이미 처리된 session_id 집합
|
| 116 |
+
"""
|
| 117 |
+
if not new_data:
|
| 118 |
+
print("⚠️ 새로운 데이터가 없어 학습을 건너뜁니다.")
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
print(f"\n🔄 모델 학습 시작 (새로운 데이터: {len(new_data)}개 포함)...")
|
| 122 |
+
|
| 123 |
+
# 전체 데이터를 가져오되, 중복된 세션은 제외
|
| 124 |
+
# train_e2e.py의 main 함수에 exclude_sessions 파라미터 전달
|
| 125 |
+
from train_e2e import main as train_main
|
| 126 |
+
training_outputs = train_main(data_list=None, exclude_sessions=processed_sessions)
|
| 127 |
+
|
| 128 |
+
if isinstance(training_outputs, dict):
|
| 129 |
+
return (
|
| 130 |
+
training_outputs.get('tflite')
|
| 131 |
+
or training_outputs.get('keras')
|
| 132 |
+
or training_outputs.get('metadata')
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return training_outputs
|
| 136 |
+
|
| 137 |
+
def run_scheduled_training(self) -> Dict[str, Optional[str]]:
|
| 138 |
+
"""스케줄된 학습 실행"""
|
| 139 |
+
print("=" * 80)
|
| 140 |
+
print(f"🕛 자동 학습 시작 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 141 |
+
print("=" * 80)
|
| 142 |
+
|
| 143 |
+
# 학습 상태 로드
|
| 144 |
+
state = self.load_training_state()
|
| 145 |
+
processed_sessions = set(state.get('processed_sessions', []))
|
| 146 |
+
|
| 147 |
+
print(f"📋 현재 상태:")
|
| 148 |
+
print(f" - 처리된 세션 수: {len(processed_sessions)}")
|
| 149 |
+
print(f" - 마지막 학습일: {state.get('last_training_date', '없음')}")
|
| 150 |
+
print(f" - 모델 버전: {state.get('model_version', 0)}")
|
| 151 |
+
|
| 152 |
+
# 새로운 데이터 수집
|
| 153 |
+
new_data, new_sessions = self.get_new_data(processed_sessions)
|
| 154 |
+
|
| 155 |
+
result: Dict[str, Optional[str]] = {
|
| 156 |
+
"status": "skipped",
|
| 157 |
+
"model_path": None,
|
| 158 |
+
"new_data_count": len(new_data),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
if new_data:
|
| 162 |
+
# 증분 학습 수행 (전체 데이터로 재학습하되 중복 제외)
|
| 163 |
+
model_path = self.train_incremental_model(
|
| 164 |
+
new_data,
|
| 165 |
+
processed_sessions
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if model_path:
|
| 169 |
+
# 학습 상태 업데이트
|
| 170 |
+
processed_sessions.update(new_sessions)
|
| 171 |
+
state['processed_sessions'] = list(processed_sessions)
|
| 172 |
+
state['last_training_date'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 173 |
+
new_version = state.get('model_version', 0) + 1
|
| 174 |
+
state['model_version'] = new_version
|
| 175 |
+
state['total_data_count'] = state.get('total_data_count', 0) + len(new_data)
|
| 176 |
+
|
| 177 |
+
self.save_training_state(state)
|
| 178 |
+
|
| 179 |
+
print("\n✅ 자동 학습 완료!")
|
| 180 |
+
print(f" - 모델 경로: {model_path}")
|
| 181 |
+
print(f" - 새로운 모델 버전: {state['model_version']}")
|
| 182 |
+
print(f" - 총 처리된 데이터: {state['total_data_count']}개")
|
| 183 |
+
result.update({
|
| 184 |
+
"status": "trained",
|
| 185 |
+
"model_path": model_path,
|
| 186 |
+
"new_data_count": len(new_data),
|
| 187 |
+
"model_version": str(state['model_version']),
|
| 188 |
+
})
|
| 189 |
+
|
| 190 |
+
else:
|
| 191 |
+
print("\n⚠️ 새로운 데이터가 없어 학습을 건너뜁니다.")
|
| 192 |
+
|
| 193 |
+
print("=" * 80)
|
| 194 |
+
return result
|
| 195 |
+
|
| 196 |
+
def _get_hf_api(self) -> Optional[HfApi]:
|
| 197 |
+
if not self._hf_repo_id or not self._hf_token:
|
| 198 |
+
return None
|
| 199 |
+
return HfApi(token=self._hf_token)
|
| 200 |
+
|
| 201 |
+
def _download_state_from_hub(self) -> bool:
|
| 202 |
+
api = self._get_hf_api()
|
| 203 |
+
if api is None:
|
| 204 |
+
return False
|
| 205 |
+
try:
|
| 206 |
+
downloaded_path = hf_hub_download(
|
| 207 |
+
repo_id=self._hf_repo_id,
|
| 208 |
+
filename=self._hf_state_filename,
|
| 209 |
+
repo_type="model",
|
| 210 |
+
token=self._hf_token,
|
| 211 |
+
local_dir=self.state_dir,
|
| 212 |
+
local_dir_use_symlinks=False,
|
| 213 |
+
)
|
| 214 |
+
shutil.move(downloaded_path, self.state_file)
|
| 215 |
+
print(f"✅ Hugging Face Hub에서 학습 상태를 다운로드했습니다: {self._hf_state_filename}")
|
| 216 |
+
return True
|
| 217 |
+
except Exception as e:
|
| 218 |
+
status = getattr(getattr(e, "response", None), "status_code", None)
|
| 219 |
+
if status == 404:
|
| 220 |
+
print("ℹ️ Hugging Face Hub에 학습 상태 파일이 없어 새로 생성합니다.")
|
| 221 |
+
else:
|
| 222 |
+
print(f"⚠️ 학습 상태 다운로드 중 오류가 발생했습니다: {e}")
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
def _upload_state_to_hub(self) -> None:
|
| 226 |
+
api = self._get_hf_api()
|
| 227 |
+
if api is None:
|
| 228 |
+
return
|
| 229 |
+
try:
|
| 230 |
+
api.create_repo(repo_id=self._hf_repo_id, repo_type="model", private=False, exist_ok=True)
|
| 231 |
+
api.upload_file(
|
| 232 |
+
path_or_fileobj=self.state_file,
|
| 233 |
+
path_in_repo=self._hf_state_filename,
|
| 234 |
+
repo_id=self._hf_repo_id,
|
| 235 |
+
repo_type="model",
|
| 236 |
+
commit_message="Update training state",
|
| 237 |
+
)
|
| 238 |
+
print("✅ 학습 상태를 Hugging Face Hub에 업로드했습니다.")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
print(f"⚠️ 학습 상태 업로드 실패: {e}")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def main():
|
| 244 |
+
"""메인 함수"""
|
| 245 |
+
scheduler = TrainingScheduler()
|
| 246 |
+
|
| 247 |
+
# 매주 일요일 자정에 실행
|
| 248 |
+
schedule.every().day.at("00:00").do(scheduler.run_scheduled_training)
|
| 249 |
+
|
| 250 |
+
print("📅 자동 학습 스케줄러 시작")
|
| 251 |
+
print(" - 실행 시간: 매일 00:00")
|
| 252 |
+
print(" - 종료하려면 Ctrl+C를 누르세요\n")
|
| 253 |
+
|
| 254 |
+
# 스케줄러 실행
|
| 255 |
+
try:
|
| 256 |
+
while True:
|
| 257 |
+
schedule.run_pending()
|
| 258 |
+
time.sleep(60) # 1분마다 체크
|
| 259 |
+
except KeyboardInterrupt:
|
| 260 |
+
print("\n\n⏹️ 스케줄러 종료")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
| 265 |
+
|