""" 주 1회 자동 모델 학습 스케줄러 매주 일요일 자정에 실행되어 모델을 자동으로 업데이트합니다. """ import schedule import time import os import json import shutil from datetime import datetime from pathlib import Path from typing import Dict, Optional from huggingface_hub import HfApi, hf_hub_download try: from huggingface_hub.utils import HfHubHTTPError except ImportError: # fallback for older versions HfHubHTTPError = Exception # type: ignore from train_e2e import main as train_main from load_dataset import load_musclecare_dataset class TrainingScheduler: """모델 학습 스케줄러 클래스""" def __init__(self, state_file: str = './model/training_state.json'): """ Args: state_file: 학습 상태를 저장할 파일 경로 """ self.state_file = state_file self.state_dir = os.path.dirname(state_file) os.makedirs(self.state_dir, exist_ok=True) self._hf_token = os.getenv("HF_E2E_MODEL_TOKEN") self._hf_repo_id = os.getenv("HF_E2E_MODEL_REPO_ID") self._hf_state_filename = os.getenv("HF_E2E_MODEL_STATE_FILE", Path(state_file).name) if not os.path.exists(self.state_file): self._download_state_from_hub() def load_training_state(self): """학습 상태 로드""" if os.path.exists(self.state_file): try: with open(self.state_file, 'r', encoding='utf-8') as f: state = json.load(f) return state except Exception as e: print(f"⚠️ 학습 상태 로드 실패: {e}") return self._get_default_state() if self._download_state_from_hub(): return self.load_training_state() return self._get_default_state() def save_training_state(self, state): """학습 상태 저장""" try: with open(self.state_file, 'w', encoding='utf-8') as f: json.dump(state, f, indent=2, ensure_ascii=False) self._upload_state_to_hub() except Exception as e: print(f"⚠️ 학습 상태 저장 실패: {e}") def _get_default_state(self): """기본 학습 상태""" return { 'processed_sessions': [], 'last_training_date': None, 'model_version': 0, 'total_data_count': 0 } def reset_training_state(self): """학습 상태 초기화""" state = self._get_default_state() self.save_training_state(state) return state def get_new_data(self, processed_sessions): """ 새로운 데이터만 수집 (중복 방지) Args: processed_sessions: 이미 처리된 session_id 집합 Returns: list: 새로운 데이터 리스트 """ print("📊 새로운 데이터 수집 중...") dataset_dict = load_musclecare_dataset() new_data = [] new_sessions = set() for split_name in dataset_dict.keys(): for item in dataset_dict[split_name]: session_id = item.get('session_id', '') # 중복 체크 if session_id not in processed_sessions: new_data.append(item) new_sessions.add(session_id) print(f"✅ 새로운 데이터: {len(new_data)}개 (새로운 세션: {len(new_sessions)}개)") return new_data, new_sessions def train_incremental_model(self, new_data, processed_sessions): """ 증분 학습 수행 (전체 데이터로 재학습하되 중복 제외) Args: new_data: 새로운 데이터 리스트 processed_sessions: 이미 처리된 session_id 집합 """ if not new_data: print("⚠️ 새로운 데이터가 없어 학습을 건너뜁니다.") return None print(f"\n🔄 모델 학습 시작 (새로운 데이터: {len(new_data)}개 포함)...") # 전체 데이터를 가져오되, 중복된 세션은 제외 # train_e2e.py의 main 함수에 exclude_sessions 파라미터 전달 from train_e2e import main as train_main training_outputs = train_main(data_list=None, exclude_sessions=processed_sessions) if isinstance(training_outputs, dict): return ( training_outputs.get('tflite') or training_outputs.get('keras') or training_outputs.get('metadata') ) return training_outputs def run_scheduled_training(self) -> Dict[str, Optional[str]]: """스케줄된 학습 실행""" print("=" * 80) print(f"🕛 자동 학습 시작 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print("=" * 80) # 학습 상태 로드 state = self.load_training_state() processed_sessions = set(state.get('processed_sessions', [])) print(f"📋 현재 상태:") print(f" - 처리된 세션 수: {len(processed_sessions)}") print(f" - 마지막 학습일: {state.get('last_training_date', '없음')}") print(f" - 모델 버전: {state.get('model_version', 0)}") # 새로운 데이터 수집 new_data, new_sessions = self.get_new_data(processed_sessions) result: Dict[str, Optional[str]] = { "status": "skipped", "model_path": None, "new_data_count": len(new_data), } if new_data: # 증분 학습 수행 (전체 데이터로 재학습하되 중복 제외) model_path = self.train_incremental_model( new_data, processed_sessions ) if model_path: # 학습 상태 업데이트 processed_sessions.update(new_sessions) state['processed_sessions'] = list(processed_sessions) state['last_training_date'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') new_version = state.get('model_version', 0) + 1 state['model_version'] = new_version state['total_data_count'] = state.get('total_data_count', 0) + len(new_data) self.save_training_state(state) print("\n✅ 자동 학습 완료!") print(f" - 모델 경로: {model_path}") print(f" - 새로운 모델 버전: {state['model_version']}") print(f" - 총 처리된 데이터: {state['total_data_count']}개") result.update({ "status": "trained", "model_path": model_path, "new_data_count": len(new_data), "model_version": str(state['model_version']), }) else: print("\n⚠️ 새로운 데이터가 없어 학습을 건너뜁니다.") print("=" * 80) return result def _get_hf_api(self) -> Optional[HfApi]: if not self._hf_repo_id or not self._hf_token: return None return HfApi(token=self._hf_token) def _download_state_from_hub(self) -> bool: api = self._get_hf_api() if api is None: return False try: downloaded_path = hf_hub_download( repo_id=self._hf_repo_id, filename=self._hf_state_filename, repo_type="model", token=self._hf_token, local_dir=self.state_dir, local_dir_use_symlinks=False, ) shutil.move(downloaded_path, self.state_file) print(f"✅ Hugging Face Hub에서 학습 상태를 다운로드했습니다: {self._hf_state_filename}") return True except Exception as e: status = getattr(getattr(e, "response", None), "status_code", None) if status == 404: print("ℹ️ Hugging Face Hub에 학습 상태 파일이 없어 새로 생성합니다.") else: print(f"⚠️ 학습 상태 다운로드 중 오류가 발생했습니다: {e}") return False def _upload_state_to_hub(self) -> None: api = self._get_hf_api() if api is None: return try: api.create_repo(repo_id=self._hf_repo_id, repo_type="model", private=False, exist_ok=True) api.upload_file( path_or_fileobj=self.state_file, path_in_repo=self._hf_state_filename, repo_id=self._hf_repo_id, repo_type="model", commit_message="Update training state", ) print("✅ 학습 상태를 Hugging Face Hub에 업로드했습니다.") except Exception as e: print(f"⚠️ 학습 상태 업로드 실패: {e}") def main(): """메인 함수""" scheduler = TrainingScheduler() # 매주 일요일 자정에 실행 schedule.every().day.at("00:00").do(scheduler.run_scheduled_training) print("📅 자동 학습 스케줄러 시작") print(" - 실행 시간: 매일 00:00") print(" - 종료하려면 Ctrl+C를 누르세요\n") # 스케줄러 실행 try: while True: schedule.run_pending() time.sleep(60) # 1분마다 체크 except KeyboardInterrupt: print("\n\n⏹️ 스케줄러 종료") if __name__ == "__main__": main()