aegislm / backend /queue /worker_registry.py
ACA050's picture
Upload 50 files
1a4aa87 verified
"""
Worker Registry for Distributed Worker Coordination
Provides worker registration, heartbeat management, and worker health tracking.
Handles worker lifecycle: REGISTERED -> ACTIVE -> DEGRADED -> OFFLINE
"""
import socket
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from backend.db.models import Worker
from backend.db.session import get_db_context
from backend.logging.logger import get_logger
from .job_schema import JobStatus
from .worker_schema import (
GPUInfo,
HeartbeatRequest,
HeartbeatResponse,
WorkerRegistrationRequest,
WorkerRegistrationResponse,
WorkerStatus,
WorkerStatusResponse,
WorkerListResponse,
WorkerMetricsResponse,
)
logger = get_logger("queue.worker_registry", component="queue")
# Configuration
DEFAULT_HEARTBEAT_INTERVAL = 30 # seconds
DEFAULT_HEARTBEAT_TIMEOUT = 120 # seconds (worker marked offline after this)
class WorkerRegistry:
"""
Manages worker registration, heartbeat, and health monitoring.
Responsibilities:
- Worker registration and deregistration
- Heartbeat processing
- Worker health tracking
- Offline worker detection
- Load factor calculation
"""
def __init__(self):
self._cache: Dict[str, Worker] = {}
self._last_cleanup = datetime.utcnow()
self._cleanup_interval = 60 # Run cleanup every 60 seconds
def _generate_worker_id(self, hostname: str) -> str:
"""Generate a unique worker ID based on hostname and UUID."""
return f"{hostname}-{uuid.uuid4().hex[:8]}"
async def register_worker(
self,
request: WorkerRegistrationRequest,
) -> WorkerRegistrationResponse:
"""
Register a new worker in the cluster.
Args:
request: Worker registration request
Returns:
Worker registration response with assigned worker_id
"""
try:
worker_id = self._generate_worker_id(request.hostname)
# Calculate total GPU memory if GPU info provided
gpu_memory_total = request.gpu_memory_total
if request.gpu_info:
gpu_memory_total = sum(gpu.memory_total_mb for gpu in request.gpu_info)
gpu_count = request.gpu_count
if request.gpu_info:
gpu_count = len(request.gpu_info)
async with get_db_context() as session:
# Create worker record
worker = Worker(
worker_id=worker_id,
hostname=request.hostname,
gpu_count=gpu_count,
gpu_memory_total=gpu_memory_total,
gpu_memory_used=0,
status=WorkerStatus.ACTIVE.value,
last_heartbeat=datetime.utcnow(),
active_jobs=0,
max_concurrent_jobs=request.max_concurrent_jobs,
capabilities=request.capabilities,
worker_metadata=request.worker_metadata,
)
session.add(worker)
await session.commit()
logger.info(
"Worker registered",
worker_id=worker_id,
hostname=request.hostname,
gpu_count=gpu_count,
)
# Cache the worker
self._cache[worker_id] = worker
return WorkerRegistrationResponse(
worker_id=worker_id,
status=WorkerStatus.ACTIVE,
registered_at=worker.registered_at,
heartbeat_interval=DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=DEFAULT_HEARTBEAT_TIMEOUT,
)
except Exception as e:
logger.error(
"Failed to register worker",
error=str(e),
hostname=request.hostname,
)
raise
async def heartbeat(
self,
request: HeartbeatRequest,
) -> HeartbeatResponse:
"""
Process worker heartbeat and update worker status.
Args:
request: Heartbeat request with worker status
Returns:
Heartbeat response
"""
try:
worker_id = request.worker_id
async with get_db_context() as session:
# Get worker from database
stmt = select(Worker).where(Worker.worker_id == worker_id)
result = await session.execute(stmt)
worker = result.scalar_one_or_none()
if worker is None:
logger.warning(
"Heartbeat from unknown worker",
worker_id=worker_id,
)
raise ValueError(f"Worker {worker_id} not found")
# Update worker status based on heartbeat
now = datetime.utcnow()
worker.last_heartbeat = now
# Update optional fields if provided
if request.gpu_usage is not None:
worker.gpu_usage_percent = request.gpu_usage
if request.gpu_memory_used is not None:
worker.gpu_memory_used = request.gpu_memory_used
if request.active_jobs is not None:
worker.active_jobs = request.active_jobs
# Update GPU info if provided
if request.gpu_info:
total_used = sum(gpu.memory_used_mb for gpu in request.gpu_info)
worker.gpu_memory_used = total_used
avg_util = sum(gpu.utilization_percent for gpu in request.gpu_info) / len(request.gpu_info)
worker.gpu_usage_percent = avg_util
# Update status if provided
if request.status is not None:
worker.status = request.status.value
# Determine if worker should be degraded
if worker.gpu_usage_percent > 90:
worker.status = WorkerStatus.DEGRADED.value
elif worker.active_jobs >= worker.max_concurrent_jobs:
worker.status = WorkerStatus.DEGRADED.value
elif worker.status == WorkerStatus.OFFLINE.value:
# Reactivate if was offline
worker.status = WorkerStatus.ACTIVE.value
await session.commit()
# Update cache
self._cache[worker_id] = worker
logger.debug(
"Heartbeat processed",
worker_id=worker_id,
active_jobs=worker.active_jobs,
gpu_usage=worker.gpu_usage_percent,
)
return HeartbeatResponse(
worker_id=worker_id,
status=WorkerStatus(worker.status),
timestamp=now,
assigned_jobs=worker.active_jobs,
)
except ValueError:
raise
except Exception as e:
logger.error(
"Failed to process heartbeat",
worker_id=request.worker_id,
error=str(e),
)
raise
async def get_worker_status(
self,
worker_id: str,
) -> Optional[WorkerStatusResponse]:
"""
Get worker status by ID.
Args:
worker_id: Worker ID
Returns:
Worker status response or None if not found
"""
# Check cache first
if worker_id in self._cache:
worker = self._cache[worker_id]
return self._build_worker_status_response(worker)
# Fetch from database
try:
async with get_db_context() as session:
stmt = select(Worker).where(Worker.worker_id == worker_id)
result = await session.execute(stmt)
worker = result.scalar_one_or_none()
if worker is None:
return None
return self._build_worker_status_response(worker)
except Exception as e:
logger.error(
"Failed to get worker status",
worker_id=worker_id,
error=str(e),
)
return None
async def list_workers(
self,
status_filter: Optional[WorkerStatus] = None,
) -> WorkerListResponse:
"""
List all workers in the cluster.
Args:
status_filter: Optional status filter
Returns:
Worker list response
"""
try:
async with get_db_context() as session:
stmt = select(Worker).order_by(Worker.registered_at.desc())
if status_filter:
stmt = stmt.where(Worker.status == status_filter.value)
result = await session.execute(stmt)
workers = result.scalars().all()
# Update cache
for worker in workers:
self._cache[worker.worker_id] = worker
# Build responses
worker_responses = [
self._build_worker_status_response(w) for w in workers
]
active_count = sum(
1 for w in workers if w.status == WorkerStatus.ACTIVE.value
)
offline_count = sum(
1 for w in workers if w.status == WorkerStatus.OFFLINE.value
)
return WorkerListResponse(
workers=worker_responses,
total=len(workers),
active=active_count,
offline=offline_count,
)
except Exception as e:
logger.error(
"Failed to list workers",
error=str(e),
)
return WorkerListResponse(workers=[], total=0, active=0, offline=0)
async def get_worker_metrics(self) -> WorkerMetricsResponse:
"""
Get cluster-wide worker metrics.
Returns:
Worker metrics response
"""
try:
async with get_db_context() as session:
stmt = select(Worker)
result = await session.execute(stmt)
workers = result.scalars().all()
total_workers = len(workers)
active_workers = sum(
1 for w in workers if w.status == WorkerStatus.ACTIVE.value
)
offline_workers = sum(
1 for w in workers if w.status == WorkerStatus.OFFLINE.value
)
degraded_workers = sum(
1 for w in workers if w.status == WorkerStatus.DEGRADED.value
)
total_gpus = sum(w.gpu_count for w in workers)
total_gpu_memory = sum(w.gpu_memory_total for w in workers)
used_gpu_memory = sum(w.gpu_memory_used for w in workers)
total_active_jobs = sum(w.active_jobs for w in workers)
# Calculate average load factor
if workers:
load_factors = [
w.active_jobs / w.max_concurrent_jobs
for w in workers
if w.max_concurrent_jobs > 0
]
avg_load = sum(load_factors) / len(load_factors) if load_factors else 0.0
else:
avg_load = 0.0
# Get queue length (pending jobs)
from backend.queue.producer import _job_queue
queue_length = sum(
1 for j in _job_queue
if j.status == JobStatus.QUEUED
)
return WorkerMetricsResponse(
total_workers=total_workers,
active_workers=active_workers,
offline_workers=offline_workers,
degraded_workers=degraded_workers,
total_gpus=total_gpus,
total_gpu_memory_mb=total_gpu_memory,
used_gpu_memory_mb=used_gpu_memory,
total_active_jobs=total_active_jobs,
average_load_factor=avg_load,
queue_length=queue_length,
)
except Exception as e:
logger.error(
"Failed to get worker metrics",
error=str(e),
)
return WorkerMetricsResponse(
total_workers=0,
active_workers=0,
offline_workers=0,
degraded_workers=0,
total_gpus=0,
total_gpu_memory_mb=0,
used_gpu_memory_mb=0,
total_active_jobs=0,
average_load_factor=0.0,
queue_length=0,
)
async def cleanup_offline_workers(self) -> int:
"""
Mark workers as offline if they haven't sent heartbeat within timeout.
Returns:
Number of workers marked offline
"""
try:
now = datetime.utcnow()
timeout_threshold = now - timedelta(seconds=DEFAULT_HEARTBEAT_TIMEOUT)
async with get_db_context() as session:
# Find workers that haven't sent heartbeat
stmt = (
update(Worker)
.where(Worker.last_heartbeat < timeout_threshold)
.where(Worker.status != WorkerStatus.OFFLINE.value)
.values(status=WorkerStatus.OFFLINE.value)
)
result = await session.execute(stmt)
await session.commit()
count = result.rowcount
if count > 0:
logger.info(
"Marked workers as offline",
count=count,
)
return count
except Exception as e:
logger.error(
"Failed to cleanup offline workers",
error=str(e),
)
return 0
async def requeue_worker_jobs(self, worker_id: str) -> int:
"""
Requeue jobs from an offline worker.
Args:
worker_id: Worker ID
Returns:
Number of jobs requeued
"""
try:
from backend.queue.producer import _job_queue
requeued_count = 0
for job in _job_queue:
if job.worker_id == worker_id and job.status == JobStatus.RUNNING:
job.status = JobStatus.QUEUED
job.worker_id = None
job.started_at = None
requeued_count += 1
logger.info(
"Requeued job from offline worker",
job_id=str(job.job_id),
worker_id=worker_id,
)
return requeued_count
except Exception as e:
logger.error(
"Failed to requeue worker jobs",
worker_id=worker_id,
error=str(e),
)
return 0
def _build_worker_status_response(self, worker: Worker) -> WorkerStatusResponse:
"""Build worker status response from worker model."""
# Calculate load factor
load_factor = (
worker.active_jobs / worker.max_concurrent_jobs
if worker.max_concurrent_jobs > 0
else 0.0
)
# Calculate uptime
uptime_seconds = int(
(datetime.utcnow() - worker.registered_at).total_seconds()
)
return WorkerStatusResponse(
worker_id=worker.worker_id,
hostname=worker.hostname,
status=WorkerStatus(worker.status),
gpu_count=worker.gpu_count,
gpu_memory_total=worker.gpu_memory_total,
gpu_memory_used=worker.gpu_memory_used,
gpu_usage_percent=worker.gpu_usage_percent,
active_jobs=worker.active_jobs,
max_concurrent_jobs=worker.max_concurrent_jobs,
load_factor=load_factor,
last_heartbeat=worker.last_heartbeat,
registered_at=worker.registered_at,
capabilities=worker.capabilities,
uptime_seconds=uptime_seconds,
)
async def get_available_workers(
self,
gpu_required: int = 0,
) -> List[Worker]:
"""
Get available workers that can accept new jobs.
Args:
gpu_required: Number of GPUs required (0 for CPU-only)
Returns:
List of available workers sorted by load factor
"""
try:
async with get_db_context() as session:
# Get active workers with capacity
stmt = (
select(Worker)
.where(Worker.status.in_([
WorkerStatus.ACTIVE.value,
WorkerStatus.DEGRADED.value,
]))
.where(Worker.active_jobs < Worker.max_concurrent_jobs)
)
if gpu_required > 0:
stmt = stmt.where(Worker.gpu_count >= gpu_required)
result = await session.execute(stmt)
workers = result.scalars().all()
# Sort by load factor (least loaded first)
sorted_workers = sorted(
workers,
key=lambda w: w.active_jobs / w.max_concurrent_jobs
if w.max_concurrent_jobs > 0
else 0.0
)
return list(sorted_workers)
except Exception as e:
logger.error(
"Failed to get available workers",
error=str(e),
)
return []
# Global instance
_worker_registry: Optional[WorkerRegistry] = None
def get_worker_registry() -> WorkerRegistry:
"""Get the global worker registry instance."""
global _worker_registry
if _worker_registry is None:
_worker_registry = WorkerRegistry()
return _worker_registry
__all__ = [
"WorkerRegistry",
"get_worker_registry",
"DEFAULT_HEARTBEAT_INTERVAL",
"DEFAULT_HEARTBEAT_TIMEOUT",
]