""" Status Tracker for Evaluation Jobs Provides job status tracking with database persistence. Handles job state transitions and progress updates. """ import uuid from datetime import datetime from typing import Dict, Optional from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from backend.db.models import EvaluationRun from backend.db.session import get_db_context from backend.logging.logger import get_logger from .job_schema import ( JobProgressUpdate, JobStatus, JobType, JobPriority, EvaluationJob, JobStatusResponse, ) logger = get_logger("queue.status_tracker", component="queue") class JobStatusTracker: """ Tracks job status and manages state transitions. Provides methods for: - Creating new jobs - Updating job progress - Querying job status - Managing job lifecycle """ def __init__(self): self._cache: Dict[str, EvaluationJob] = {} async def create_job( self, job: EvaluationJob, ) -> EvaluationJob: """ Create a new job in the database. Args: job: The job to create Returns: The created job """ try: async with get_db_context() as session: # Create evaluation run record run = EvaluationRun( id=job.job_id, model_name=job.model_name, model_version=job.model_version, dataset_version=job.dataset_version, status=job.status.value, config_hash=job.config_hash, ) session.add(run) await session.commit() logger.info( "Job created", job_id=str(job.job_id), job_type=job.job_type, priority=job.priority, ) # Cache the job self._cache[str(job.job_id)] = job return job except Exception as e: logger.error( "Failed to create job", job_id=str(job.job_id), error=str(e), ) raise async def update_job_status( self, job_id: uuid.UUID, status: JobStatus, error: Optional[str] = None, ) -> None: """ Update job status in database. Args: job_id: The job ID status: New status error: Optional error message """ try: async with get_db_context() as session: stmt = ( update(EvaluationRun) .where(EvaluationRun.id == job_id) .values(status=status.value) ) await session.execute(stmt) await session.commit() # Update cache job_id_str = str(job_id) if job_id_str in self._cache: self._cache[job_id_str].status = status if error: self._cache[job_id_str].error = error logger.info( "Job status updated", job_id=job_id_str, status=status.value, ) except Exception as e: logger.error( "Failed to update job status", job_id=str(job_id), error=str(e), ) raise async def update_job_progress( self, progress_update: JobProgressUpdate, ) -> None: """ Update job progress in database. Args: progress_update: The progress update """ try: job_id_str = str(progress_update.job_id) # Update cached job if job_id_str in self._cache: job = self._cache[job_id_str] job.completed_samples = progress_update.completed_samples job.failed_samples = progress_update.failed_samples if progress_update.composite_score is not None: job.composite_score = progress_update.composite_score if progress_update.metrics is not None: job.metrics = progress_update.metrics # Calculate progress percentage if job.total_samples > 0: total_done = progress_update.completed_samples + progress_update.failed_samples job.progress = (total_done / job.total_samples) * 100 job.last_checkpoint_at = progress_update.checkpoint_at # Update database async with get_db_context() as session: # Update composite score if available stmt = select(EvaluationRun).where( EvaluationRun.id == progress_update.job_id ) result = await session.execute(stmt) run = result.scalar_one_or_none() if run: if progress_update.composite_score is not None: run.composite_score = progress_update.composite_score await session.commit() logger.debug( "Job progress updated", job_id=job_id_str, completed=progress_update.completed_samples, failed=progress_update.failed_samples, ) except Exception as e: logger.error( "Failed to update job progress", job_id=str(progress_update.job_id), error=str(e), ) raise async def get_job_status( self, job_id: uuid.UUID, ) -> Optional[JobStatusResponse]: """ Get job status from database. Args: job_id: The job ID Returns: Job status response or None if not found """ # Check cache first job_id_str = str(job_id) if job_id_str in self._cache: job = self._cache[job_id_str] return JobStatusResponse( job_id=job.job_id, job_type=JobType(job.job_type), status=JobStatus(job.status), progress=job.progress, total_samples=job.total_samples, completed_samples=job.completed_samples, failed_samples=job.failed_samples, composite_score=job.composite_score, metrics=job.metrics, error=job.error, created_at=job.created_at, started_at=job.started_at, completed_at=job.completed_at, worker_id=job.worker_id, ) # Fetch from database try: async with get_db_context() as session: stmt = select(EvaluationRun).where( EvaluationRun.id == job_id ) result = await session.execute(stmt) run = result.scalar_one_or_none() if run is None: return None # Build response from DB return JobStatusResponse( job_id=run.id, job_type=JobType.BENCHMARK, status=JobStatus(run.status), progress=0.0, total_samples=0, completed_samples=0, failed_samples=0, composite_score=run.composite_score, metrics=None, error=None, created_at=run.timestamp, started_at=None, completed_at=None, worker_id=None, ) except Exception as e: logger.error( "Failed to get job status", job_id=str(job_id), error=str(e), ) return None async def complete_job( self, job_id: uuid.UUID, composite_score: Optional[float] = None, metrics: Optional[dict] = None, ) -> None: """ Mark job as completed. Args: job_id: The job ID composite_score: Final composite score metrics: Final metrics """ try: job_id_str = str(job_id) async with get_db_context() as session: stmt = ( update(EvaluationRun) .where(EvaluationRun.id == job_id) .values( status=JobStatus.COMPLETED.value, composite_score=composite_score, ) ) await session.execute(stmt) await session.commit() # Update cache if job_id_str in self._cache: job = self._cache[job_id_str] job.status = JobStatus.COMPLETED job.composite_score = composite_score job.metrics = metrics job.completed_at = datetime.utcnow() job.progress = 100.0 logger.info( "Job completed", job_id=job_id_str, composite_score=composite_score, ) except Exception as e: logger.error( "Failed to complete job", job_id=str(job_id), error=str(e), ) raise async def fail_job( self, job_id: uuid.UUID, error: str, error_details: Optional[dict] = None, ) -> None: """ Mark job as failed. Args: job_id: The job ID error: Error message error_details: Optional error details """ try: job_id_str = str(job_id) async with get_db_context() as session: stmt = ( update(EvaluationRun) .where(EvaluationRun.id == job_id) .values( status=JobStatus.FAILED.value, ) ) await session.execute(stmt) await session.commit() # Update cache if job_id_str in self._cache: job = self._cache[job_id_str] job.status = JobStatus.FAILED job.error = error job.error_details = error_details job.completed_at = datetime.utcnow() logger.error( "Job failed", job_id=job_id_str, error=error, ) except Exception as e: logger.error( "Failed to mark job as failed", job_id=str(job_id), error=str(e), ) raise async def start_job( self, job_id: uuid.UUID, worker_id: str, ) -> None: """ Mark job as started. Args: job_id: The job ID worker_id: ID of the worker starting the job """ try: job_id_str = str(job_id) now = datetime.utcnow() async with get_db_context() as session: stmt = ( update(EvaluationRun) .where(EvaluationRun.id == job_id) .values( status=JobStatus.RUNNING.value, ) ) await session.execute(stmt) await session.commit() # Update cache if job_id_str in self._cache: job = self._cache[job_id_str] job.status = JobStatus.RUNNING job.worker_id = worker_id job.started_at = now logger.info( "Job started", job_id=job_id_str, worker_id=worker_id, ) except Exception as e: logger.error( "Failed to start job", job_id=str(job_id), error=str(e), ) raise def get_cached_job(self, job_id: uuid.UUID) -> Optional[EvaluationJob]: """Get job from cache.""" return self._cache.get(str(job_id)) def set_cached_job(self, job: EvaluationJob) -> None: """Set job in cache.""" self._cache[str(job.job_id)] = job # Global instance status_tracker = JobStatusTracker() def get_status_tracker() -> JobStatusTracker: """Get the global status tracker instance.""" return status_tracker __all__ = [ "JobStatusTracker", "status_tracker", "get_status_tracker", "JobProgressUpdate", ]