Spaces:
Sleeping
Sleeping
| """ | |
| Modular Priority-Tier Worker Pool | |
| A self-contained, plug-and-play worker pool for processing async jobs | |
| with priority-tier scheduling. Can be used in any Python application. | |
| Usage: | |
| from services.priority_worker_pool import PriorityWorkerPool, WorkerConfig | |
| # Define your job processor function | |
| async def process_my_job(job, session): | |
| # Process job and return updated job | |
| job.status = "completed" | |
| job.output_data = {"result": "done"} | |
| return job | |
| # Configure and start pool | |
| pool = PriorityWorkerPool( | |
| database_url="sqlite+aiosqlite:///./my_db.db", | |
| job_model=MyJobModel, | |
| job_processor=process_my_job, | |
| config=WorkerConfig(fast_workers=5, medium_workers=5, slow_workers=5) | |
| ) | |
| await pool.start() | |
| Environment Variables (optional): | |
| FAST_WORKERS: Number of fast workers (default: 5) | |
| MEDIUM_WORKERS: Number of medium workers (default: 5) | |
| SLOW_WORKERS: Number of slow workers (default: 5) | |
| FAST_INTERVAL: Fast tier polling interval in seconds (default: 5) | |
| MEDIUM_INTERVAL: Medium tier polling interval in seconds (default: 30) | |
| SLOW_INTERVAL: Slow tier polling interval in seconds (default: 60) | |
| Dependencies: | |
| sqlalchemy[asyncio]>=2.0.0 | |
| aiosqlite (for SQLite) or asyncpg (for PostgreSQL) | |
| Job Model Requirements: | |
| Your job model must have these columns: | |
| - job_id: str (unique identifier) | |
| - status: str (queued, processing, completed, failed, cancelled) | |
| - priority: str (fast, medium, slow) | |
| - next_process_at: datetime (nullable, for rescheduling) | |
| - retry_count: int (default 0) | |
| - created_at: datetime | |
| - started_at: datetime (nullable) | |
| - completed_at: datetime (nullable) | |
| - error_message: str (nullable) | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| from typing import Optional, List, Callable, Any, TypeVar, Generic | |
| from sqlalchemy import select, or_, and_ | |
| from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession | |
| logger = logging.getLogger(__name__) | |
| # Generic type for job model | |
| JobType = TypeVar('JobType') | |
| class WorkerConfig: | |
| """Configuration for the worker pool.""" | |
| fast_workers: int = 5 | |
| medium_workers: int = 5 | |
| slow_workers: int = 5 | |
| fast_interval: int = 2 # seconds | |
| medium_interval: int = 10 # seconds | |
| slow_interval: int = 15 # seconds | |
| max_retries: int = 60 # Max retry attempts before failing | |
| job_per_api_key: int = 1 # Max concurrent jobs per API key | |
| def from_env(cls) -> 'WorkerConfig': | |
| """Create config from environment variables.""" | |
| return cls( | |
| fast_workers=int(os.getenv("FAST_WORKERS", "5")), | |
| medium_workers=int(os.getenv("MEDIUM_WORKERS", "5")), | |
| slow_workers=int(os.getenv("SLOW_WORKERS", "5")), | |
| fast_interval=int(os.getenv("FAST_INTERVAL", "5")), | |
| medium_interval=int(os.getenv("MEDIUM_INTERVAL", "30")), | |
| slow_interval=int(os.getenv("SLOW_INTERVAL", "60")), | |
| job_per_api_key=int(os.getenv("JOB_PER_API_KEY", "1")), | |
| ) | |
| class PriorityMapping: | |
| """Maps job types to priority tiers.""" | |
| mappings: dict = field(default_factory=dict) | |
| def get_priority(self, job_type: str, default: str = "fast") -> str: | |
| """Get priority for a job type.""" | |
| return self.mappings.get(job_type, default) | |
| def get_interval(self, priority: str, config: WorkerConfig) -> int: | |
| """Get polling interval for a priority tier.""" | |
| if priority == "fast": | |
| return config.fast_interval | |
| elif priority == "medium": | |
| return config.medium_interval | |
| else: | |
| return config.slow_interval | |
| class JobProcessor(ABC, Generic[JobType]): | |
| """Abstract base class for job processors.""" | |
| async def process(self, job: JobType, session: AsyncSession) -> JobType: | |
| """ | |
| Process a job and return the updated job. | |
| Args: | |
| job: The job to process | |
| session: Database session for updates | |
| Returns: | |
| The updated job with new status/output | |
| """ | |
| pass | |
| async def check_status(self, job: JobType, session: AsyncSession) -> JobType: | |
| """ | |
| Check status of an in-progress job (for async third-party operations). | |
| Args: | |
| job: The job to check | |
| session: Database session for updates | |
| Returns: | |
| The updated job. Set next_process_at to reschedule if not done. | |
| """ | |
| pass | |
| class PriorityWorker(Generic[JobType]): | |
| """Worker that processes jobs of a specific priority tier.""" | |
| def __init__( | |
| self, | |
| worker_id: int, | |
| priority: str, | |
| poll_interval: int, | |
| session_maker: async_sessionmaker, | |
| job_model: type, | |
| job_processor: JobProcessor[JobType], | |
| max_retries: int = 60, | |
| wake_event: Optional[asyncio.Event] = None, | |
| config: Optional[WorkerConfig] = None | |
| ): | |
| self.worker_id = worker_id | |
| self.priority = priority | |
| self.poll_interval = poll_interval | |
| self.session_maker = session_maker | |
| self.job_model = job_model | |
| self.job_processor = job_processor | |
| self.max_retries = max_retries | |
| self._running = False | |
| self._current_job_id: Optional[str] = None | |
| self._wake_event = wake_event # Event to wake worker immediately when new jobs arrive | |
| self._config = config or WorkerConfig.from_env() | |
| async def start(self): | |
| """Start the worker polling loop.""" | |
| self._running = True | |
| logger.debug(f"Worker {self.worker_id} ({self.priority}) started, polling every {self.poll_interval}s") | |
| asyncio.create_task(self._poll_loop()) | |
| async def stop(self): | |
| """Stop the worker.""" | |
| self._running = False | |
| logger.info(f"Worker {self.worker_id} ({self.priority}) stopped") | |
| async def _poll_loop(self): | |
| """Main polling loop with optimized scheduling. | |
| Optimizations: | |
| - When no jobs are found, sleep for poll_interval before checking again | |
| - When a job is processed, immediately check for the next job (no waiting) | |
| - This ensures first job starts immediately when queue was empty | |
| - This ensures next job starts immediately after current job finishes | |
| """ | |
| while self._running: | |
| job_found = False | |
| try: | |
| job_found = await self._process_one_job() | |
| except Exception as e: | |
| logger.error(f"Worker {self.worker_id}: Error in poll loop: {e}") | |
| # Only sleep if no job was found - otherwise immediately look for next job | |
| if not job_found: | |
| # Wait on event with timeout - allows immediate wake-up when new job arrives | |
| if self._wake_event: | |
| try: | |
| # Wait for event or timeout (whichever comes first) | |
| await asyncio.wait_for( | |
| self._wake_event.wait(), | |
| timeout=self.poll_interval | |
| ) | |
| # Clear event after waking (we'll check for jobs) | |
| self._wake_event.clear() | |
| except asyncio.TimeoutError: | |
| pass # Normal timeout, check for jobs | |
| else: | |
| await asyncio.sleep(self.poll_interval) | |
| async def _process_one_job(self) -> bool: | |
| """Find and process one job. | |
| Enforces constraints: | |
| 1. Only one job per user can be in processing state at a time | |
| 2. Total processing jobs limited to JOB_PER_API_KEY * number of API keys | |
| Returns: | |
| True if a job was found and processed, False if no jobs available | |
| """ | |
| async with self.session_maker() as session: | |
| from sqlalchemy import func | |
| now = datetime.utcnow() | |
| # Get number of API keys for capacity calculation | |
| try: | |
| from services.api_key_manager import get_key_count | |
| num_api_keys = get_key_count() | |
| max_processing = self._config.job_per_api_key * num_api_keys | |
| except ImportError: | |
| max_processing = 10 # Default fallback | |
| # Check if we're at max processing capacity (only for new jobs being picked up) | |
| count_query = select(func.count()).where( | |
| self.job_model.status == "processing" | |
| ) | |
| count_result = await session.execute(count_query) | |
| current_processing = count_result.scalar() or 0 | |
| # Query for jobs matching this priority tier | |
| query = select(self.job_model).where( | |
| and_( | |
| self.job_model.priority == self.priority, | |
| self.job_model.status.in_(["queued", "processing"]), | |
| or_( | |
| self.job_model.next_process_at.is_(None), | |
| self.job_model.next_process_at <= now | |
| ) | |
| ) | |
| ).order_by(self.job_model.created_at).limit(1) | |
| result = await session.execute(query) | |
| job = result.scalar_one_or_none() | |
| if not job: | |
| return False | |
| # For queued jobs, apply the constraints | |
| if job.status == "queued": | |
| # Constraint 1: Check if this user already has a job in processing | |
| user_processing_query = select(func.count()).where( | |
| and_( | |
| self.job_model.user_id == job.user_id, | |
| self.job_model.status == "processing" | |
| ) | |
| ) | |
| user_result = await session.execute(user_processing_query) | |
| user_processing_count = user_result.scalar() or 0 | |
| if user_processing_count > 0: | |
| logger.debug(f"Worker {self.worker_id}: User {job.user_id} already has a job processing, skipping") | |
| return False | |
| # Constraint 2: Check if we're at max total processing capacity | |
| if current_processing >= max_processing: | |
| logger.debug(f"Worker {self.worker_id}: At max capacity ({current_processing}/{max_processing}), skipping new job") | |
| return False | |
| self._current_job_id = job.job_id | |
| try: | |
| await self._process_job(session, job) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Worker {self.worker_id}: Error processing job {job.job_id}: {e}") | |
| job.status = "failed" | |
| job.error_message = str(e) | |
| job.completed_at = datetime.utcnow() | |
| await session.commit() | |
| return True # Job was found, even though it failed | |
| finally: | |
| self._current_job_id = None | |
| async def _process_job(self, session: AsyncSession, job: JobType): | |
| """Process a single job.""" | |
| logger.info(f"Worker {self.worker_id}: Processing job {job.job_id} (status: {job.status})") | |
| from sqlalchemy import update | |
| if job.status == "queued": | |
| # New job - try to claim it atomically | |
| # Set next_process_at to future to prevent others from picking it up while we process | |
| next_check = datetime.utcnow() + timedelta(seconds=self.poll_interval * 2) | |
| stmt = ( | |
| update(self.job_model) | |
| .where( | |
| self.job_model.job_id == job.job_id, | |
| self.job_model.status == "queued" | |
| ) | |
| .values( | |
| status="processing", | |
| started_at=datetime.utcnow(), | |
| next_process_at=next_check | |
| ) | |
| ) | |
| result = await session.execute(stmt) | |
| await session.commit() | |
| if result.rowcount == 0: | |
| logger.info(f"Worker {self.worker_id}: Failed to claim job {job.job_id} (already taken)") | |
| return | |
| # We claimed it. Refresh and process. | |
| await session.refresh(job) | |
| job = await self.job_processor.process(job, session) | |
| else: | |
| # Already processing - try to claim for status check | |
| # Ensure we only pick it up if next_process_at matches (or is null/past) | |
| # But the SELECT already filtered for that. | |
| # We just need to ensure no one else grabbed it between SELECT and UPDATE. | |
| # Update next_process_at to future to lock it for this check | |
| next_check = datetime.utcnow() + timedelta(seconds=self.poll_interval * 2) | |
| stmt = ( | |
| update(self.job_model) | |
| .where( | |
| self.job_model.job_id == job.job_id, | |
| or_( | |
| self.job_model.next_process_at.is_(None), | |
| self.job_model.next_process_at <= datetime.utcnow() | |
| ) | |
| ) | |
| .values(next_process_at=next_check) | |
| ) | |
| result = await session.execute(stmt) | |
| await session.commit() | |
| if result.rowcount == 0: | |
| logger.info(f"Worker {self.worker_id}: Failed to claim job {job.job_id} for check (already taken)") | |
| return | |
| await session.refresh(job) | |
| job = await self.job_processor.check_status(job, session) | |
| # Handle retry limit | |
| if job.status == "processing" and job.retry_count > self.max_retries: | |
| job.status = "failed" | |
| job.error_message = f"Max retries ({self.max_retries}) exceeded" | |
| job.completed_at = datetime.utcnow() | |
| # Handle credit finalization for jobs with reserved credits | |
| if job.status in ("completed", "failed", "cancelled"): | |
| await self._handle_job_credits(session, job) | |
| await session.commit() | |
| async def _handle_job_credits(self, session: AsyncSession, job: JobType): | |
| """Handle credit finalization when job reaches terminal state.""" | |
| # Check if job has credits_reserved attribute (credit-enabled jobs) | |
| if not hasattr(job, 'credits_reserved') or job.credits_reserved <= 0: | |
| return | |
| try: | |
| from services.credit_service.credit_manager import handle_job_completion | |
| await handle_job_completion(session, job) | |
| except ImportError: | |
| # Credit service not available - skip | |
| logger.debug(f"Credit service not available for job {job.job_id}") | |
| except Exception as e: | |
| logger.error(f"Error handling credits for job {job.job_id}: {e}") | |
| class PriorityWorkerPool(Generic[JobType]): | |
| """ | |
| Modular priority-tier worker pool. | |
| Can be used with any job model that follows the required schema. | |
| """ | |
| def __init__( | |
| self, | |
| database_url: str, | |
| job_model: type, | |
| job_processor: JobProcessor[JobType], | |
| config: Optional[WorkerConfig] = None | |
| ): | |
| """ | |
| Initialize the worker pool. | |
| Args: | |
| database_url: SQLAlchemy async database URL | |
| job_model: Your ORM model class for jobs | |
| job_processor: Instance of JobProcessor to handle jobs | |
| config: Worker configuration (uses env vars if not provided) | |
| """ | |
| self.database_url = database_url | |
| self.job_model = job_model | |
| self.job_processor = job_processor | |
| self.config = config or WorkerConfig.from_env() | |
| self.engine = create_async_engine(database_url, echo=False) | |
| self.session_maker = async_sessionmaker( | |
| self.engine, | |
| class_=AsyncSession, | |
| expire_on_commit=False | |
| ) | |
| self.workers: List[PriorityWorker] = [] | |
| self._running = False | |
| # Wake events for each priority tier - allows immediate job notification | |
| self._wake_events: dict[str, asyncio.Event] = { | |
| "fast": asyncio.Event(), | |
| "medium": asyncio.Event(), | |
| "slow": asyncio.Event() | |
| } | |
| async def start(self): | |
| """Start all workers.""" | |
| self._running = True | |
| worker_id = 0 | |
| # Create fast workers | |
| for i in range(self.config.fast_workers): | |
| worker = PriorityWorker( | |
| worker_id=worker_id, | |
| priority="fast", | |
| poll_interval=self.config.fast_interval, | |
| session_maker=self.session_maker, | |
| job_model=self.job_model, | |
| job_processor=self.job_processor, | |
| max_retries=self.config.max_retries, | |
| wake_event=self._wake_events["fast"], | |
| config=self.config | |
| ) | |
| self.workers.append(worker) | |
| await worker.start() | |
| worker_id += 1 | |
| # Create medium workers | |
| for i in range(self.config.medium_workers): | |
| worker = PriorityWorker( | |
| worker_id=worker_id, | |
| priority="medium", | |
| poll_interval=self.config.medium_interval, | |
| session_maker=self.session_maker, | |
| job_model=self.job_model, | |
| job_processor=self.job_processor, | |
| max_retries=self.config.max_retries, | |
| wake_event=self._wake_events["medium"], | |
| config=self.config | |
| ) | |
| self.workers.append(worker) | |
| await worker.start() | |
| worker_id += 1 | |
| # Create slow workers | |
| for i in range(self.config.slow_workers): | |
| worker = PriorityWorker( | |
| worker_id=worker_id, | |
| priority="slow", | |
| poll_interval=self.config.slow_interval, | |
| session_maker=self.session_maker, | |
| job_model=self.job_model, | |
| job_processor=self.job_processor, | |
| max_retries=self.config.max_retries, | |
| wake_event=self._wake_events["slow"], | |
| config=self.config | |
| ) | |
| self.workers.append(worker) | |
| await worker.start() | |
| worker_id += 1 | |
| total = self.config.fast_workers + self.config.medium_workers + self.config.slow_workers | |
| logger.info( | |
| f"PriorityWorkerPool started with {total} workers: " | |
| f"{self.config.fast_workers} fast, {self.config.medium_workers} medium, {self.config.slow_workers} slow" | |
| ) | |
| def notify_new_job(self, priority: str): | |
| """ | |
| Wake sleeping workers of the specified priority tier. | |
| Call this when a new job is created to start processing immediately. | |
| Args: | |
| priority: Priority tier ("fast", "medium", or "slow") | |
| """ | |
| if priority in self._wake_events: | |
| self._wake_events[priority].set() | |
| logger.debug(f"Notified {priority} workers of new job") | |
| async def stop(self): | |
| """Stop all workers and refund orphaned jobs.""" | |
| self._running = False | |
| # Refund credits for any jobs that were processing when server stopped | |
| await self._refund_orphaned_jobs() | |
| for worker in self.workers: | |
| await worker.stop() | |
| logger.info("PriorityWorkerPool stopped") | |
| async def _refund_orphaned_jobs(self): | |
| """Refund credits for jobs abandoned during shutdown.""" | |
| try: | |
| from services.credit_service.credit_manager import refund_orphaned_jobs | |
| async with self.session_maker() as session: | |
| refund_count = await refund_orphaned_jobs(session) | |
| if refund_count > 0: | |
| logger.info(f"Shutdown: Refunded {refund_count} orphaned job(s)") | |
| except ImportError: | |
| logger.debug("Credit service not available for orphaned job refunds") | |
| except Exception as e: | |
| logger.error(f"Error refunding orphaned jobs during shutdown: {e}") | |
| # Convenience functions for priority mapping | |
| def get_priority_for_job_type(job_type: str, mappings: dict) -> str: | |
| """Get priority tier for a job type using provided mappings.""" | |
| return mappings.get(job_type, "fast") | |
| def get_interval_for_priority(priority: str, config: Optional[WorkerConfig] = None) -> int: | |
| """Get polling interval for a priority tier.""" | |
| cfg = config or WorkerConfig.from_env() | |
| if priority == "fast": | |
| return cfg.fast_interval | |
| elif priority == "medium": | |
| return cfg.medium_interval | |
| else: | |
| return cfg.slow_interval | |