Lora-trainer / training_queue.py
Allex21's picture
Upload 24 files
5bb2330 verified
import threading
import queue
import time
import logging
from typing import Optional
from datetime import datetime
class TrainingQueue:
"""Simple in-memory training queue for LoRA tasks"""
def __init__(self):
self.task_queue = queue.Queue()
self.current_task: Optional[int] = None
self.worker_thread: Optional[threading.Thread] = None
self.is_running = False
self.logger = logging.getLogger(__name__)
# Start worker thread
self.start_worker()
def start_worker(self):
"""Start the worker thread"""
if self.worker_thread is None or not self.worker_thread.is_alive():
self.is_running = True
self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self.worker_thread.start()
self.logger.info("Training queue worker started")
def stop_worker(self):
"""Stop the worker thread"""
self.is_running = False
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=5)
self.logger.info("Training queue worker stopped")
def add_task(self, project_id: int):
"""Add a training task to the queue"""
self.task_queue.put(project_id)
self.logger.info(f"Added project {project_id} to training queue")
def get_queue_status(self):
"""Get current queue status"""
return {
'queue_size': self.task_queue.qsize(),
'current_task': self.current_task,
'is_running': self.is_running
}
def _worker_loop(self):
"""Main worker loop that processes training tasks"""
from src.services.lora_trainer import LoRATrainer
trainer = LoRATrainer()
while self.is_running:
try:
# Get next task from queue (with timeout to allow checking is_running)
try:
project_id = self.task_queue.get(timeout=1)
except queue.Empty:
continue
self.current_task = project_id
self.logger.info(f"Starting training for project {project_id}")
# Process the training task
try:
trainer.train_project(project_id)
self.logger.info(f"Completed training for project {project_id}")
except Exception as e:
self.logger.error(f"Training failed for project {project_id}: {str(e)}")
finally:
self.current_task = None
self.task_queue.task_done()
except Exception as e:
self.logger.error(f"Error in worker loop: {str(e)}")
time.sleep(1) # Prevent tight loop on persistent errors