MuscleCare-Train-AI / train_scheduler.py
Merry99's picture
prevent hugging face spaces pause
ece3e89
"""
μ£Ό 1회 μžλ™ λͺ¨λΈ ν•™μŠ΅ μŠ€μΌ€μ€„λŸ¬
λ§€μ£Ό μΌμš”μΌ μžμ •μ— μ‹€ν–‰λ˜μ–΄ λͺ¨λΈμ„ μžλ™μœΌλ‘œ μ—…λ°μ΄νŠΈν•©λ‹ˆλ‹€.
"""
import schedule
import time
import os
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
from huggingface_hub import HfApi, hf_hub_download
try:
from huggingface_hub.utils import HfHubHTTPError
except ImportError: # fallback for older versions
HfHubHTTPError = Exception # type: ignore
from train_e2e import main as train_main
from load_dataset import load_musclecare_dataset
class TrainingScheduler:
"""λͺ¨λΈ ν•™μŠ΅ μŠ€μΌ€μ€„λŸ¬ 클래슀"""
def __init__(self, state_file: str = './model/training_state.json'):
"""
Args:
state_file: ν•™μŠ΅ μƒνƒœλ₯Ό μ €μž₯ν•  파일 경둜
"""
self.state_file = state_file
self.state_dir = os.path.dirname(state_file)
os.makedirs(self.state_dir, exist_ok=True)
self._hf_token = os.getenv("HF_E2E_MODEL_TOKEN")
self._hf_repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
self._hf_state_filename = os.getenv("HF_E2E_MODEL_STATE_FILE", Path(state_file).name)
if not os.path.exists(self.state_file):
self._download_state_from_hub()
def load_training_state(self):
"""ν•™μŠ΅ μƒνƒœ λ‘œλ“œ"""
if os.path.exists(self.state_file):
try:
with open(self.state_file, 'r', encoding='utf-8') as f:
state = json.load(f)
return state
except Exception as e:
print(f"⚠️ ν•™μŠ΅ μƒνƒœ λ‘œλ“œ μ‹€νŒ¨: {e}")
return self._get_default_state()
if self._download_state_from_hub():
return self.load_training_state()
return self._get_default_state()
def save_training_state(self, state):
"""ν•™μŠ΅ μƒνƒœ μ €μž₯"""
try:
with open(self.state_file, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
self._upload_state_to_hub()
except Exception as e:
print(f"⚠️ ν•™μŠ΅ μƒνƒœ μ €μž₯ μ‹€νŒ¨: {e}")
def _get_default_state(self):
"""κΈ°λ³Έ ν•™μŠ΅ μƒνƒœ"""
return {
'processed_sessions': [],
'last_training_date': None,
'model_version': 0,
'total_data_count': 0
}
def reset_training_state(self):
"""ν•™μŠ΅ μƒνƒœ μ΄ˆκΈ°ν™”"""
state = self._get_default_state()
self.save_training_state(state)
return state
def get_new_data(self, processed_sessions):
"""
μƒˆλ‘œμš΄ λ°μ΄ν„°λ§Œ μˆ˜μ§‘ (쀑볡 λ°©μ§€)
Args:
processed_sessions: 이미 처리된 session_id μ§‘ν•©
Returns:
list: μƒˆλ‘œμš΄ 데이터 리슀트
"""
print("πŸ“Š μƒˆλ‘œμš΄ 데이터 μˆ˜μ§‘ 쀑...")
dataset_dict = load_musclecare_dataset()
new_data = []
new_sessions = set()
for split_name in dataset_dict.keys():
for item in dataset_dict[split_name]:
session_id = item.get('session_id', '')
# 쀑볡 체크
if session_id not in processed_sessions:
new_data.append(item)
new_sessions.add(session_id)
print(f"βœ… μƒˆλ‘œμš΄ 데이터: {len(new_data)}개 (μƒˆλ‘œμš΄ μ„Έμ…˜: {len(new_sessions)}개)")
return new_data, new_sessions
def train_incremental_model(self, new_data, processed_sessions):
"""
증뢄 ν•™μŠ΅ μˆ˜ν–‰ (전체 λ°μ΄ν„°λ‘œ μž¬ν•™μŠ΅ν•˜λ˜ 쀑볡 μ œμ™Έ)
Args:
new_data: μƒˆλ‘œμš΄ 데이터 리슀트
processed_sessions: 이미 처리된 session_id μ§‘ν•©
"""
if not new_data:
print("⚠️ μƒˆλ‘œμš΄ 데이터가 μ—†μ–΄ ν•™μŠ΅μ„ κ±΄λ„ˆλœλ‹ˆλ‹€.")
return None
print(f"\nπŸ”„ λͺ¨λΈ ν•™μŠ΅ μ‹œμž‘ (μƒˆλ‘œμš΄ 데이터: {len(new_data)}개 포함)...")
# 전체 데이터λ₯Ό κ°€μ Έμ˜€λ˜, μ€‘λ³΅λœ μ„Έμ…˜μ€ μ œμ™Έ
# train_e2e.py의 main ν•¨μˆ˜μ— exclude_sessions νŒŒλΌλ―Έν„° 전달
from train_e2e import main as train_main
training_outputs = train_main(data_list=None, exclude_sessions=processed_sessions)
if isinstance(training_outputs, dict):
return (
training_outputs.get('tflite')
or training_outputs.get('keras')
or training_outputs.get('metadata')
)
return training_outputs
def run_scheduled_training(self) -> Dict[str, Optional[str]]:
"""μŠ€μΌ€μ€„λœ ν•™μŠ΅ μ‹€ν–‰"""
print("=" * 80)
print(f"πŸ•› μžλ™ ν•™μŠ΅ μ‹œμž‘ - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 80)
# ν•™μŠ΅ μƒνƒœ λ‘œλ“œ
state = self.load_training_state()
processed_sessions = set(state.get('processed_sessions', []))
print(f"πŸ“‹ ν˜„μž¬ μƒνƒœ:")
print(f" - 처리된 μ„Έμ…˜ 수: {len(processed_sessions)}")
print(f" - λ§ˆμ§€λ§‰ ν•™μŠ΅μΌ: {state.get('last_training_date', 'μ—†μŒ')}")
print(f" - λͺ¨λΈ 버전: {state.get('model_version', 0)}")
# μƒˆλ‘œμš΄ 데이터 μˆ˜μ§‘
new_data, new_sessions = self.get_new_data(processed_sessions)
result: Dict[str, Optional[str]] = {
"status": "skipped",
"model_path": None,
"new_data_count": len(new_data),
}
if new_data:
# 증뢄 ν•™μŠ΅ μˆ˜ν–‰ (전체 λ°μ΄ν„°λ‘œ μž¬ν•™μŠ΅ν•˜λ˜ 쀑볡 μ œμ™Έ)
model_path = self.train_incremental_model(
new_data,
processed_sessions
)
if model_path:
# ν•™μŠ΅ μƒνƒœ μ—…λ°μ΄νŠΈ
processed_sessions.update(new_sessions)
state['processed_sessions'] = list(processed_sessions)
state['last_training_date'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
new_version = state.get('model_version', 0) + 1
state['model_version'] = new_version
state['total_data_count'] = state.get('total_data_count', 0) + len(new_data)
self.save_training_state(state)
print("\nβœ… μžλ™ ν•™μŠ΅ μ™„λ£Œ!")
print(f" - λͺ¨λΈ 경둜: {model_path}")
print(f" - μƒˆλ‘œμš΄ λͺ¨λΈ 버전: {state['model_version']}")
print(f" - 총 처리된 데이터: {state['total_data_count']}개")
result.update({
"status": "trained",
"model_path": model_path,
"new_data_count": len(new_data),
"model_version": str(state['model_version']),
})
else:
print("\n⚠️ μƒˆλ‘œμš΄ 데이터가 μ—†μ–΄ ν•™μŠ΅μ„ κ±΄λ„ˆλœλ‹ˆλ‹€.")
print("=" * 80)
return result
def _get_hf_api(self) -> Optional[HfApi]:
if not self._hf_repo_id or not self._hf_token:
return None
return HfApi(token=self._hf_token)
def _download_state_from_hub(self) -> bool:
api = self._get_hf_api()
if api is None:
return False
try:
downloaded_path = hf_hub_download(
repo_id=self._hf_repo_id,
filename=self._hf_state_filename,
repo_type="model",
token=self._hf_token,
local_dir=self.state_dir,
local_dir_use_symlinks=False,
)
shutil.move(downloaded_path, self.state_file)
print(f"βœ… Hugging Face Hubμ—μ„œ ν•™μŠ΅ μƒνƒœλ₯Ό λ‹€μš΄λ‘œλ“œν–ˆμŠ΅λ‹ˆλ‹€: {self._hf_state_filename}")
return True
except Exception as e:
status = getattr(getattr(e, "response", None), "status_code", None)
if status == 404:
print("ℹ️ Hugging Face Hub에 ν•™μŠ΅ μƒνƒœ 파일이 μ—†μ–΄ μƒˆλ‘œ μƒμ„±ν•©λ‹ˆλ‹€.")
else:
print(f"⚠️ ν•™μŠ΅ μƒνƒœ λ‹€μš΄λ‘œλ“œ 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {e}")
return False
def _upload_state_to_hub(self) -> None:
api = self._get_hf_api()
if api is None:
return
try:
api.create_repo(repo_id=self._hf_repo_id, repo_type="model", private=False, exist_ok=True)
api.upload_file(
path_or_fileobj=self.state_file,
path_in_repo=self._hf_state_filename,
repo_id=self._hf_repo_id,
repo_type="model",
commit_message="Update training state",
)
print("βœ… ν•™μŠ΅ μƒνƒœλ₯Ό Hugging Face Hub에 μ—…λ‘œλ“œν–ˆμŠ΅λ‹ˆλ‹€.")
except Exception as e:
print(f"⚠️ ν•™μŠ΅ μƒνƒœ μ—…λ‘œλ“œ μ‹€νŒ¨: {e}")
def main():
"""메인 ν•¨μˆ˜"""
scheduler = TrainingScheduler()
# λ§€μ£Ό μΌμš”μΌ μžμ •μ— μ‹€ν–‰
schedule.every().day.at("00:00").do(scheduler.run_scheduled_training)
print("πŸ“… μžλ™ ν•™μŠ΅ μŠ€μΌ€μ€„λŸ¬ μ‹œμž‘")
print(" - μ‹€ν–‰ μ‹œκ°„: 맀일 00:00")
print(" - μ’…λ£Œν•˜λ €λ©΄ Ctrl+Cλ₯Ό λˆ„λ₯΄μ„Έμš”\n")
# μŠ€μΌ€μ€„λŸ¬ μ‹€ν–‰
try:
while True:
schedule.run_pending()
time.sleep(60) # 1λΆ„λ§ˆλ‹€ 체크
except KeyboardInterrupt:
print("\n\n⏹️ μŠ€μΌ€μ€„λŸ¬ μ’…λ£Œ")
if __name__ == "__main__":
main()