Spaces:
Sleeping
Sleeping
| """ | |
| Training parameters management module. | |
| This module provides functions for managing and accessing training parameters. | |
| """ | |
| import logging | |
| import json | |
| import os | |
| # Configure logger | |
| logger = logging.getLogger(__name__) | |
| class TrainingParamsManager: | |
| """ | |
| Training parameters manager class. | |
| """ | |
| # Default training parameters | |
| _default_training_params = { | |
| "model_name": "Qwen2.5-0.5B-Instruct", | |
| "learning_rate": 1e-4, | |
| "number_of_epochs": 3, | |
| "concurrency_threads": 2, | |
| "data_synthesis_mode": "low", | |
| "use_cuda": False, # Default to using CUDA when available | |
| "is_cot": False | |
| } | |
| # Parameters file path | |
| _params_file_path = None | |
| def _get_params_file_path(cls): | |
| """ | |
| Get the training parameters file path | |
| """ | |
| if cls._params_file_path is None: | |
| # Set the parameters file path | |
| progress_dir = os.path.join(os.getcwd(), "data", "progress") | |
| if not os.path.exists(progress_dir): | |
| os.makedirs(progress_dir) | |
| cls._params_file_path = os.path.join(progress_dir, "training_params.json") | |
| return cls._params_file_path | |
| def update_training_params(cls, params, use_previous_params=True): | |
| """ | |
| Update the latest training parameters and save to file | |
| Args: | |
| params: Dictionary containing training parameters | |
| use_previous_params: Whether to use previous training parameters as base | |
| """ | |
| # First try to load existing parameters | |
| current_params = cls.get_latest_training_params() if use_previous_params else cls._default_training_params.copy() | |
| # Update parameters | |
| for key, value in params.items(): | |
| if key in cls._default_training_params: | |
| current_params[key] = value | |
| logger.debug(f"Updated training parameter {key} to {value}") | |
| else: | |
| logger.warning(f"Ignoring unknown parameter: {key}") | |
| # Save to file | |
| params_file = cls._get_params_file_path() | |
| try: | |
| with open(params_file, 'w', encoding='utf-8') as f: | |
| json.dump(current_params, f, indent=2) | |
| logger.info(f"Training parameters saved to {params_file}") | |
| except Exception as e: | |
| logger.error(f"Failed to save training parameters to file: {str(e)}", exc_info=True) | |
| def get_latest_training_params(cls): | |
| """ | |
| Get the latest training parameters from file | |
| Returns: | |
| dict: Dictionary containing the latest training parameters | |
| """ | |
| params_file = cls._get_params_file_path() | |
| # If file exists, read from file | |
| if os.path.exists(params_file): | |
| try: | |
| with open(params_file, 'r', encoding='utf-8') as f: | |
| params = json.load(f) | |
| # Replace null values with default values | |
| default_params = cls._default_training_params.copy() | |
| for key, value in default_params.items(): | |
| if key not in params or params[key] is None: | |
| params[key] = value | |
| logger.debug(f"Loaded training parameters from {params_file}") | |
| return params | |
| except Exception as e: | |
| logger.error(f"Failed to load training parameters from file: {str(e)}", exc_info=True) | |
| # If reading fails, return default parameters | |
| return cls._default_training_params.copy() | |
| else: | |
| # If file does not exist, return default parameters | |
| return cls._default_training_params.copy() |