""" Evaluation Worker / Consumer for Job Queue Provides worker functionality for processing evaluation jobs from the queue. Handles job execution, checkpointing, and progress reporting. """ import asyncio import os import socket import uuid from dataclasses import dataclass from datetime import datetime from typing import Optional from backend.core.config import settings from backend.logging.logger import get_logger from .job_schema import ( JobProgressUpdate, JobStatus, EvaluationJob, ) from .producer import _job_queue, get_job_producer from .status_tracker import get_status_tracker logger = get_logger("queue.consumer", component="queue") @dataclass class WorkerConfig: """Configuration for evaluation worker.""" worker_id: str max_concurrent_jobs: int = 1 job_timeout_seconds: int = 3600 heartbeat_interval_seconds: int = 30 enable_checkpointing: bool = True class EvaluationWorker: """ Worker that processes evaluation jobs from the queue. Responsibilities: - Poll queue for new jobs - Execute evaluation jobs - Handle checkpointing - Report progress - Manage job lifecycle """ def __init__(self, config: Optional[WorkerConfig] = None): self.config = config or WorkerConfig( worker_id=f"worker-{socket.gethostname()}-{os.getpid()}", ) self._status_tracker = get_status_tracker() self._producer = get_job_producer() self._active_jobs: dict[uuid.UUID, asyncio.Task] = {} self._running = False self._current_job: Optional[EvaluationJob] = None async def start(self) -> None: """Start the worker.""" self._running = True logger.info( "Worker started", worker_id=self.config.worker_id, max_concurrent=self.config.max_concurrent_jobs, ) while self._running: try: # Poll for jobs await self._poll_and_process() # Brief sleep to prevent CPU spinning await asyncio.sleep(1) except asyncio.CancelledError: logger.info("Worker cancelled", worker_id=self.config.worker_id) break except Exception as e: logger.error( "Worker error", worker_id=self.config.worker_id, error=str(e), ) await asyncio.sleep(5) # Cancel active jobs for job_id, task in self._active_jobs.items(): if not task.done(): task.cancel() logger.info("Cancelled active job", job_id=str(job_id)) logger.info("Worker stopped", worker_id=self.config.worker_id) async def stop(self) -> None: """Stop the worker.""" self._running = False async def _poll_and_process(self) -> None: """Poll queue and process available jobs.""" # Check if we can accept more jobs if len(self._active_jobs) >= self.config.max_concurrent_jobs: return # Find a queued job for job in _job_queue: if job.status == JobStatus.QUEUED: # Check if already being processed if job.job_id in self._active_jobs: continue # Start processing the job await self._process_job(job) break async def _process_job(self, job: EvaluationJob) -> None: """Process a single evaluation job.""" job_id_str = str(job.job_id) try: # Mark job as started self._current_job = job await self._status_tracker.start_job(job.job_id, self.config.worker_id) job.status = JobStatus.RUNNING job.started_at = datetime.utcnow() logger.info( "Processing job", job_id=job_id_str, worker_id=self.config.worker_id, model=job.model_name, ) # Create task for async processing task = asyncio.create_task(self._execute_job(job)) self._active_jobs[job.job_id] = task # Wait for completion await task # Job completed successfully logger.info( "Job completed", job_id=job_id_str, worker_id=self.config.worker_id, ) except asyncio.CancelledError: logger.info("Job cancelled", job_id=job_id_str) await self._status_tracker.fail_job( job.job_id, "Job cancelled by worker", ) except Exception as e: logger.error( "Job failed", job_id=job_id_str, error=str(e), ) await self._status_tracker.fail_job( job.job_id, str(e), ) finally: # Clean up self._active_jobs.pop(job.job_id, None) self._current_job = None # Remove from queue _job_queue[:] = [j for j in _job_queue if j.job_id != job.job_id] async def _execute_job(self, job: EvaluationJob) -> None: """Execute the evaluation job.""" # Import orchestrator here to avoid circular imports from backend.core.orchestrator import ( EvaluationInput, EvaluationOrchestrator, ) # Get metadata metadata = job.metadata or {} mutation_depth = metadata.get("mutation_depth", 2) attack_types = metadata.get("attack_types", ["jailbreak"]) max_concurrency = metadata.get("max_concurrency", 4) # Create evaluation input eval_input = EvaluationInput( model_name=job.model_name, model_version=job.model_version, dataset_name=job.dataset_name, dataset_version=job.dataset_version, mutation_depth=mutation_depth, attack_types=attack_types, max_concurrency=max_concurrency, ) # Create orchestrator orchestrator = EvaluationOrchestrator() # Track progress for checkpointing checkpoint_interval = job.checkpoint_interval completed_samples = 0 failed_samples = 0 # For checkpointing - we need to hook into the orchestrator # This is a simplified version - in production, you'd have more sophisticated checkpointing # Run evaluation output = await orchestrator.start_run(eval_input) # Wait for completion (the orchestrator runs asynchronously) # In a real implementation, we'd need to track progress periodically # Mark job as complete await self._status_tracker.complete_job( job.job_id, output.composite_score, output.metrics, ) job.status = JobStatus.COMPLETED job.completed_at = datetime.utcnow() job.composite_score = output.composite_score job.metrics = output.metrics job.progress = 100.0 # Update total/completed samples if output.metrics: job.total_samples = output.metrics.get("total_samples", 0) job.completed_samples = output.metrics.get("successful_samples", 0) job.failed_samples = output.metrics.get("failed_samples", 0) async def get_current_job_status(self) -> Optional[dict]: """Get status of current job being processed.""" if self._current_job is None: return None job = self._current_job return { "job_id": str(job.job_id), "status": job.status.value, "progress": job.progress, "completed_samples": job.completed_samples, "total_samples": job.total_samples, } def get_worker_status(self) -> dict: """Get worker status.""" return { "worker_id": self.config.worker_id, "running": self._running, "active_jobs": len(self._active_jobs), "max_concurrent_jobs": self.config.max_concurrent_jobs, } # Global worker instance _worker: Optional[EvaluationWorker] = None def get_worker(config: Optional[WorkerConfig] = None) -> EvaluationWorker: """Get the global worker instance.""" global _worker if _worker is None: _worker = EvaluationWorker(config) return _worker async def start_worker() -> EvaluationWorker: """Start the worker and return it.""" worker = get_worker() asyncio.create_task(worker.start()) return worker async def stop_worker() -> None: """Stop the worker.""" global _worker if _worker is not None: await _worker.stop() _worker = None __all__ = [ "WorkerConfig", "EvaluationWorker", "get_worker", "start_worker", "stop_worker", ]