Spaces:
Sleeping
Sleeping
File size: 16,212 Bytes
01d5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 |
import time
from werkzeug.utils import secure_filename
from flask import Blueprint, jsonify, Response, request
from charset_normalizer import from_path
from lpm_kernel.api.domains.trainprocess.trainprocess_service import TrainProcessService
from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager
from ...common.responses import APIResponse
from threading import Thread
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
trainprocess_bp = Blueprint("trainprocess", __name__, url_prefix="/api/trainprocess")
@trainprocess_bp.route("/start", methods=["POST"])
def start_process():
"""
Start training process, returns progress stream ID
Request parameters:
model_name: Model name
learning_rate: Learning rate for model training (optional)
number_of_epochs: Number of training epochs (optional)
concurrency_threads: Number of threads for concurrent processing (optional)
data_synthesis_mode: Mode for data synthesis (optional)
use_cuda: Whether to use CUDA for training (optional)
Includes the following steps:
1. Health check
2. Generate L0
3. Generate document embeddings
4. Process document chunks
5. Generate chunk embeddings
6. Analyze documents
7. Generate L1
8. Download model
9. Prepare data
10. Train model
11. Merge weights
12. Convert model
Returns:
Response: JSON response
{
"code": 0 for success, non-zero for failure,
"message": "Error message",
"data": {
"progress_id": "Progress stream ID",
"model_name": "Model name"
}
}
"""
logger.info("Training process starting...") # Log the startup
try:
data = request.get_json()
if not data or "model_name" not in data:
return jsonify(APIResponse.error(message="Missing required parameters"))
model_name = data["model_name"]
# Get optional parameters with default values
learning_rate = data.get("learning_rate", None)
number_of_epochs = data.get("number_of_epochs", None)
concurrency_threads = data.get("concurrency_threads", None)
data_synthesis_mode = data.get("data_synthesis_mode", None)
use_cuda = data.get("use_cuda", False) # Default to False if not provided
is_cot = data.get("is_cot", None)
# Log the received parameters
logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, is_cot={is_cot}")
# Create service instance with model name and additional parameters
last_train_service = TrainProcessService.get_instance()
# Check if there are any in_progress statuses that need to be reset
if last_train_service is not None and last_train_service.progress.progress.data["status"] == "in_progress":
return jsonify(APIResponse.error(
message="There is an existing training process that was interrupted.",
code=409 # Conflict status code
))
train_service = TrainProcessService(current_model_name=model_name)
if not train_service.check_training_condition():
train_service.reset_progress()
# Save training parameters
training_params = {
"model_name": model_name,
"learning_rate": learning_rate,
"number_of_epochs": number_of_epochs,
"concurrency_threads": concurrency_threads,
"data_synthesis_mode": data_synthesis_mode,
"use_cuda": use_cuda, # Make sure to include use_cuda parameter
"is_cot": is_cot
}
params_manager = TrainingParamsManager()
# Update the latest training parameters
params_manager.update_training_params(training_params)
# Log training parameters
logger.info(f"Saved training parameters: {training_params}")
thread = Thread(target=train_service.start_process)
thread.daemon = True
thread.start()
# Return success response with all parameters
return jsonify(
APIResponse.success(
data={
"model_name": model_name,
"learning_rate": learning_rate,
"number_of_epochs": number_of_epochs,
"concurrency_threads": concurrency_threads,
"data_synthesis_mode": data_synthesis_mode,
"use_cuda": use_cuda, # Include in response
"is_cot": is_cot
}
)
)
except Exception as e:
logger.error(f"Training process failed: {str(e)}")
return jsonify(APIResponse.error(message=f"Training process error: {str(e)}"))
@trainprocess_bp.route("/logs", methods=["GET"])
def stream_logs():
"""Get training logs in real-time"""
log_file_path = "logs/train/train.log" # Log file path
last_position = 0
def generate_logs():
nonlocal last_position
while True:
try:
encoding = from_path(log_file_path).best().encoding
with open(log_file_path, 'r', encoding=encoding) as log_file:
log_file.seek(last_position)
new_lines = log_file.readlines() # Read new lines
for line in new_lines:
# Skip empty lines
if not line.strip():
continue
yield f"data: {line.strip()}\n\n"
last_position = log_file.tell()
if not new_lines:
yield f":heartbeat\n\n"
except Exception as e:
# If file reading fails, record error and continue
yield f"data: Error reading log file: {str(e)}\n\n"
time.sleep(1) # Check for new logs every second
return Response(
generate_logs(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache, no-transform',
'X-Accel-Buffering': 'no',
'Connection': 'keep-alive',
'Transfer-Encoding': 'chunked'
}
)
@trainprocess_bp.route("/progress/<model_name>", methods=["GET"])
def get_progress(model_name):
"""Get current progress (non-real-time)"""
sanitized_model_name = secure_filename(model_name) # Sanitize model_name
try:
train_service = TrainProcessService(current_model_name=sanitized_model_name) # Pass in specific progress file
progress = train_service.progress.progress
return jsonify(
APIResponse.success(
data=progress.to_dict() # Return progress data
)
)
except Exception as e:
logger.error(f"Get progress failed: {str(e)}", exc_info=True)
return jsonify(APIResponse.error(message=str(e)))
@trainprocess_bp.route("/progress/reset", methods=["POST"])
def reset_progress():
"""
Reset progress
Returns:
Response: JSON response
{
"code": 0 for success, non-zero for failure,
"message": "Error message",
"data": null
}
"""
try:
train_service = TrainProcessService.get_instance()
if train_service is not None:
train_service.progress.reset_progress()
logger.info("Progress reset successfully")
else:
logger.warning("No active training process found")
return jsonify(APIResponse.success(message="Progress reset successfully"))
except Exception as e:
logger.error(f"Reset progress failed: {str(e)}", exc_info=True)
return jsonify(APIResponse.error(message=f"Failed to reset progress: {str(e)}"))
@trainprocess_bp.route("/stop", methods=["POST"])
def stop_training():
"""Stop training process and wait until status is suspended"""
try:
# Get the TrainProcessService instance
train_service = TrainProcessService.get_instance() # Need to get instance based on your implementation
if train_service is None:
return jsonify(APIResponse.error(message="Failed to stop training: No active training process"))
# Stop the process
train_service.stop_process()
# Wait for the status to change to SUSPENDED
wait_interval = 1 # Check interval in seconds
while True:
# Get the current progress
progress = train_service.progress.progress
# Check if status is SUSPENDED
if progress.data["status"] == "suspended" or progress.data["status"] == "failed":
return jsonify(APIResponse.success(
message="Training process has been stopped and status is confirmed as suspended",
data={"status": "suspended"}
))
# Wait before checking again
time.sleep(wait_interval)
except Exception as e:
logger.error(f"Error stopping training process: {str(e)}", exc_info=True)
return jsonify(APIResponse.error(message=f"Error stopping training process: {str(e)}"))
@trainprocess_bp.route("/step_output_content", methods=["GET"])
def get_step_output_content():
"""
Get content of output file for a specific training step
Request parameters:
step_name: Name of the step to get content for, e.g. 'extract_dimensional_topics'
Returns:
Response: JSON response
{
"code": 0,
"message": "Success",
"data": {...} // Content of the output file, or null if not found
}
"""
try:
# Get TrainProcessService instance
train_service = TrainProcessService.get_instance()
if train_service is None:
logger.error("No active training process found.")
return jsonify(APIResponse.error(message="No active training process found."))
# Get step name from query parameters
step_name = request.args.get('step_name')
if not step_name:
return jsonify(APIResponse.error(message="Missing required parameter: step_name", code=400))
# Get step output content
output_content = train_service.get_step_output_content(step_name)
logger.info(f"Step output content: {output_content}")
return jsonify(APIResponse.success(data=output_content))
except Exception as e:
logger.error(f"Failed to get step output content: {str(e)}", exc_info=True)
return jsonify(APIResponse.error(message=f"Failed to get step output content: {str(e)}"))
@trainprocess_bp.route("/training_params", methods=["GET"])
def get_training_params():
"""
Get the latest training parameters
Returns:
Response: JSON response
{
"code": 0 for success, non-zero for failure,
"message": "Error message",
"data": {
"model_name": "Model name",
"learning_rate": "Learning rate",
"number_of_epochs": "Number of epochs",
"concurrency_threads": "Concurrency threads",
"data_synthesis_mode": "Data synthesis mode"
}
}
"""
try:
# Get the latest training parameters
params_manager = TrainingParamsManager()
training_params = params_manager.get_latest_training_params()
return jsonify(APIResponse.success(data=training_params))
except Exception as e:
logger.error(f"Error getting training parameters: {str(e)}", exc_info=True)
return jsonify(APIResponse.error(message=f"Error getting training parameters: {str(e)}"))
@trainprocess_bp.route("/retrain", methods=["POST"])
def retrain():
"""
Reset progress to data processing stage (data_processing not started) and automatically start the training process
Request parameters:
model_name: Model name (required)
learning_rate: Learning rate for model training (optional)
number_of_epochs: Number of training epochs (optional)
concurrency_threads: Number of threads for concurrent processing (optional)
data_synthesis_mode: Mode for data synthesis (optional)
use_cuda: Whether to use CUDA for training (optional)
is_cot: Whether to use Chain of Thought (optional)
Returns:
Response: JSON response
{
"code": 0 for success, non-zero for failure,
"message": "Error message",
"data": {
"progress_id": "Progress stream ID",
"model_name": "Model name"
}
}
"""
try:
# get request parameters
data = request.get_json() or {}
model_name = data.get("model_name")
if not model_name:
return jsonify(APIResponse.error(message="missing necessary parameter: model_name", code=400))
# Get optional parameters
learning_rate = data.get("learning_rate", None)
number_of_epochs = data.get("number_of_epochs", None)
concurrency_threads = data.get("concurrency_threads", None)
data_synthesis_mode = data.get("data_synthesis_mode", None)
use_cuda = data.get("use_cuda", False)
is_cot = data.get("is_cot", None)
# Log the received parameters
logger.info(f"Retrain parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, use_cuda={use_cuda}, is_cot={is_cot}")
# Create training service instance
train_service = TrainProcessService(current_model_name=model_name)
# Check if there are any in_progress statuses that need to be reset
if train_service.progress.progress.data["status"] == "in_progress":
# Reset the progress and continue
logger.info("There is an existing training process that was interrupted.")
train_service.reset_progress()
# Save training parameters
training_params = {
"model_name": model_name,
"learning_rate": learning_rate,
"number_of_epochs": number_of_epochs,
"concurrency_threads": concurrency_threads,
"data_synthesis_mode": data_synthesis_mode,
"use_cuda": use_cuda,
"is_cot": is_cot
}
params_manager = TrainingParamsManager()
# Update the training parameters, optionally using previous params as base
params_manager.update_training_params(training_params, use_previous_params=False)
# Log training parameters
logger.info(f"Saved training parameters: {training_params}")
thread = Thread(target=train_service.start_process)
thread.daemon = True
thread.start()
return jsonify(
APIResponse.success(
message="Successfully reset progress to data processing stage and started training process",
data={
"model_name": model_name,
"learning_rate": learning_rate,
"number_of_epochs": number_of_epochs,
"concurrency_threads": concurrency_threads,
"data_synthesis_mode": data_synthesis_mode,
"use_cuda": use_cuda,
"is_cot": is_cot
}
)
)
except Exception as e:
logger.error(f"Retrain reset failed: {str(e)}", exc_info=True)
return jsonify(APIResponse.error(message=f"Failed to reset progress to data processing stage: {str(e)}"))
|