File size: 2,920 Bytes
5bb2330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import threading
import queue
import time
import logging
from typing import Optional
from datetime import datetime

class TrainingQueue:
    """Simple in-memory training queue for LoRA tasks"""
    
    def __init__(self):
        self.task_queue = queue.Queue()
        self.current_task: Optional[int] = None
        self.worker_thread: Optional[threading.Thread] = None
        self.is_running = False
        self.logger = logging.getLogger(__name__)
        
        # Start worker thread
        self.start_worker()
    
    def start_worker(self):
        """Start the worker thread"""
        if self.worker_thread is None or not self.worker_thread.is_alive():
            self.is_running = True
            self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
            self.worker_thread.start()
            self.logger.info("Training queue worker started")
    
    def stop_worker(self):
        """Stop the worker thread"""
        self.is_running = False
        if self.worker_thread and self.worker_thread.is_alive():
            self.worker_thread.join(timeout=5)
            self.logger.info("Training queue worker stopped")
    
    def add_task(self, project_id: int):
        """Add a training task to the queue"""
        self.task_queue.put(project_id)
        self.logger.info(f"Added project {project_id} to training queue")
    
    def get_queue_status(self):
        """Get current queue status"""
        return {
            'queue_size': self.task_queue.qsize(),
            'current_task': self.current_task,
            'is_running': self.is_running
        }
    
    def _worker_loop(self):
        """Main worker loop that processes training tasks"""
        from src.services.lora_trainer import LoRATrainer
        
        trainer = LoRATrainer()
        
        while self.is_running:
            try:
                # Get next task from queue (with timeout to allow checking is_running)
                try:
                    project_id = self.task_queue.get(timeout=1)
                except queue.Empty:
                    continue
                
                self.current_task = project_id
                self.logger.info(f"Starting training for project {project_id}")
                
                # Process the training task
                try:
                    trainer.train_project(project_id)
                    self.logger.info(f"Completed training for project {project_id}")
                except Exception as e:
                    self.logger.error(f"Training failed for project {project_id}: {str(e)}")
                finally:
                    self.current_task = None
                    self.task_queue.task_done()
                    
            except Exception as e:
                self.logger.error(f"Error in worker loop: {str(e)}")
                time.sleep(1)  # Prevent tight loop on persistent errors