Spaces:
Running
Running
| """ | |
| μ£Ό 1ν μλ λͺ¨λΈ νμ΅ μ€μΌμ€λ¬ | |
| λ§€μ£Ό μΌμμΌ μμ μ μ€νλμ΄ λͺ¨λΈμ μλμΌλ‘ μ λ°μ΄νΈν©λλ€. | |
| """ | |
| import schedule | |
| import time | |
| import os | |
| import json | |
| import shutil | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, Optional | |
| from huggingface_hub import HfApi, hf_hub_download | |
| try: | |
| from huggingface_hub.utils import HfHubHTTPError | |
| except ImportError: # fallback for older versions | |
| HfHubHTTPError = Exception # type: ignore | |
| from train_e2e import main as train_main | |
| from load_dataset import load_musclecare_dataset | |
| class TrainingScheduler: | |
| """λͺ¨λΈ νμ΅ μ€μΌμ€λ¬ ν΄λμ€""" | |
| def __init__(self, state_file: str = './model/training_state.json'): | |
| """ | |
| Args: | |
| state_file: νμ΅ μνλ₯Ό μ μ₯ν νμΌ κ²½λ‘ | |
| """ | |
| self.state_file = state_file | |
| self.state_dir = os.path.dirname(state_file) | |
| os.makedirs(self.state_dir, exist_ok=True) | |
| self._hf_token = os.getenv("HF_E2E_MODEL_TOKEN") | |
| self._hf_repo_id = os.getenv("HF_E2E_MODEL_REPO_ID") | |
| self._hf_state_filename = os.getenv("HF_E2E_MODEL_STATE_FILE", Path(state_file).name) | |
| if not os.path.exists(self.state_file): | |
| self._download_state_from_hub() | |
| def load_training_state(self): | |
| """νμ΅ μν λ‘λ""" | |
| if os.path.exists(self.state_file): | |
| try: | |
| with open(self.state_file, 'r', encoding='utf-8') as f: | |
| state = json.load(f) | |
| return state | |
| except Exception as e: | |
| print(f"β οΈ νμ΅ μν λ‘λ μ€ν¨: {e}") | |
| return self._get_default_state() | |
| if self._download_state_from_hub(): | |
| return self.load_training_state() | |
| return self._get_default_state() | |
| def save_training_state(self, state): | |
| """νμ΅ μν μ μ₯""" | |
| try: | |
| with open(self.state_file, 'w', encoding='utf-8') as f: | |
| json.dump(state, f, indent=2, ensure_ascii=False) | |
| self._upload_state_to_hub() | |
| except Exception as e: | |
| print(f"β οΈ νμ΅ μν μ μ₯ μ€ν¨: {e}") | |
| def _get_default_state(self): | |
| """κΈ°λ³Έ νμ΅ μν""" | |
| return { | |
| 'processed_sessions': [], | |
| 'last_training_date': None, | |
| 'model_version': 0, | |
| 'total_data_count': 0 | |
| } | |
| def reset_training_state(self): | |
| """νμ΅ μν μ΄κΈ°ν""" | |
| state = self._get_default_state() | |
| self.save_training_state(state) | |
| return state | |
| def get_new_data(self, processed_sessions): | |
| """ | |
| μλ‘μ΄ λ°μ΄ν°λ§ μμ§ (μ€λ³΅ λ°©μ§) | |
| Args: | |
| processed_sessions: μ΄λ―Έ μ²λ¦¬λ session_id μ§ν© | |
| Returns: | |
| list: μλ‘μ΄ λ°μ΄ν° 리μ€νΈ | |
| """ | |
| print("π μλ‘μ΄ λ°μ΄ν° μμ§ μ€...") | |
| dataset_dict = load_musclecare_dataset() | |
| new_data = [] | |
| new_sessions = set() | |
| for split_name in dataset_dict.keys(): | |
| for item in dataset_dict[split_name]: | |
| session_id = item.get('session_id', '') | |
| # μ€λ³΅ μ²΄ν¬ | |
| if session_id not in processed_sessions: | |
| new_data.append(item) | |
| new_sessions.add(session_id) | |
| print(f"β μλ‘μ΄ λ°μ΄ν°: {len(new_data)}κ° (μλ‘μ΄ μΈμ : {len(new_sessions)}κ°)") | |
| return new_data, new_sessions | |
| def train_incremental_model(self, new_data, processed_sessions): | |
| """ | |
| μ¦λΆ νμ΅ μν (μ 체 λ°μ΄ν°λ‘ μ¬νμ΅νλ μ€λ³΅ μ μΈ) | |
| Args: | |
| new_data: μλ‘μ΄ λ°μ΄ν° 리μ€νΈ | |
| processed_sessions: μ΄λ―Έ μ²λ¦¬λ session_id μ§ν© | |
| """ | |
| if not new_data: | |
| print("β οΈ μλ‘μ΄ λ°μ΄ν°κ° μμ΄ νμ΅μ 건λλλλ€.") | |
| return None | |
| print(f"\nπ λͺ¨λΈ νμ΅ μμ (μλ‘μ΄ λ°μ΄ν°: {len(new_data)}κ° ν¬ν¨)...") | |
| # μ 체 λ°μ΄ν°λ₯Ό κ°μ Έμ€λ, μ€λ³΅λ μΈμ μ μ μΈ | |
| # train_e2e.pyμ main ν¨μμ exclude_sessions νλΌλ―Έν° μ λ¬ | |
| from train_e2e import main as train_main | |
| training_outputs = train_main(data_list=None, exclude_sessions=processed_sessions) | |
| if isinstance(training_outputs, dict): | |
| return ( | |
| training_outputs.get('tflite') | |
| or training_outputs.get('keras') | |
| or training_outputs.get('metadata') | |
| ) | |
| return training_outputs | |
| def run_scheduled_training(self) -> Dict[str, Optional[str]]: | |
| """μ€μΌμ€λ νμ΅ μ€ν""" | |
| print("=" * 80) | |
| print(f"π μλ νμ΅ μμ - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print("=" * 80) | |
| # νμ΅ μν λ‘λ | |
| state = self.load_training_state() | |
| processed_sessions = set(state.get('processed_sessions', [])) | |
| print(f"π νμ¬ μν:") | |
| print(f" - μ²λ¦¬λ μΈμ μ: {len(processed_sessions)}") | |
| print(f" - λ§μ§λ§ νμ΅μΌ: {state.get('last_training_date', 'μμ')}") | |
| print(f" - λͺ¨λΈ λ²μ : {state.get('model_version', 0)}") | |
| # μλ‘μ΄ λ°μ΄ν° μμ§ | |
| new_data, new_sessions = self.get_new_data(processed_sessions) | |
| result: Dict[str, Optional[str]] = { | |
| "status": "skipped", | |
| "model_path": None, | |
| "new_data_count": len(new_data), | |
| } | |
| if new_data: | |
| # μ¦λΆ νμ΅ μν (μ 체 λ°μ΄ν°λ‘ μ¬νμ΅νλ μ€λ³΅ μ μΈ) | |
| model_path = self.train_incremental_model( | |
| new_data, | |
| processed_sessions | |
| ) | |
| if model_path: | |
| # νμ΅ μν μ λ°μ΄νΈ | |
| processed_sessions.update(new_sessions) | |
| state['processed_sessions'] = list(processed_sessions) | |
| state['last_training_date'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| new_version = state.get('model_version', 0) + 1 | |
| state['model_version'] = new_version | |
| state['total_data_count'] = state.get('total_data_count', 0) + len(new_data) | |
| self.save_training_state(state) | |
| print("\nβ μλ νμ΅ μλ£!") | |
| print(f" - λͺ¨λΈ κ²½λ‘: {model_path}") | |
| print(f" - μλ‘μ΄ λͺ¨λΈ λ²μ : {state['model_version']}") | |
| print(f" - μ΄ μ²λ¦¬λ λ°μ΄ν°: {state['total_data_count']}κ°") | |
| result.update({ | |
| "status": "trained", | |
| "model_path": model_path, | |
| "new_data_count": len(new_data), | |
| "model_version": str(state['model_version']), | |
| }) | |
| else: | |
| print("\nβ οΈ μλ‘μ΄ λ°μ΄ν°κ° μμ΄ νμ΅μ 건λλλλ€.") | |
| print("=" * 80) | |
| return result | |
| def _get_hf_api(self) -> Optional[HfApi]: | |
| if not self._hf_repo_id or not self._hf_token: | |
| return None | |
| return HfApi(token=self._hf_token) | |
| def _download_state_from_hub(self) -> bool: | |
| api = self._get_hf_api() | |
| if api is None: | |
| return False | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=self._hf_repo_id, | |
| filename=self._hf_state_filename, | |
| repo_type="model", | |
| token=self._hf_token, | |
| local_dir=self.state_dir, | |
| local_dir_use_symlinks=False, | |
| ) | |
| shutil.move(downloaded_path, self.state_file) | |
| print(f"β Hugging Face Hubμμ νμ΅ μνλ₯Ό λ€μ΄λ‘λνμ΅λλ€: {self._hf_state_filename}") | |
| return True | |
| except Exception as e: | |
| status = getattr(getattr(e, "response", None), "status_code", None) | |
| if status == 404: | |
| print("βΉοΈ Hugging Face Hubμ νμ΅ μν νμΌμ΄ μμ΄ μλ‘ μμ±ν©λλ€.") | |
| else: | |
| print(f"β οΈ νμ΅ μν λ€μ΄λ‘λ μ€ μ€λ₯κ° λ°μνμ΅λλ€: {e}") | |
| return False | |
| def _upload_state_to_hub(self) -> None: | |
| api = self._get_hf_api() | |
| if api is None: | |
| return | |
| try: | |
| api.create_repo(repo_id=self._hf_repo_id, repo_type="model", private=False, exist_ok=True) | |
| api.upload_file( | |
| path_or_fileobj=self.state_file, | |
| path_in_repo=self._hf_state_filename, | |
| repo_id=self._hf_repo_id, | |
| repo_type="model", | |
| commit_message="Update training state", | |
| ) | |
| print("β νμ΅ μνλ₯Ό Hugging Face Hubμ μ λ‘λνμ΅λλ€.") | |
| except Exception as e: | |
| print(f"β οΈ νμ΅ μν μ λ‘λ μ€ν¨: {e}") | |
| def main(): | |
| """λ©μΈ ν¨μ""" | |
| scheduler = TrainingScheduler() | |
| # λ§€μ£Ό μΌμμΌ μμ μ μ€ν | |
| schedule.every().day.at("00:00").do(scheduler.run_scheduled_training) | |
| print("π μλ νμ΅ μ€μΌμ€λ¬ μμ") | |
| print(" - μ€ν μκ°: λ§€μΌ 00:00") | |
| print(" - μ’ λ£νλ €λ©΄ Ctrl+Cλ₯Ό λλ₯΄μΈμ\n") | |
| # μ€μΌμ€λ¬ μ€ν | |
| try: | |
| while True: | |
| schedule.run_pending() | |
| time.sleep(60) # 1λΆλ§λ€ μ²΄ν¬ | |
| except KeyboardInterrupt: | |
| print("\n\nβΉοΈ μ€μΌμ€λ¬ μ’ λ£") | |
| if __name__ == "__main__": | |
| main() | |