Spaces:
No application file
No application file
File size: 2,920 Bytes
5bb2330 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import threading
import queue
import time
import logging
from typing import Optional
from datetime import datetime
class TrainingQueue:
"""Simple in-memory training queue for LoRA tasks"""
def __init__(self):
self.task_queue = queue.Queue()
self.current_task: Optional[int] = None
self.worker_thread: Optional[threading.Thread] = None
self.is_running = False
self.logger = logging.getLogger(__name__)
# Start worker thread
self.start_worker()
def start_worker(self):
"""Start the worker thread"""
if self.worker_thread is None or not self.worker_thread.is_alive():
self.is_running = True
self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self.worker_thread.start()
self.logger.info("Training queue worker started")
def stop_worker(self):
"""Stop the worker thread"""
self.is_running = False
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=5)
self.logger.info("Training queue worker stopped")
def add_task(self, project_id: int):
"""Add a training task to the queue"""
self.task_queue.put(project_id)
self.logger.info(f"Added project {project_id} to training queue")
def get_queue_status(self):
"""Get current queue status"""
return {
'queue_size': self.task_queue.qsize(),
'current_task': self.current_task,
'is_running': self.is_running
}
def _worker_loop(self):
"""Main worker loop that processes training tasks"""
from src.services.lora_trainer import LoRATrainer
trainer = LoRATrainer()
while self.is_running:
try:
# Get next task from queue (with timeout to allow checking is_running)
try:
project_id = self.task_queue.get(timeout=1)
except queue.Empty:
continue
self.current_task = project_id
self.logger.info(f"Starting training for project {project_id}")
# Process the training task
try:
trainer.train_project(project_id)
self.logger.info(f"Completed training for project {project_id}")
except Exception as e:
self.logger.error(f"Training failed for project {project_id}: {str(e)}")
finally:
self.current_task = None
self.task_queue.task_done()
except Exception as e:
self.logger.error(f"Error in worker loop: {str(e)}")
time.sleep(1) # Prevent tight loop on persistent errors
|