secondme-api / lpm_kernel /api /domains /trainprocess /training_params_manager.py
Gemini
feat: add detailed logging
01d5a5d
"""
Training parameters management module.
This module provides functions for managing and accessing training parameters.
"""
import logging
import json
import os
# Configure logger
logger = logging.getLogger(__name__)
class TrainingParamsManager:
"""
Training parameters manager class.
"""
# Default training parameters
_default_training_params = {
"model_name": "Qwen2.5-0.5B-Instruct",
"learning_rate": 1e-4,
"number_of_epochs": 3,
"concurrency_threads": 2,
"data_synthesis_mode": "low",
"use_cuda": False, # Default to using CUDA when available
"is_cot": False
}
# Parameters file path
_params_file_path = None
@classmethod
def _get_params_file_path(cls):
"""
Get the training parameters file path
"""
if cls._params_file_path is None:
# Set the parameters file path
progress_dir = os.path.join(os.getcwd(), "data", "progress")
if not os.path.exists(progress_dir):
os.makedirs(progress_dir)
cls._params_file_path = os.path.join(progress_dir, "training_params.json")
return cls._params_file_path
@classmethod
def update_training_params(cls, params, use_previous_params=True):
"""
Update the latest training parameters and save to file
Args:
params: Dictionary containing training parameters
use_previous_params: Whether to use previous training parameters as base
"""
# First try to load existing parameters
current_params = cls.get_latest_training_params() if use_previous_params else cls._default_training_params.copy()
# Update parameters
for key, value in params.items():
if key in cls._default_training_params:
current_params[key] = value
logger.debug(f"Updated training parameter {key} to {value}")
else:
logger.warning(f"Ignoring unknown parameter: {key}")
# Save to file
params_file = cls._get_params_file_path()
try:
with open(params_file, 'w', encoding='utf-8') as f:
json.dump(current_params, f, indent=2)
logger.info(f"Training parameters saved to {params_file}")
except Exception as e:
logger.error(f"Failed to save training parameters to file: {str(e)}", exc_info=True)
@classmethod
def get_latest_training_params(cls):
"""
Get the latest training parameters from file
Returns:
dict: Dictionary containing the latest training parameters
"""
params_file = cls._get_params_file_path()
# If file exists, read from file
if os.path.exists(params_file):
try:
with open(params_file, 'r', encoding='utf-8') as f:
params = json.load(f)
# Replace null values with default values
default_params = cls._default_training_params.copy()
for key, value in default_params.items():
if key not in params or params[key] is None:
params[key] = value
logger.debug(f"Loaded training parameters from {params_file}")
return params
except Exception as e:
logger.error(f"Failed to load training parameters from file: {str(e)}", exc_info=True)
# If reading fails, return default parameters
return cls._default_training_params.copy()
else:
# If file does not exist, return default parameters
return cls._default_training_params.copy()