| """
|
| Evaluation Worker / Consumer for Job Queue
|
|
|
| Provides worker functionality for processing evaluation jobs from the queue.
|
| Handles job execution, checkpointing, and progress reporting.
|
| """
|
|
|
| import asyncio
|
| import os
|
| import socket
|
| import uuid
|
| from dataclasses import dataclass
|
| from datetime import datetime
|
| from typing import Optional
|
|
|
| from backend.core.config import settings
|
| from backend.logging.logger import get_logger
|
|
|
| from .job_schema import (
|
| JobProgressUpdate,
|
| JobStatus,
|
| EvaluationJob,
|
| )
|
| from .producer import _job_queue, get_job_producer
|
| from .status_tracker import get_status_tracker
|
|
|
|
|
| logger = get_logger("queue.consumer", component="queue")
|
|
|
|
|
| @dataclass
|
| class WorkerConfig:
|
| """Configuration for evaluation worker."""
|
| worker_id: str
|
| max_concurrent_jobs: int = 1
|
| job_timeout_seconds: int = 3600
|
| heartbeat_interval_seconds: int = 30
|
| enable_checkpointing: bool = True
|
|
|
|
|
| class EvaluationWorker:
|
| """
|
| Worker that processes evaluation jobs from the queue.
|
|
|
| Responsibilities:
|
| - Poll queue for new jobs
|
| - Execute evaluation jobs
|
| - Handle checkpointing
|
| - Report progress
|
| - Manage job lifecycle
|
| """
|
|
|
| def __init__(self, config: Optional[WorkerConfig] = None):
|
| self.config = config or WorkerConfig(
|
| worker_id=f"worker-{socket.gethostname()}-{os.getpid()}",
|
| )
|
| self._status_tracker = get_status_tracker()
|
| self._producer = get_job_producer()
|
| self._active_jobs: dict[uuid.UUID, asyncio.Task] = {}
|
| self._running = False
|
| self._current_job: Optional[EvaluationJob] = None
|
|
|
| async def start(self) -> None:
|
| """Start the worker."""
|
| self._running = True
|
| logger.info(
|
| "Worker started",
|
| worker_id=self.config.worker_id,
|
| max_concurrent=self.config.max_concurrent_jobs,
|
| )
|
|
|
| while self._running:
|
| try:
|
|
|
| await self._poll_and_process()
|
|
|
|
|
| await asyncio.sleep(1)
|
|
|
| except asyncio.CancelledError:
|
| logger.info("Worker cancelled", worker_id=self.config.worker_id)
|
| break
|
| except Exception as e:
|
| logger.error(
|
| "Worker error",
|
| worker_id=self.config.worker_id,
|
| error=str(e),
|
| )
|
| await asyncio.sleep(5)
|
|
|
|
|
| for job_id, task in self._active_jobs.items():
|
| if not task.done():
|
| task.cancel()
|
| logger.info("Cancelled active job", job_id=str(job_id))
|
|
|
| logger.info("Worker stopped", worker_id=self.config.worker_id)
|
|
|
| async def stop(self) -> None:
|
| """Stop the worker."""
|
| self._running = False
|
|
|
| async def _poll_and_process(self) -> None:
|
| """Poll queue and process available jobs."""
|
|
|
| if len(self._active_jobs) >= self.config.max_concurrent_jobs:
|
| return
|
|
|
|
|
| for job in _job_queue:
|
| if job.status == JobStatus.QUEUED:
|
|
|
| if job.job_id in self._active_jobs:
|
| continue
|
|
|
|
|
| await self._process_job(job)
|
| break
|
|
|
| async def _process_job(self, job: EvaluationJob) -> None:
|
| """Process a single evaluation job."""
|
| job_id_str = str(job.job_id)
|
|
|
| try:
|
|
|
| self._current_job = job
|
| await self._status_tracker.start_job(job.job_id, self.config.worker_id)
|
| job.status = JobStatus.RUNNING
|
| job.started_at = datetime.utcnow()
|
|
|
| logger.info(
|
| "Processing job",
|
| job_id=job_id_str,
|
| worker_id=self.config.worker_id,
|
| model=job.model_name,
|
| )
|
|
|
|
|
| task = asyncio.create_task(self._execute_job(job))
|
| self._active_jobs[job.job_id] = task
|
|
|
|
|
| await task
|
|
|
|
|
| logger.info(
|
| "Job completed",
|
| job_id=job_id_str,
|
| worker_id=self.config.worker_id,
|
| )
|
|
|
| except asyncio.CancelledError:
|
| logger.info("Job cancelled", job_id=job_id_str)
|
| await self._status_tracker.fail_job(
|
| job.job_id,
|
| "Job cancelled by worker",
|
| )
|
| except Exception as e:
|
| logger.error(
|
| "Job failed",
|
| job_id=job_id_str,
|
| error=str(e),
|
| )
|
| await self._status_tracker.fail_job(
|
| job.job_id,
|
| str(e),
|
| )
|
| finally:
|
|
|
| self._active_jobs.pop(job.job_id, None)
|
| self._current_job = None
|
|
|
|
|
| _job_queue[:] = [j for j in _job_queue if j.job_id != job.job_id]
|
|
|
| async def _execute_job(self, job: EvaluationJob) -> None:
|
| """Execute the evaluation job."""
|
|
|
| from backend.core.orchestrator import (
|
| EvaluationInput,
|
| EvaluationOrchestrator,
|
| )
|
|
|
|
|
| metadata = job.metadata or {}
|
| mutation_depth = metadata.get("mutation_depth", 2)
|
| attack_types = metadata.get("attack_types", ["jailbreak"])
|
| max_concurrency = metadata.get("max_concurrency", 4)
|
|
|
|
|
| eval_input = EvaluationInput(
|
| model_name=job.model_name,
|
| model_version=job.model_version,
|
| dataset_name=job.dataset_name,
|
| dataset_version=job.dataset_version,
|
| mutation_depth=mutation_depth,
|
| attack_types=attack_types,
|
| max_concurrency=max_concurrency,
|
| )
|
|
|
|
|
| orchestrator = EvaluationOrchestrator()
|
|
|
|
|
| checkpoint_interval = job.checkpoint_interval
|
| completed_samples = 0
|
| failed_samples = 0
|
|
|
|
|
|
|
|
|
|
|
| output = await orchestrator.start_run(eval_input)
|
|
|
|
|
|
|
|
|
|
|
| await self._status_tracker.complete_job(
|
| job.job_id,
|
| output.composite_score,
|
| output.metrics,
|
| )
|
|
|
| job.status = JobStatus.COMPLETED
|
| job.completed_at = datetime.utcnow()
|
| job.composite_score = output.composite_score
|
| job.metrics = output.metrics
|
| job.progress = 100.0
|
|
|
|
|
| if output.metrics:
|
| job.total_samples = output.metrics.get("total_samples", 0)
|
| job.completed_samples = output.metrics.get("successful_samples", 0)
|
| job.failed_samples = output.metrics.get("failed_samples", 0)
|
|
|
| async def get_current_job_status(self) -> Optional[dict]:
|
| """Get status of current job being processed."""
|
| if self._current_job is None:
|
| return None
|
|
|
| job = self._current_job
|
| return {
|
| "job_id": str(job.job_id),
|
| "status": job.status.value,
|
| "progress": job.progress,
|
| "completed_samples": job.completed_samples,
|
| "total_samples": job.total_samples,
|
| }
|
|
|
| def get_worker_status(self) -> dict:
|
| """Get worker status."""
|
| return {
|
| "worker_id": self.config.worker_id,
|
| "running": self._running,
|
| "active_jobs": len(self._active_jobs),
|
| "max_concurrent_jobs": self.config.max_concurrent_jobs,
|
| }
|
|
|
|
|
|
|
| _worker: Optional[EvaluationWorker] = None
|
|
|
|
|
| def get_worker(config: Optional[WorkerConfig] = None) -> EvaluationWorker:
|
| """Get the global worker instance."""
|
| global _worker
|
| if _worker is None:
|
| _worker = EvaluationWorker(config)
|
| return _worker
|
|
|
|
|
| async def start_worker() -> EvaluationWorker:
|
| """Start the worker and return it."""
|
| worker = get_worker()
|
| asyncio.create_task(worker.start())
|
| return worker
|
|
|
|
|
| async def stop_worker() -> None:
|
| """Stop the worker."""
|
| global _worker
|
| if _worker is not None:
|
| await _worker.stop()
|
| _worker = None
|
|
|
|
|
| __all__ = [
|
| "WorkerConfig",
|
| "EvaluationWorker",
|
| "get_worker",
|
| "start_worker",
|
| "stop_worker",
|
| ]
|
|
|