| """
|
| Worker Registry for Distributed Worker Coordination
|
|
|
| Provides worker registration, heartbeat management, and worker health tracking.
|
| Handles worker lifecycle: REGISTERED -> ACTIVE -> DEGRADED -> OFFLINE
|
| """
|
|
|
| import socket
|
| import uuid
|
| from datetime import datetime, timedelta
|
| from typing import Dict, List, Optional
|
|
|
| from sqlalchemy import select, update
|
| from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
| from backend.db.models import Worker
|
| from backend.db.session import get_db_context
|
| from backend.logging.logger import get_logger
|
|
|
| from .job_schema import JobStatus
|
| from .worker_schema import (
|
| GPUInfo,
|
| HeartbeatRequest,
|
| HeartbeatResponse,
|
| WorkerRegistrationRequest,
|
| WorkerRegistrationResponse,
|
| WorkerStatus,
|
| WorkerStatusResponse,
|
| WorkerListResponse,
|
| WorkerMetricsResponse,
|
| )
|
|
|
| logger = get_logger("queue.worker_registry", component="queue")
|
|
|
|
|
| DEFAULT_HEARTBEAT_INTERVAL = 30
|
| DEFAULT_HEARTBEAT_TIMEOUT = 120
|
|
|
|
|
| class WorkerRegistry:
|
| """
|
| Manages worker registration, heartbeat, and health monitoring.
|
|
|
| Responsibilities:
|
| - Worker registration and deregistration
|
| - Heartbeat processing
|
| - Worker health tracking
|
| - Offline worker detection
|
| - Load factor calculation
|
| """
|
|
|
| def __init__(self):
|
| self._cache: Dict[str, Worker] = {}
|
| self._last_cleanup = datetime.utcnow()
|
| self._cleanup_interval = 60
|
|
|
| def _generate_worker_id(self, hostname: str) -> str:
|
| """Generate a unique worker ID based on hostname and UUID."""
|
| return f"{hostname}-{uuid.uuid4().hex[:8]}"
|
|
|
| async def register_worker(
|
| self,
|
| request: WorkerRegistrationRequest,
|
| ) -> WorkerRegistrationResponse:
|
| """
|
| Register a new worker in the cluster.
|
|
|
| Args:
|
| request: Worker registration request
|
|
|
| Returns:
|
| Worker registration response with assigned worker_id
|
| """
|
| try:
|
| worker_id = self._generate_worker_id(request.hostname)
|
|
|
|
|
| gpu_memory_total = request.gpu_memory_total
|
| if request.gpu_info:
|
| gpu_memory_total = sum(gpu.memory_total_mb for gpu in request.gpu_info)
|
|
|
| gpu_count = request.gpu_count
|
| if request.gpu_info:
|
| gpu_count = len(request.gpu_info)
|
|
|
| async with get_db_context() as session:
|
|
|
| worker = Worker(
|
| worker_id=worker_id,
|
| hostname=request.hostname,
|
| gpu_count=gpu_count,
|
| gpu_memory_total=gpu_memory_total,
|
| gpu_memory_used=0,
|
| status=WorkerStatus.ACTIVE.value,
|
| last_heartbeat=datetime.utcnow(),
|
| active_jobs=0,
|
| max_concurrent_jobs=request.max_concurrent_jobs,
|
| capabilities=request.capabilities,
|
| worker_metadata=request.worker_metadata,
|
| )
|
| session.add(worker)
|
| await session.commit()
|
|
|
| logger.info(
|
| "Worker registered",
|
| worker_id=worker_id,
|
| hostname=request.hostname,
|
| gpu_count=gpu_count,
|
| )
|
|
|
|
|
| self._cache[worker_id] = worker
|
|
|
| return WorkerRegistrationResponse(
|
| worker_id=worker_id,
|
| status=WorkerStatus.ACTIVE,
|
| registered_at=worker.registered_at,
|
| heartbeat_interval=DEFAULT_HEARTBEAT_INTERVAL,
|
| heartbeat_timeout=DEFAULT_HEARTBEAT_TIMEOUT,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to register worker",
|
| error=str(e),
|
| hostname=request.hostname,
|
| )
|
| raise
|
|
|
| async def heartbeat(
|
| self,
|
| request: HeartbeatRequest,
|
| ) -> HeartbeatResponse:
|
| """
|
| Process worker heartbeat and update worker status.
|
|
|
| Args:
|
| request: Heartbeat request with worker status
|
|
|
| Returns:
|
| Heartbeat response
|
| """
|
| try:
|
| worker_id = request.worker_id
|
|
|
| async with get_db_context() as session:
|
|
|
| stmt = select(Worker).where(Worker.worker_id == worker_id)
|
| result = await session.execute(stmt)
|
| worker = result.scalar_one_or_none()
|
|
|
| if worker is None:
|
| logger.warning(
|
| "Heartbeat from unknown worker",
|
| worker_id=worker_id,
|
| )
|
| raise ValueError(f"Worker {worker_id} not found")
|
|
|
|
|
| now = datetime.utcnow()
|
| worker.last_heartbeat = now
|
|
|
|
|
| if request.gpu_usage is not None:
|
| worker.gpu_usage_percent = request.gpu_usage
|
|
|
| if request.gpu_memory_used is not None:
|
| worker.gpu_memory_used = request.gpu_memory_used
|
|
|
| if request.active_jobs is not None:
|
| worker.active_jobs = request.active_jobs
|
|
|
|
|
| if request.gpu_info:
|
| total_used = sum(gpu.memory_used_mb for gpu in request.gpu_info)
|
| worker.gpu_memory_used = total_used
|
| avg_util = sum(gpu.utilization_percent for gpu in request.gpu_info) / len(request.gpu_info)
|
| worker.gpu_usage_percent = avg_util
|
|
|
|
|
| if request.status is not None:
|
| worker.status = request.status.value
|
|
|
|
|
| if worker.gpu_usage_percent > 90:
|
| worker.status = WorkerStatus.DEGRADED.value
|
| elif worker.active_jobs >= worker.max_concurrent_jobs:
|
| worker.status = WorkerStatus.DEGRADED.value
|
| elif worker.status == WorkerStatus.OFFLINE.value:
|
|
|
| worker.status = WorkerStatus.ACTIVE.value
|
|
|
| await session.commit()
|
|
|
|
|
| self._cache[worker_id] = worker
|
|
|
| logger.debug(
|
| "Heartbeat processed",
|
| worker_id=worker_id,
|
| active_jobs=worker.active_jobs,
|
| gpu_usage=worker.gpu_usage_percent,
|
| )
|
|
|
| return HeartbeatResponse(
|
| worker_id=worker_id,
|
| status=WorkerStatus(worker.status),
|
| timestamp=now,
|
| assigned_jobs=worker.active_jobs,
|
| )
|
|
|
| except ValueError:
|
| raise
|
| except Exception as e:
|
| logger.error(
|
| "Failed to process heartbeat",
|
| worker_id=request.worker_id,
|
| error=str(e),
|
| )
|
| raise
|
|
|
| async def get_worker_status(
|
| self,
|
| worker_id: str,
|
| ) -> Optional[WorkerStatusResponse]:
|
| """
|
| Get worker status by ID.
|
|
|
| Args:
|
| worker_id: Worker ID
|
|
|
| Returns:
|
| Worker status response or None if not found
|
| """
|
|
|
| if worker_id in self._cache:
|
| worker = self._cache[worker_id]
|
| return self._build_worker_status_response(worker)
|
|
|
|
|
| try:
|
| async with get_db_context() as session:
|
| stmt = select(Worker).where(Worker.worker_id == worker_id)
|
| result = await session.execute(stmt)
|
| worker = result.scalar_one_or_none()
|
|
|
| if worker is None:
|
| return None
|
|
|
| return self._build_worker_status_response(worker)
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to get worker status",
|
| worker_id=worker_id,
|
| error=str(e),
|
| )
|
| return None
|
|
|
| async def list_workers(
|
| self,
|
| status_filter: Optional[WorkerStatus] = None,
|
| ) -> WorkerListResponse:
|
| """
|
| List all workers in the cluster.
|
|
|
| Args:
|
| status_filter: Optional status filter
|
|
|
| Returns:
|
| Worker list response
|
| """
|
| try:
|
| async with get_db_context() as session:
|
| stmt = select(Worker).order_by(Worker.registered_at.desc())
|
| if status_filter:
|
| stmt = stmt.where(Worker.status == status_filter.value)
|
|
|
| result = await session.execute(stmt)
|
| workers = result.scalars().all()
|
|
|
|
|
| for worker in workers:
|
| self._cache[worker.worker_id] = worker
|
|
|
|
|
| worker_responses = [
|
| self._build_worker_status_response(w) for w in workers
|
| ]
|
|
|
| active_count = sum(
|
| 1 for w in workers if w.status == WorkerStatus.ACTIVE.value
|
| )
|
| offline_count = sum(
|
| 1 for w in workers if w.status == WorkerStatus.OFFLINE.value
|
| )
|
|
|
| return WorkerListResponse(
|
| workers=worker_responses,
|
| total=len(workers),
|
| active=active_count,
|
| offline=offline_count,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to list workers",
|
| error=str(e),
|
| )
|
| return WorkerListResponse(workers=[], total=0, active=0, offline=0)
|
|
|
| async def get_worker_metrics(self) -> WorkerMetricsResponse:
|
| """
|
| Get cluster-wide worker metrics.
|
|
|
| Returns:
|
| Worker metrics response
|
| """
|
| try:
|
| async with get_db_context() as session:
|
| stmt = select(Worker)
|
| result = await session.execute(stmt)
|
| workers = result.scalars().all()
|
|
|
| total_workers = len(workers)
|
| active_workers = sum(
|
| 1 for w in workers if w.status == WorkerStatus.ACTIVE.value
|
| )
|
| offline_workers = sum(
|
| 1 for w in workers if w.status == WorkerStatus.OFFLINE.value
|
| )
|
| degraded_workers = sum(
|
| 1 for w in workers if w.status == WorkerStatus.DEGRADED.value
|
| )
|
|
|
| total_gpus = sum(w.gpu_count for w in workers)
|
| total_gpu_memory = sum(w.gpu_memory_total for w in workers)
|
| used_gpu_memory = sum(w.gpu_memory_used for w in workers)
|
| total_active_jobs = sum(w.active_jobs for w in workers)
|
|
|
|
|
| if workers:
|
| load_factors = [
|
| w.active_jobs / w.max_concurrent_jobs
|
| for w in workers
|
| if w.max_concurrent_jobs > 0
|
| ]
|
| avg_load = sum(load_factors) / len(load_factors) if load_factors else 0.0
|
| else:
|
| avg_load = 0.0
|
|
|
|
|
| from backend.queue.producer import _job_queue
|
| queue_length = sum(
|
| 1 for j in _job_queue
|
| if j.status == JobStatus.QUEUED
|
| )
|
|
|
| return WorkerMetricsResponse(
|
| total_workers=total_workers,
|
| active_workers=active_workers,
|
| offline_workers=offline_workers,
|
| degraded_workers=degraded_workers,
|
| total_gpus=total_gpus,
|
| total_gpu_memory_mb=total_gpu_memory,
|
| used_gpu_memory_mb=used_gpu_memory,
|
| total_active_jobs=total_active_jobs,
|
| average_load_factor=avg_load,
|
| queue_length=queue_length,
|
| )
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to get worker metrics",
|
| error=str(e),
|
| )
|
| return WorkerMetricsResponse(
|
| total_workers=0,
|
| active_workers=0,
|
| offline_workers=0,
|
| degraded_workers=0,
|
| total_gpus=0,
|
| total_gpu_memory_mb=0,
|
| used_gpu_memory_mb=0,
|
| total_active_jobs=0,
|
| average_load_factor=0.0,
|
| queue_length=0,
|
| )
|
|
|
| async def cleanup_offline_workers(self) -> int:
|
| """
|
| Mark workers as offline if they haven't sent heartbeat within timeout.
|
|
|
| Returns:
|
| Number of workers marked offline
|
| """
|
| try:
|
| now = datetime.utcnow()
|
| timeout_threshold = now - timedelta(seconds=DEFAULT_HEARTBEAT_TIMEOUT)
|
|
|
| async with get_db_context() as session:
|
|
|
| stmt = (
|
| update(Worker)
|
| .where(Worker.last_heartbeat < timeout_threshold)
|
| .where(Worker.status != WorkerStatus.OFFLINE.value)
|
| .values(status=WorkerStatus.OFFLINE.value)
|
| )
|
| result = await session.execute(stmt)
|
| await session.commit()
|
|
|
| count = result.rowcount
|
|
|
| if count > 0:
|
| logger.info(
|
| "Marked workers as offline",
|
| count=count,
|
| )
|
|
|
| return count
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to cleanup offline workers",
|
| error=str(e),
|
| )
|
| return 0
|
|
|
| async def requeue_worker_jobs(self, worker_id: str) -> int:
|
| """
|
| Requeue jobs from an offline worker.
|
|
|
| Args:
|
| worker_id: Worker ID
|
|
|
| Returns:
|
| Number of jobs requeued
|
| """
|
| try:
|
| from backend.queue.producer import _job_queue
|
|
|
| requeued_count = 0
|
|
|
| for job in _job_queue:
|
| if job.worker_id == worker_id and job.status == JobStatus.RUNNING:
|
| job.status = JobStatus.QUEUED
|
| job.worker_id = None
|
| job.started_at = None
|
| requeued_count += 1
|
|
|
| logger.info(
|
| "Requeued job from offline worker",
|
| job_id=str(job.job_id),
|
| worker_id=worker_id,
|
| )
|
|
|
| return requeued_count
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to requeue worker jobs",
|
| worker_id=worker_id,
|
| error=str(e),
|
| )
|
| return 0
|
|
|
| def _build_worker_status_response(self, worker: Worker) -> WorkerStatusResponse:
|
| """Build worker status response from worker model."""
|
|
|
| load_factor = (
|
| worker.active_jobs / worker.max_concurrent_jobs
|
| if worker.max_concurrent_jobs > 0
|
| else 0.0
|
| )
|
|
|
|
|
| uptime_seconds = int(
|
| (datetime.utcnow() - worker.registered_at).total_seconds()
|
| )
|
|
|
| return WorkerStatusResponse(
|
| worker_id=worker.worker_id,
|
| hostname=worker.hostname,
|
| status=WorkerStatus(worker.status),
|
| gpu_count=worker.gpu_count,
|
| gpu_memory_total=worker.gpu_memory_total,
|
| gpu_memory_used=worker.gpu_memory_used,
|
| gpu_usage_percent=worker.gpu_usage_percent,
|
| active_jobs=worker.active_jobs,
|
| max_concurrent_jobs=worker.max_concurrent_jobs,
|
| load_factor=load_factor,
|
| last_heartbeat=worker.last_heartbeat,
|
| registered_at=worker.registered_at,
|
| capabilities=worker.capabilities,
|
| uptime_seconds=uptime_seconds,
|
| )
|
|
|
| async def get_available_workers(
|
| self,
|
| gpu_required: int = 0,
|
| ) -> List[Worker]:
|
| """
|
| Get available workers that can accept new jobs.
|
|
|
| Args:
|
| gpu_required: Number of GPUs required (0 for CPU-only)
|
|
|
| Returns:
|
| List of available workers sorted by load factor
|
| """
|
| try:
|
| async with get_db_context() as session:
|
|
|
| stmt = (
|
| select(Worker)
|
| .where(Worker.status.in_([
|
| WorkerStatus.ACTIVE.value,
|
| WorkerStatus.DEGRADED.value,
|
| ]))
|
| .where(Worker.active_jobs < Worker.max_concurrent_jobs)
|
| )
|
|
|
| if gpu_required > 0:
|
| stmt = stmt.where(Worker.gpu_count >= gpu_required)
|
|
|
| result = await session.execute(stmt)
|
| workers = result.scalars().all()
|
|
|
|
|
| sorted_workers = sorted(
|
| workers,
|
| key=lambda w: w.active_jobs / w.max_concurrent_jobs
|
| if w.max_concurrent_jobs > 0
|
| else 0.0
|
| )
|
|
|
| return list(sorted_workers)
|
|
|
| except Exception as e:
|
| logger.error(
|
| "Failed to get available workers",
|
| error=str(e),
|
| )
|
| return []
|
|
|
|
|
|
|
| _worker_registry: Optional[WorkerRegistry] = None
|
|
|
|
|
| def get_worker_registry() -> WorkerRegistry:
|
| """Get the global worker registry instance."""
|
| global _worker_registry
|
| if _worker_registry is None:
|
| _worker_registry = WorkerRegistry()
|
| return _worker_registry
|
|
|
|
|
| __all__ = [
|
| "WorkerRegistry",
|
| "get_worker_registry",
|
| "DEFAULT_HEARTBEAT_INTERVAL",
|
| "DEFAULT_HEARTBEAT_TIMEOUT",
|
| ]
|
|
|