File size: 16,212 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import time
from werkzeug.utils import secure_filename
from flask import Blueprint, jsonify, Response, request
from charset_normalizer import from_path

from lpm_kernel.api.domains.trainprocess.trainprocess_service import TrainProcessService
from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager
from ...common.responses import APIResponse
from threading import Thread

from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()

trainprocess_bp = Blueprint("trainprocess", __name__, url_prefix="/api/trainprocess")

@trainprocess_bp.route("/start", methods=["POST"])
def start_process():
    """
    Start training process, returns progress stream ID
    
    Request parameters:
        model_name: Model name
        learning_rate: Learning rate for model training (optional)
        number_of_epochs: Number of training epochs (optional)
        concurrency_threads: Number of threads for concurrent processing (optional)
        data_synthesis_mode: Mode for data synthesis (optional)
        use_cuda: Whether to use CUDA for training (optional)
    
    Includes the following steps:
    1. Health check
    2. Generate L0
    3. Generate document embeddings
    4. Process document chunks
    5. Generate chunk embeddings
    6. Analyze documents
    7. Generate L1
    8. Download model
    9. Prepare data
    10. Train model
    11. Merge weights
    12. Convert model

    Returns:
        Response: JSON response
        {
            "code": 0 for success, non-zero for failure,
            "message": "Error message",
            "data": {
                "progress_id": "Progress stream ID",
                "model_name": "Model name"
            }
        }
    """
    logger.info("Training process starting...")  # Log the startup
    try:
        data = request.get_json()
        if not data or "model_name" not in data:
            return jsonify(APIResponse.error(message="Missing required parameters"))

        model_name = data["model_name"]
        
        # Get optional parameters with default values
        learning_rate = data.get("learning_rate", None)
        number_of_epochs = data.get("number_of_epochs", None)
        concurrency_threads = data.get("concurrency_threads", None)
        data_synthesis_mode = data.get("data_synthesis_mode", None)
        use_cuda = data.get("use_cuda", False)  # Default to False if not provided
        is_cot = data.get("is_cot", None)
        
        # Log the received parameters
        logger.info(f"Training parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, is_cot={is_cot}")

        # Create service instance with model name and additional parameters
        last_train_service = TrainProcessService.get_instance()
        
        # Check if there are any in_progress statuses that need to be reset
        if last_train_service is not None and last_train_service.progress.progress.data["status"] == "in_progress":
            return jsonify(APIResponse.error(
                message="There is an existing training process that was interrupted.",
                code=409  # Conflict status code
            ))
            

        train_service = TrainProcessService(current_model_name=model_name)
        if not train_service.check_training_condition():
            train_service.reset_progress()

        # Save training parameters
        training_params = {
            "model_name": model_name,
            "learning_rate": learning_rate,
            "number_of_epochs": number_of_epochs,
            "concurrency_threads": concurrency_threads,
            "data_synthesis_mode": data_synthesis_mode,
            "use_cuda": use_cuda,  # Make sure to include use_cuda parameter
            "is_cot": is_cot
        }
        
        params_manager = TrainingParamsManager()
        # Update the latest training parameters
        params_manager.update_training_params(training_params)
        
        # Log training parameters
        logger.info(f"Saved training parameters: {training_params}")

        thread = Thread(target=train_service.start_process)
        thread.daemon = True
        thread.start()

        # Return success response with all parameters
        return jsonify(
            APIResponse.success(
                data={
                    "model_name": model_name,
                    "learning_rate": learning_rate,
                    "number_of_epochs": number_of_epochs,
                    "concurrency_threads": concurrency_threads,
                    "data_synthesis_mode": data_synthesis_mode,
                    "use_cuda": use_cuda,  # Include in response
                    "is_cot": is_cot
                }
            )
        )
    
    except Exception as e:
        logger.error(f"Training process failed: {str(e)}")
        return jsonify(APIResponse.error(message=f"Training process error: {str(e)}"))

@trainprocess_bp.route("/logs", methods=["GET"])
def stream_logs():
    """Get training logs in real-time"""
    log_file_path = "logs/train/train.log"  # Log file path
    last_position = 0
    def generate_logs():
        nonlocal last_position
        while True:
            try:
                encoding = from_path(log_file_path).best().encoding
                with open(log_file_path, 'r', encoding=encoding) as log_file:
                    log_file.seek(last_position)
                    new_lines = log_file.readlines()  # Read new lines

                    for line in new_lines:
                        # Skip empty lines
                        if not line.strip():
                            continue
                        
                        yield f"data: {line.strip()}\n\n"
                            
                    last_position = log_file.tell()
                    if not new_lines:
                        yield f":heartbeat\n\n"
            except Exception as e:
                # If file reading fails, record error and continue
                yield f"data: Error reading log file: {str(e)}\n\n"
                
            time.sleep(1)  # Check for new logs every second

    return Response(
        generate_logs(),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache, no-transform',
            'X-Accel-Buffering': 'no',
            'Connection': 'keep-alive',
            'Transfer-Encoding': 'chunked'
        }
    )

@trainprocess_bp.route("/progress/<model_name>", methods=["GET"])
def get_progress(model_name):
    """Get current progress (non-real-time)"""
    sanitized_model_name = secure_filename(model_name)  # Sanitize model_name
    try:
        train_service = TrainProcessService(current_model_name=sanitized_model_name)  # Pass in specific progress file
        progress = train_service.progress.progress

        return jsonify(
            APIResponse.success(
                data=progress.to_dict()  # Return progress data
            )
        )
    except Exception as e:
        logger.error(f"Get progress failed: {str(e)}", exc_info=True)
        return jsonify(APIResponse.error(message=str(e)))

@trainprocess_bp.route("/progress/reset", methods=["POST"])
def reset_progress():
    """
    Reset progress

    Returns:
        Response: JSON response
        {
            "code": 0 for success, non-zero for failure,
            "message": "Error message",
            "data": null
        }
    """
    try:
        train_service = TrainProcessService.get_instance()
        if train_service is not None:
            train_service.progress.reset_progress()
            logger.info("Progress reset successfully")
        else:
            logger.warning("No active training process found")

        return jsonify(APIResponse.success(message="Progress reset successfully"))
    except Exception as e:
        logger.error(f"Reset progress failed: {str(e)}", exc_info=True)
        return jsonify(APIResponse.error(message=f"Failed to reset progress: {str(e)}"))


@trainprocess_bp.route("/stop", methods=["POST"])
def stop_training():
    """Stop training process and wait until status is suspended"""
    try:
        # Get the TrainProcessService instance
        train_service = TrainProcessService.get_instance()  # Need to get instance based on your implementation
        if train_service is None:
            return jsonify(APIResponse.error(message="Failed to stop training: No active training process"))
        
        # Stop the process
        train_service.stop_process()
        
        # Wait for the status to change to SUSPENDED
        wait_interval = 1  # Check interval in seconds
        
        while True:
            # Get the current progress
            progress = train_service.progress.progress
            
            # Check if status is SUSPENDED
            if progress.data["status"] == "suspended" or progress.data["status"] == "failed":
                return jsonify(APIResponse.success(
                    message="Training process has been stopped and status is confirmed as suspended",
                    data={"status": "suspended"}
                ))
            
            # Wait before checking again
            time.sleep(wait_interval)

    except Exception as e:
        logger.error(f"Error stopping training process: {str(e)}", exc_info=True)
        return jsonify(APIResponse.error(message=f"Error stopping training process: {str(e)}"))

@trainprocess_bp.route("/step_output_content", methods=["GET"])
def get_step_output_content():
    """
    Get content of output file for a specific training step
    
    Request parameters:
        step_name: Name of the step to get content for, e.g. 'extract_dimensional_topics'
    
    Returns:
        Response: JSON response
        {
            "code": 0,
            "message": "Success",
            "data": {...}  // Content of the output file, or null if not found
        }
    """
    try:
        # Get TrainProcessService instance
        train_service = TrainProcessService.get_instance()
        if train_service is None:
            logger.error("No active training process found.")
            return jsonify(APIResponse.error(message="No active training process found."))
        
        # Get step name from query parameters
        step_name = request.args.get('step_name')
        if not step_name:
            return jsonify(APIResponse.error(message="Missing required parameter: step_name", code=400))
        
        # Get step output content
        output_content = train_service.get_step_output_content(step_name)  
        logger.info(f"Step output content: {output_content}")      
        return jsonify(APIResponse.success(data=output_content))
    except Exception as e:
        logger.error(f"Failed to get step output content: {str(e)}", exc_info=True)
        return jsonify(APIResponse.error(message=f"Failed to get step output content: {str(e)}"))

@trainprocess_bp.route("/training_params", methods=["GET"])
def get_training_params():
    """
    Get the latest training parameters
    
    Returns:
        Response: JSON response
        {
            "code": 0 for success, non-zero for failure,
            "message": "Error message",
            "data": {
                "model_name": "Model name",
                "learning_rate": "Learning rate",
                "number_of_epochs": "Number of epochs",
                "concurrency_threads": "Concurrency threads",
                "data_synthesis_mode": "Data synthesis mode"
            }
        }
    """
    try:
        # Get the latest training parameters
        params_manager = TrainingParamsManager()
        training_params = params_manager.get_latest_training_params()
        
        return jsonify(APIResponse.success(data=training_params))
    except Exception as e:
        logger.error(f"Error getting training parameters: {str(e)}", exc_info=True)
        return jsonify(APIResponse.error(message=f"Error getting training parameters: {str(e)}"))


@trainprocess_bp.route("/retrain", methods=["POST"])
def retrain():
    """
    Reset progress to data processing stage (data_processing not started) and automatically start the training process
    
    Request parameters:
        model_name: Model name (required)
        learning_rate: Learning rate for model training (optional)
        number_of_epochs: Number of training epochs (optional)
        concurrency_threads: Number of threads for concurrent processing (optional)
        data_synthesis_mode: Mode for data synthesis (optional)
        use_cuda: Whether to use CUDA for training (optional)
        is_cot: Whether to use Chain of Thought (optional)
    
    Returns:
        Response: JSON response
        {
            "code": 0 for success, non-zero for failure,
            "message": "Error message",
            "data": {
                "progress_id": "Progress stream ID",
                "model_name": "Model name"
            }
        }
    """
    try:
        # get request parameters
        data = request.get_json() or {}
        model_name = data.get("model_name")
        
        if not model_name:
            return jsonify(APIResponse.error(message="missing necessary parameter: model_name", code=400))
        
        # Get optional parameters
        learning_rate = data.get("learning_rate", None)
        number_of_epochs = data.get("number_of_epochs", None)
        concurrency_threads = data.get("concurrency_threads", None)
        data_synthesis_mode = data.get("data_synthesis_mode", None)
        use_cuda = data.get("use_cuda", False)
        is_cot = data.get("is_cot", None)
        
        # Log the received parameters
        logger.info(f"Retrain parameters: model_name={model_name}, learning_rate={learning_rate}, number_of_epochs={number_of_epochs}, concurrency_threads={concurrency_threads}, data_synthesis_mode={data_synthesis_mode}, use_cuda={use_cuda}, is_cot={is_cot}")
        
        # Create training service instance
        train_service = TrainProcessService(current_model_name=model_name)
        
        # Check if there are any in_progress statuses that need to be reset
        if train_service.progress.progress.data["status"] == "in_progress":
            # Reset the progress and continue
            logger.info("There is an existing training process that was interrupted.")
            
        train_service.reset_progress()

        # Save training parameters
        training_params = {
            "model_name": model_name,
            "learning_rate": learning_rate,
            "number_of_epochs": number_of_epochs,
            "concurrency_threads": concurrency_threads,
            "data_synthesis_mode": data_synthesis_mode,
            "use_cuda": use_cuda,
            "is_cot": is_cot
        }
        
        params_manager = TrainingParamsManager()
        # Update the training parameters, optionally using previous params as base
        params_manager.update_training_params(training_params, use_previous_params=False)
        
        # Log training parameters
        logger.info(f"Saved training parameters: {training_params}")

        thread = Thread(target=train_service.start_process)
        thread.daemon = True
        thread.start()
        
        return jsonify(
            APIResponse.success(
                message="Successfully reset progress to data processing stage and started training process",
                data={
                    "model_name": model_name,
                    "learning_rate": learning_rate,
                    "number_of_epochs": number_of_epochs,
                    "concurrency_threads": concurrency_threads,
                    "data_synthesis_mode": data_synthesis_mode,
                    "use_cuda": use_cuda,
                    "is_cot": is_cot
                }
            )
        )
    except Exception as e:
        logger.error(f"Retrain reset failed: {str(e)}", exc_info=True)
        return jsonify(APIResponse.error(message=f"Failed to reset progress to data processing stage: {str(e)}"))