| """
|
| Status Tracker for Evaluation Jobs
|
|
|
| Provides job status tracking with database persistence.
|
| Handles job state transitions and progress updates.
|
| """
|
|
|
| import uuid
|
| from datetime import datetime
|
| from typing import Dict, Optional
|
|
|
| from sqlalchemy import select, update
|
| from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
| from backend.db.models import EvaluationRun
|
| from backend.db.session import get_db_context
|
| from backend.logging.logger import get_logger
|
|
|
| from .job_schema import (
|
| JobProgressUpdate,
|
| JobStatus,
|
| JobType,
|
| JobPriority,
|
| EvaluationJob,
|
| JobStatusResponse,
|
| )
|
|
|
|
|
| logger = get_logger("queue.status_tracker", component="queue")
|
|
|
|
|
| class JobStatusTracker:
|
| """
|
| Tracks job status and manages state transitions.
|
|
|
| Provides methods for:
|
| - Creating new jobs
|
| - Updating job progress
|
| - Querying job status
|
| - Managing job lifecycle
|
| """
|
|
|
| def __init__(self):
|
| self._cache: Dict[str, EvaluationJob] = {}
|
|
|
| async def create_job(
|
| self,
|
| job: EvaluationJob,
|
| ) -> EvaluationJob:
|
| """
|
| Create a new job in the database.
|
|
|
| Args:
|
| job: The job to create
|
|
|
| Returns:
|
| The created job
|
| """
|
| try:
|
| async with get_db_context() as session:
|
|
|
| run = EvaluationRun(
|
| id=job.job_id,
|
| model_name=job.model_name,
|
| model_version=job.model_version,
|
| dataset_version=job.dataset_version,
|
| status=job.status.value,
|
| config_hash=job.config_hash,
|
| )
|
| session.add(run)
|
| await session.commit()
|
|
|
| logger.info(
|
| "Job created",
|
| job_id=str(job.job_id),
|
| job_type=job.job_type,
|
| priority=job.priority,
|
| )
|
|
|
|
|
| self._cache[str(job.job_id)] = job
|
|
|
| return job
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to create job",
|
| job_id=str(job.job_id),
|
| error=str(e),
|
| )
|
| raise
|
|
|
| async def update_job_status(
|
| self,
|
| job_id: uuid.UUID,
|
| status: JobStatus,
|
| error: Optional[str] = None,
|
| ) -> None:
|
| """
|
| Update job status in database.
|
|
|
| Args:
|
| job_id: The job ID
|
| status: New status
|
| error: Optional error message
|
| """
|
| try:
|
| async with get_db_context() as session:
|
| stmt = (
|
| update(EvaluationRun)
|
| .where(EvaluationRun.id == job_id)
|
| .values(status=status.value)
|
| )
|
| await session.execute(stmt)
|
| await session.commit()
|
|
|
|
|
| job_id_str = str(job_id)
|
| if job_id_str in self._cache:
|
| self._cache[job_id_str].status = status
|
| if error:
|
| self._cache[job_id_str].error = error
|
|
|
| logger.info(
|
| "Job status updated",
|
| job_id=job_id_str,
|
| status=status.value,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to update job status",
|
| job_id=str(job_id),
|
| error=str(e),
|
| )
|
| raise
|
|
|
| async def update_job_progress(
|
| self,
|
| progress_update: JobProgressUpdate,
|
| ) -> None:
|
| """
|
| Update job progress in database.
|
|
|
| Args:
|
| progress_update: The progress update
|
| """
|
| try:
|
| job_id_str = str(progress_update.job_id)
|
|
|
|
|
| if job_id_str in self._cache:
|
| job = self._cache[job_id_str]
|
| job.completed_samples = progress_update.completed_samples
|
| job.failed_samples = progress_update.failed_samples
|
|
|
| if progress_update.composite_score is not None:
|
| job.composite_score = progress_update.composite_score
|
| if progress_update.metrics is not None:
|
| job.metrics = progress_update.metrics
|
|
|
|
|
| if job.total_samples > 0:
|
| total_done = progress_update.completed_samples + progress_update.failed_samples
|
| job.progress = (total_done / job.total_samples) * 100
|
|
|
| job.last_checkpoint_at = progress_update.checkpoint_at
|
|
|
|
|
| async with get_db_context() as session:
|
|
|
| stmt = select(EvaluationRun).where(
|
| EvaluationRun.id == progress_update.job_id
|
| )
|
| result = await session.execute(stmt)
|
| run = result.scalar_one_or_none()
|
|
|
| if run:
|
| if progress_update.composite_score is not None:
|
| run.composite_score = progress_update.composite_score
|
| await session.commit()
|
|
|
| logger.debug(
|
| "Job progress updated",
|
| job_id=job_id_str,
|
| completed=progress_update.completed_samples,
|
| failed=progress_update.failed_samples,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to update job progress",
|
| job_id=str(progress_update.job_id),
|
| error=str(e),
|
| )
|
| raise
|
|
|
| async def get_job_status(
|
| self,
|
| job_id: uuid.UUID,
|
| ) -> Optional[JobStatusResponse]:
|
| """
|
| Get job status from database.
|
|
|
| Args:
|
| job_id: The job ID
|
|
|
| Returns:
|
| Job status response or None if not found
|
| """
|
|
|
| job_id_str = str(job_id)
|
| if job_id_str in self._cache:
|
| job = self._cache[job_id_str]
|
| return JobStatusResponse(
|
| job_id=job.job_id,
|
| job_type=JobType(job.job_type),
|
| status=JobStatus(job.status),
|
| progress=job.progress,
|
| total_samples=job.total_samples,
|
| completed_samples=job.completed_samples,
|
| failed_samples=job.failed_samples,
|
| composite_score=job.composite_score,
|
| metrics=job.metrics,
|
| error=job.error,
|
| created_at=job.created_at,
|
| started_at=job.started_at,
|
| completed_at=job.completed_at,
|
| worker_id=job.worker_id,
|
| )
|
|
|
|
|
| try:
|
| async with get_db_context() as session:
|
| stmt = select(EvaluationRun).where(
|
| EvaluationRun.id == job_id
|
| )
|
| result = await session.execute(stmt)
|
| run = result.scalar_one_or_none()
|
|
|
| if run is None:
|
| return None
|
|
|
|
|
| return JobStatusResponse(
|
| job_id=run.id,
|
| job_type=JobType.BENCHMARK,
|
| status=JobStatus(run.status),
|
| progress=0.0,
|
| total_samples=0,
|
| completed_samples=0,
|
| failed_samples=0,
|
| composite_score=run.composite_score,
|
| metrics=None,
|
| error=None,
|
| created_at=run.timestamp,
|
| started_at=None,
|
| completed_at=None,
|
| worker_id=None,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to get job status",
|
| job_id=str(job_id),
|
| error=str(e),
|
| )
|
| return None
|
|
|
| async def complete_job(
|
| self,
|
| job_id: uuid.UUID,
|
| composite_score: Optional[float] = None,
|
| metrics: Optional[dict] = None,
|
| ) -> None:
|
| """
|
| Mark job as completed.
|
|
|
| Args:
|
| job_id: The job ID
|
| composite_score: Final composite score
|
| metrics: Final metrics
|
| """
|
| try:
|
| job_id_str = str(job_id)
|
|
|
| async with get_db_context() as session:
|
| stmt = (
|
| update(EvaluationRun)
|
| .where(EvaluationRun.id == job_id)
|
| .values(
|
| status=JobStatus.COMPLETED.value,
|
| composite_score=composite_score,
|
| )
|
| )
|
| await session.execute(stmt)
|
| await session.commit()
|
|
|
|
|
| if job_id_str in self._cache:
|
| job = self._cache[job_id_str]
|
| job.status = JobStatus.COMPLETED
|
| job.composite_score = composite_score
|
| job.metrics = metrics
|
| job.completed_at = datetime.utcnow()
|
| job.progress = 100.0
|
|
|
| logger.info(
|
| "Job completed",
|
| job_id=job_id_str,
|
| composite_score=composite_score,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to complete job",
|
| job_id=str(job_id),
|
| error=str(e),
|
| )
|
| raise
|
|
|
| async def fail_job(
|
| self,
|
| job_id: uuid.UUID,
|
| error: str,
|
| error_details: Optional[dict] = None,
|
| ) -> None:
|
| """
|
| Mark job as failed.
|
|
|
| Args:
|
| job_id: The job ID
|
| error: Error message
|
| error_details: Optional error details
|
| """
|
| try:
|
| job_id_str = str(job_id)
|
|
|
| async with get_db_context() as session:
|
| stmt = (
|
| update(EvaluationRun)
|
| .where(EvaluationRun.id == job_id)
|
| .values(
|
| status=JobStatus.FAILED.value,
|
| )
|
| )
|
| await session.execute(stmt)
|
| await session.commit()
|
|
|
|
|
| if job_id_str in self._cache:
|
| job = self._cache[job_id_str]
|
| job.status = JobStatus.FAILED
|
| job.error = error
|
| job.error_details = error_details
|
| job.completed_at = datetime.utcnow()
|
|
|
| logger.error(
|
| "Job failed",
|
| job_id=job_id_str,
|
| error=error,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to mark job as failed",
|
| job_id=str(job_id),
|
| error=str(e),
|
| )
|
| raise
|
|
|
| async def start_job(
|
| self,
|
| job_id: uuid.UUID,
|
| worker_id: str,
|
| ) -> None:
|
| """
|
| Mark job as started.
|
|
|
| Args:
|
| job_id: The job ID
|
| worker_id: ID of the worker starting the job
|
| """
|
| try:
|
| job_id_str = str(job_id)
|
| now = datetime.utcnow()
|
|
|
| async with get_db_context() as session:
|
| stmt = (
|
| update(EvaluationRun)
|
| .where(EvaluationRun.id == job_id)
|
| .values(
|
| status=JobStatus.RUNNING.value,
|
| )
|
| )
|
| await session.execute(stmt)
|
| await session.commit()
|
|
|
|
|
| if job_id_str in self._cache:
|
| job = self._cache[job_id_str]
|
| job.status = JobStatus.RUNNING
|
| job.worker_id = worker_id
|
| job.started_at = now
|
|
|
| logger.info(
|
| "Job started",
|
| job_id=job_id_str,
|
| worker_id=worker_id,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to start job",
|
| job_id=str(job_id),
|
| error=str(e),
|
| )
|
| raise
|
|
|
| def get_cached_job(self, job_id: uuid.UUID) -> Optional[EvaluationJob]:
|
| """Get job from cache."""
|
| return self._cache.get(str(job_id))
|
|
|
| def set_cached_job(self, job: EvaluationJob) -> None:
|
| """Set job in cache."""
|
| self._cache[str(job.job_id)] = job
|
|
|
|
|
|
|
| status_tracker = JobStatusTracker()
|
|
|
|
|
| def get_status_tracker() -> JobStatusTracker:
|
| """Get the global status tracker instance."""
|
| return status_tracker
|
|
|
|
|
| __all__ = [
|
| "JobStatusTracker",
|
| "status_tracker",
|
| "get_status_tracker",
|
| "JobProgressUpdate",
|
| ]
|
|
|