Spaces:
Sleeping
Sleeping
| import time | |
| from werkzeug.utils import secure_filename | |
| from flask import Blueprint, jsonify, Response, request | |
| from charset_normalizer import from_path | |
| from lpm_kernel.api.domains.trainprocess.trainprocess_service import TrainProcessService | |
| from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager | |
| from ...common.responses import APIResponse | |
| from threading import Thread | |
| from lpm_kernel.configs.logging import get_train_process_logger | |
| logger = get_train_process_logger() | |
| trainprocess_bp = Blueprint("trainprocess", __name__, url_prefix="/api/trainprocess") | |
| def start_process(): | |
| """ | |
| Start training process, returns progress stream ID | |
| Request parameters: | |
| model_name: Model name | |
| learning_rate: Learning rate for model training (optional) | |
| number_of_epochs: Number of training epochs (optional) | |
| concurrency_threads: Number of threads for concurrent processing (optional) | |
| data_synthesis_mode: Mode for data synthesis (optional) | |
| use_cuda: Whether to use CUDA for training (optional) | |
| Includes the following steps: | |
| 1. Health check | |
| 2. Generate L0 | |
| 3. Generate document embeddings | |
| 4. Process document chunks | |
| 5. Generate chunk embeddings | |
| 6. Analyze documents | |
| 7. Generate L1 | |
| 8. Download model | |
| 9. Prepare data | |
| 10. Train model | |
| 11. Merge weights | |
| 12. Convert model | |
| Returns: | |
| Response: JSON response | |
| { | |
| "code": 0 for success, non-zero for failure, | |
| "message": "Error message", | |
| "data": { | |
| "progress_id": "Progress stream ID", | |
| "model_name": "Model name" | |
| } | |
| } | |
| """ | |
| logger.info("Training process starting...") # Log the startup | |
| try: | |
| data = request.get_json() | |
| if not data or "model_name" not in data: | |
| return jsonify(APIResponse.error(message="Missing required parameters")) | |
| model_name = data["model_name"] | |
| # Get optional parameters with default values | |
| learning_rate = data.get("learning_rate", None) | |
| number_of_epochs = data.get("number_of_epochs", None) | |
| concurrency_threads = data.get("concurrency_threads", None) | |
| data_synthesis_mode = data.get("data_synthesis_mode", None) | |
| use_cuda = data.get("use_cuda", False) # Default to False if not provided | |
| is_cot = data.get("is_cot", None) | |
| # Log the received parameters | |
| logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, is_cot={is_cot}") | |
| # Create service instance with model name and additional parameters | |
| last_train_service = TrainProcessService.get_instance() | |
| # Check if there are any in_progress statuses that need to be reset | |
| if last_train_service is not None and last_train_service.progress.progress.data["status"] == "in_progress": | |
| return jsonify(APIResponse.error( | |
| message="There is an existing training process that was interrupted.", | |
| code=409 # Conflict status code | |
| )) | |
| train_service = TrainProcessService(current_model_name=model_name) | |
| if not train_service.check_training_condition(): | |
| train_service.reset_progress() | |
| # Save training parameters | |
| training_params = { | |
| "model_name": model_name, | |
| "learning_rate": learning_rate, | |
| "number_of_epochs": number_of_epochs, | |
| "concurrency_threads": concurrency_threads, | |
| "data_synthesis_mode": data_synthesis_mode, | |
| "use_cuda": use_cuda, # Make sure to include use_cuda parameter | |
| "is_cot": is_cot | |
| } | |
| params_manager = TrainingParamsManager() | |
| # Update the latest training parameters | |
| params_manager.update_training_params(training_params) | |
| # Log training parameters | |
| logger.info(f"Saved training parameters: {training_params}") | |
| thread = Thread(target=train_service.start_process) | |
| thread.daemon = True | |
| thread.start() | |
| # Return success response with all parameters | |
| return jsonify( | |
| APIResponse.success( | |
| data={ | |
| "model_name": model_name, | |
| "learning_rate": learning_rate, | |
| "number_of_epochs": number_of_epochs, | |
| "concurrency_threads": concurrency_threads, | |
| "data_synthesis_mode": data_synthesis_mode, | |
| "use_cuda": use_cuda, # Include in response | |
| "is_cot": is_cot | |
| } | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Training process failed: {str(e)}") | |
| return jsonify(APIResponse.error(message=f"Training process error: {str(e)}")) | |
| def stream_logs(): | |
| """Get training logs in real-time""" | |
| log_file_path = "logs/train/train.log" # Log file path | |
| last_position = 0 | |
| def generate_logs(): | |
| nonlocal last_position | |
| while True: | |
| try: | |
| encoding = from_path(log_file_path).best().encoding | |
| with open(log_file_path, 'r', encoding=encoding) as log_file: | |
| log_file.seek(last_position) | |
| new_lines = log_file.readlines() # Read new lines | |
| for line in new_lines: | |
| # Skip empty lines | |
| if not line.strip(): | |
| continue | |
| yield f"data: {line.strip()}\n\n" | |
| last_position = log_file.tell() | |
| if not new_lines: | |
| yield f":heartbeat\n\n" | |
| except Exception as e: | |
| # If file reading fails, record error and continue | |
| yield f"data: Error reading log file: {str(e)}\n\n" | |
| time.sleep(1) # Check for new logs every second | |
| return Response( | |
| generate_logs(), | |
| mimetype='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache, no-transform', | |
| 'X-Accel-Buffering': 'no', | |
| 'Connection': 'keep-alive', | |
| 'Transfer-Encoding': 'chunked' | |
| } | |
| ) | |
| def get_progress(model_name): | |
| """Get current progress (non-real-time)""" | |
| sanitized_model_name = secure_filename(model_name) # Sanitize model_name | |
| try: | |
| train_service = TrainProcessService(current_model_name=sanitized_model_name) # Pass in specific progress file | |
| progress = train_service.progress.progress | |
| return jsonify( | |
| APIResponse.success( | |
| data=progress.to_dict() # Return progress data | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Get progress failed: {str(e)}", exc_info=True) | |
| return jsonify(APIResponse.error(message=str(e))) | |
| def reset_progress(): | |
| """ | |
| Reset progress | |
| Returns: | |
| Response: JSON response | |
| { | |
| "code": 0 for success, non-zero for failure, | |
| "message": "Error message", | |
| "data": null | |
| } | |
| """ | |
| try: | |
| train_service = TrainProcessService.get_instance() | |
| if train_service is not None: | |
| train_service.progress.reset_progress() | |
| logger.info("Progress reset successfully") | |
| else: | |
| logger.warning("No active training process found") | |
| return jsonify(APIResponse.success(message="Progress reset successfully")) | |
| except Exception as e: | |
| logger.error(f"Reset progress failed: {str(e)}", exc_info=True) | |
| return jsonify(APIResponse.error(message=f"Failed to reset progress: {str(e)}")) | |
| def stop_training(): | |
| """Stop training process and wait until status is suspended""" | |
| try: | |
| # Get the TrainProcessService instance | |
| train_service = TrainProcessService.get_instance() # Need to get instance based on your implementation | |
| if train_service is None: | |
| return jsonify(APIResponse.error(message="Failed to stop training: No active training process")) | |
| # Stop the process | |
| train_service.stop_process() | |
| # Wait for the status to change to SUSPENDED | |
| wait_interval = 1 # Check interval in seconds | |
| while True: | |
| # Get the current progress | |
| progress = train_service.progress.progress | |
| # Check if status is SUSPENDED | |
| if progress.data["status"] == "suspended" or progress.data["status"] == "failed": | |
| return jsonify(APIResponse.success( | |
| message="Training process has been stopped and status is confirmed as suspended", | |
| data={"status": "suspended"} | |
| )) | |
| # Wait before checking again | |
| time.sleep(wait_interval) | |
| except Exception as e: | |
| logger.error(f"Error stopping training process: {str(e)}", exc_info=True) | |
| return jsonify(APIResponse.error(message=f"Error stopping training process: {str(e)}")) | |
| def get_step_output_content(): | |
| """ | |
| Get content of output file for a specific training step | |
| Request parameters: | |
| step_name: Name of the step to get content for, e.g. 'extract_dimensional_topics' | |
| Returns: | |
| Response: JSON response | |
| { | |
| "code": 0, | |
| "message": "Success", | |
| "data": {...} // Content of the output file, or null if not found | |
| } | |
| """ | |
| try: | |
| # Get TrainProcessService instance | |
| train_service = TrainProcessService.get_instance() | |
| if train_service is None: | |
| logger.error("No active training process found.") | |
| return jsonify(APIResponse.error(message="No active training process found.")) | |
| # Get step name from query parameters | |
| step_name = request.args.get('step_name') | |
| if not step_name: | |
| return jsonify(APIResponse.error(message="Missing required parameter: step_name", code=400)) | |
| # Get step output content | |
| output_content = train_service.get_step_output_content(step_name) | |
| logger.info(f"Step output content: {output_content}") | |
| return jsonify(APIResponse.success(data=output_content)) | |
| except Exception as e: | |
| logger.error(f"Failed to get step output content: {str(e)}", exc_info=True) | |
| return jsonify(APIResponse.error(message=f"Failed to get step output content: {str(e)}")) | |
| def get_training_params(): | |
| """ | |
| Get the latest training parameters | |
| Returns: | |
| Response: JSON response | |
| { | |
| "code": 0 for success, non-zero for failure, | |
| "message": "Error message", | |
| "data": { | |
| "model_name": "Model name", | |
| "learning_rate": "Learning rate", | |
| "number_of_epochs": "Number of epochs", | |
| "concurrency_threads": "Concurrency threads", | |
| "data_synthesis_mode": "Data synthesis mode" | |
| } | |
| } | |
| """ | |
| try: | |
| # Get the latest training parameters | |
| params_manager = TrainingParamsManager() | |
| training_params = params_manager.get_latest_training_params() | |
| return jsonify(APIResponse.success(data=training_params)) | |
| except Exception as e: | |
| logger.error(f"Error getting training parameters: {str(e)}", exc_info=True) | |
| return jsonify(APIResponse.error(message=f"Error getting training parameters: {str(e)}")) | |
| def retrain(): | |
| """ | |
| Reset progress to data processing stage (data_processing not started) and automatically start the training process | |
| Request parameters: | |
| model_name: Model name (required) | |
| learning_rate: Learning rate for model training (optional) | |
| number_of_epochs: Number of training epochs (optional) | |
| concurrency_threads: Number of threads for concurrent processing (optional) | |
| data_synthesis_mode: Mode for data synthesis (optional) | |
| use_cuda: Whether to use CUDA for training (optional) | |
| is_cot: Whether to use Chain of Thought (optional) | |
| Returns: | |
| Response: JSON response | |
| { | |
| "code": 0 for success, non-zero for failure, | |
| "message": "Error message", | |
| "data": { | |
| "progress_id": "Progress stream ID", | |
| "model_name": "Model name" | |
| } | |
| } | |
| """ | |
| try: | |
| # get request parameters | |
| data = request.get_json() or {} | |
| model_name = data.get("model_name") | |
| if not model_name: | |
| return jsonify(APIResponse.error(message="missing necessary parameter: model_name", code=400)) | |
| # Get optional parameters | |
| learning_rate = data.get("learning_rate", None) | |
| number_of_epochs = data.get("number_of_epochs", None) | |
| concurrency_threads = data.get("concurrency_threads", None) | |
| data_synthesis_mode = data.get("data_synthesis_mode", None) | |
| use_cuda = data.get("use_cuda", False) | |
| is_cot = data.get("is_cot", None) | |
| # Log the received parameters | |
| logger.info(f"Retrain parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, use_cuda={use_cuda}, is_cot={is_cot}") | |
| # Create training service instance | |
| train_service = TrainProcessService(current_model_name=model_name) | |
| # Check if there are any in_progress statuses that need to be reset | |
| if train_service.progress.progress.data["status"] == "in_progress": | |
| # Reset the progress and continue | |
| logger.info("There is an existing training process that was interrupted.") | |
| train_service.reset_progress() | |
| # Save training parameters | |
| training_params = { | |
| "model_name": model_name, | |
| "learning_rate": learning_rate, | |
| "number_of_epochs": number_of_epochs, | |
| "concurrency_threads": concurrency_threads, | |
| "data_synthesis_mode": data_synthesis_mode, | |
| "use_cuda": use_cuda, | |
| "is_cot": is_cot | |
| } | |
| params_manager = TrainingParamsManager() | |
| # Update the training parameters, optionally using previous params as base | |
| params_manager.update_training_params(training_params, use_previous_params=False) | |
| # Log training parameters | |
| logger.info(f"Saved training parameters: {training_params}") | |
| thread = Thread(target=train_service.start_process) | |
| thread.daemon = True | |
| thread.start() | |
| return jsonify( | |
| APIResponse.success( | |
| message="Successfully reset progress to data processing stage and started training process", | |
| data={ | |
| "model_name": model_name, | |
| "learning_rate": learning_rate, | |
| "number_of_epochs": number_of_epochs, | |
| "concurrency_threads": concurrency_threads, | |
| "data_synthesis_mode": data_synthesis_mode, | |
| "use_cuda": use_cuda, | |
| "is_cot": is_cot | |
| } | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Retrain reset failed: {str(e)}", exc_info=True) | |
| return jsonify(APIResponse.error(message=f"Failed to reset progress to data processing stage: {str(e)}")) | |