Spaces:
Running
Running
mod
Browse files- app.py +1 -1
- routers/gemini.py +1 -1
- services/gemini_job_worker.py +218 -0
- services/{job_worker.py → priority_worker_pool.py} +235 -199
- tests/test_worker_pool.py +52 -49
app.py
CHANGED
|
@@ -48,7 +48,7 @@ async def lifespan(app: FastAPI):
|
|
| 48 |
logger.info("Database initialized successfully")
|
| 49 |
|
| 50 |
# Start background job worker
|
| 51 |
-
from services.
|
| 52 |
await start_worker()
|
| 53 |
logger.info("Background job worker started")
|
| 54 |
|
|
|
|
| 48 |
logger.info("Database initialized successfully")
|
| 49 |
|
| 50 |
# Start background job worker
|
| 51 |
+
from services.gemini_job_worker import start_worker, stop_worker
|
| 52 |
await start_worker()
|
| 53 |
logger.info("Background job worker started")
|
| 54 |
|
routers/gemini.py
CHANGED
|
@@ -75,7 +75,7 @@ async def create_job(
|
|
| 75 |
input_data: dict
|
| 76 |
) -> GeminiJob:
|
| 77 |
"""Create a new job in the queue with auto-assigned priority."""
|
| 78 |
-
from services.
|
| 79 |
|
| 80 |
job_id = f"job_{uuid.uuid4().hex[:16]}"
|
| 81 |
priority = get_priority_for_job_type(job_type)
|
|
|
|
| 75 |
input_data: dict
|
| 76 |
) -> GeminiJob:
|
| 77 |
"""Create a new job in the queue with auto-assigned priority."""
|
| 78 |
+
from services.gemini_job_worker import get_priority_for_job_type
|
| 79 |
|
| 80 |
job_id = f"job_{uuid.uuid4().hex[:16]}"
|
| 81 |
priority = get_priority_for_job_type(job_type)
|
services/gemini_job_worker.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini Job Worker - Specific implementation using the modular PriorityWorkerPool.
|
| 3 |
+
|
| 4 |
+
This file shows how to use the modular PriorityWorkerPool with Gemini-specific
|
| 5 |
+
job processing logic.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 11 |
+
|
| 12 |
+
from core.database import DATABASE_URL
|
| 13 |
+
from core.models import GeminiJob
|
| 14 |
+
from services.priority_worker_pool import (
|
| 15 |
+
PriorityWorkerPool,
|
| 16 |
+
JobProcessor,
|
| 17 |
+
WorkerConfig,
|
| 18 |
+
get_interval_for_priority
|
| 19 |
+
)
|
| 20 |
+
from services.gemini_service import GeminiService
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# Job type to priority mapping for Gemini jobs
|
| 25 |
+
JOB_PRIORITY_MAP = {
|
| 26 |
+
"text": "fast",
|
| 27 |
+
"analyze": "fast",
|
| 28 |
+
"animation_prompt": "fast",
|
| 29 |
+
"image": "medium",
|
| 30 |
+
"edit_image": "medium",
|
| 31 |
+
"video": "slow"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_priority_for_job_type(job_type: str) -> str:
|
| 36 |
+
"""Get the priority tier for a Gemini job type."""
|
| 37 |
+
return JOB_PRIORITY_MAP.get(job_type, "fast")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class GeminiJobProcessor(JobProcessor[GeminiJob]):
|
| 41 |
+
"""Processes Gemini AI jobs (text, image, video generation)."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 44 |
+
"""Initialize with optional API key (uses env var if not provided)."""
|
| 45 |
+
self.api_key = api_key
|
| 46 |
+
|
| 47 |
+
def _get_service(self) -> GeminiService:
|
| 48 |
+
"""Get a GeminiService instance."""
|
| 49 |
+
return GeminiService(api_key=self.api_key)
|
| 50 |
+
|
| 51 |
+
async def process(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
|
| 52 |
+
"""Start processing a new job."""
|
| 53 |
+
service = self._get_service()
|
| 54 |
+
input_data = job.input_data or {}
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
if job.job_type == "video":
|
| 58 |
+
job = await self._start_video(job, session, service, input_data)
|
| 59 |
+
elif job.job_type == "image":
|
| 60 |
+
job = await self._process_image(job, service, input_data)
|
| 61 |
+
elif job.job_type == "text":
|
| 62 |
+
job = await self._process_text(job, service, input_data)
|
| 63 |
+
elif job.job_type == "analyze":
|
| 64 |
+
job = await self._process_analyze(job, service, input_data)
|
| 65 |
+
elif job.job_type == "animation_prompt":
|
| 66 |
+
job = await self._process_animation_prompt(job, service, input_data)
|
| 67 |
+
else:
|
| 68 |
+
job.status = "failed"
|
| 69 |
+
job.error_message = f"Unknown job type: {job.job_type}"
|
| 70 |
+
job.completed_at = datetime.utcnow()
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Error processing job {job.job_id}: {e}")
|
| 73 |
+
job.status = "failed"
|
| 74 |
+
job.error_message = str(e)
|
| 75 |
+
job.completed_at = datetime.utcnow()
|
| 76 |
+
|
| 77 |
+
return job
|
| 78 |
+
|
| 79 |
+
async def check_status(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
|
| 80 |
+
"""Check status of an in-progress job (video generation)."""
|
| 81 |
+
if job.job_type != "video" or not job.third_party_id:
|
| 82 |
+
# Non-video jobs or missing third_party_id - shouldn't happen
|
| 83 |
+
job.status = "failed"
|
| 84 |
+
job.error_message = "Invalid job state for status check"
|
| 85 |
+
job.completed_at = datetime.utcnow()
|
| 86 |
+
return job
|
| 87 |
+
|
| 88 |
+
service = self._get_service()
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
status_result = await service.check_video_status(job.third_party_id)
|
| 92 |
+
|
| 93 |
+
if status_result.get("done"):
|
| 94 |
+
if status_result.get("status") == "completed":
|
| 95 |
+
video_url = status_result.get("video_url")
|
| 96 |
+
if video_url:
|
| 97 |
+
filename = await service.download_video(video_url, job.job_id)
|
| 98 |
+
job.status = "completed"
|
| 99 |
+
job.output_data = {"filename": filename}
|
| 100 |
+
else:
|
| 101 |
+
job.status = "failed"
|
| 102 |
+
job.error_message = "No video URL returned"
|
| 103 |
+
else:
|
| 104 |
+
job.status = "failed"
|
| 105 |
+
job.error_message = status_result.get("error", "Unknown error")
|
| 106 |
+
|
| 107 |
+
job.completed_at = datetime.utcnow()
|
| 108 |
+
else:
|
| 109 |
+
# Not done - reschedule
|
| 110 |
+
job.retry_count += 1
|
| 111 |
+
config = WorkerConfig.from_env()
|
| 112 |
+
interval = get_interval_for_priority(job.priority, config)
|
| 113 |
+
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
|
| 114 |
+
logger.debug(f"Job {job.job_id}: retry #{job.retry_count}, next check at {job.next_process_at}")
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Error checking video status for {job.job_id}: {e}")
|
| 118 |
+
job.retry_count += 1
|
| 119 |
+
config = WorkerConfig.from_env()
|
| 120 |
+
interval = get_interval_for_priority(job.priority, config)
|
| 121 |
+
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
|
| 122 |
+
|
| 123 |
+
return job
|
| 124 |
+
|
| 125 |
+
async def _start_video(self, job: GeminiJob, session: AsyncSession, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 126 |
+
"""Start async video generation."""
|
| 127 |
+
result = await service.start_video_generation(
|
| 128 |
+
base64_image=input_data.get("base64_image", ""),
|
| 129 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 130 |
+
prompt=input_data.get("prompt", ""),
|
| 131 |
+
aspect_ratio=input_data.get("aspect_ratio", "16:9"),
|
| 132 |
+
resolution=input_data.get("resolution", "720p"),
|
| 133 |
+
number_of_videos=input_data.get("number_of_videos", 1)
|
| 134 |
+
)
|
| 135 |
+
job.third_party_id = result.get("gemini_operation_name")
|
| 136 |
+
|
| 137 |
+
# Schedule first status check
|
| 138 |
+
config = WorkerConfig.from_env()
|
| 139 |
+
interval = get_interval_for_priority(job.priority, config)
|
| 140 |
+
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
|
| 141 |
+
|
| 142 |
+
return job
|
| 143 |
+
|
| 144 |
+
async def _process_image(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 145 |
+
"""Process image edit (synchronous)."""
|
| 146 |
+
result = await service.edit_image(
|
| 147 |
+
base64_image=input_data.get("base64_image", ""),
|
| 148 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 149 |
+
prompt=input_data.get("prompt", "")
|
| 150 |
+
)
|
| 151 |
+
job.status = "completed"
|
| 152 |
+
job.output_data = {"image": result}
|
| 153 |
+
job.completed_at = datetime.utcnow()
|
| 154 |
+
return job
|
| 155 |
+
|
| 156 |
+
async def _process_text(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 157 |
+
"""Process text generation (synchronous)."""
|
| 158 |
+
result = await service.generate_text(
|
| 159 |
+
prompt=input_data.get("prompt", ""),
|
| 160 |
+
model=input_data.get("model")
|
| 161 |
+
)
|
| 162 |
+
job.status = "completed"
|
| 163 |
+
job.output_data = {"text": result}
|
| 164 |
+
job.completed_at = datetime.utcnow()
|
| 165 |
+
return job
|
| 166 |
+
|
| 167 |
+
async def _process_analyze(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 168 |
+
"""Process image analysis (synchronous)."""
|
| 169 |
+
result = await service.analyze_image(
|
| 170 |
+
base64_image=input_data.get("base64_image", ""),
|
| 171 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 172 |
+
prompt=input_data.get("prompt", "")
|
| 173 |
+
)
|
| 174 |
+
job.status = "completed"
|
| 175 |
+
job.output_data = {"analysis": result}
|
| 176 |
+
job.completed_at = datetime.utcnow()
|
| 177 |
+
return job
|
| 178 |
+
|
| 179 |
+
async def _process_animation_prompt(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 180 |
+
"""Process animation prompt generation (synchronous)."""
|
| 181 |
+
result = await service.generate_animation_prompt(
|
| 182 |
+
base64_image=input_data.get("base64_image", ""),
|
| 183 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 184 |
+
custom_prompt=input_data.get("custom_prompt")
|
| 185 |
+
)
|
| 186 |
+
job.status = "completed"
|
| 187 |
+
job.output_data = {"prompt": result}
|
| 188 |
+
job.completed_at = datetime.utcnow()
|
| 189 |
+
return job
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Singleton pool instance
|
| 193 |
+
_pool: Optional[PriorityWorkerPool] = None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_pool() -> PriorityWorkerPool:
|
| 197 |
+
"""Get the global Gemini worker pool instance."""
|
| 198 |
+
global _pool
|
| 199 |
+
if _pool is None:
|
| 200 |
+
_pool = PriorityWorkerPool(
|
| 201 |
+
database_url=DATABASE_URL,
|
| 202 |
+
job_model=GeminiJob,
|
| 203 |
+
job_processor=GeminiJobProcessor(),
|
| 204 |
+
config=WorkerConfig.from_env()
|
| 205 |
+
)
|
| 206 |
+
return _pool
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
async def start_worker():
|
| 210 |
+
"""Start the Gemini job worker pool."""
|
| 211 |
+
pool = get_pool()
|
| 212 |
+
await pool.start()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
async def stop_worker():
|
| 216 |
+
"""Stop the Gemini job worker pool."""
|
| 217 |
+
pool = get_pool()
|
| 218 |
+
await pool.stop()
|
services/{job_worker.py → priority_worker_pool.py}
RENAMED
|
@@ -1,67 +1,163 @@
|
|
| 1 |
"""
|
| 2 |
-
Priority-Tier Worker Pool
|
| 3 |
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
import asyncio
|
| 11 |
import logging
|
| 12 |
import os
|
|
|
|
|
|
|
| 13 |
from datetime import datetime, timedelta
|
| 14 |
-
from typing import Optional, List
|
| 15 |
from sqlalchemy import select, or_, and_
|
| 16 |
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 17 |
|
| 18 |
-
from core.database import DATABASE_URL
|
| 19 |
-
from core.models import GeminiJob
|
| 20 |
-
from services.gemini_service import GeminiService
|
| 21 |
-
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
MEDIUM_WORKERS = int(os.getenv("MEDIUM_WORKERS", "5"))
|
| 27 |
-
SLOW_WORKERS = int(os.getenv("SLOW_WORKERS", "5"))
|
| 28 |
|
| 29 |
-
FAST_INTERVAL = int(os.getenv("FAST_INTERVAL", "5")) # 5 seconds
|
| 30 |
-
MEDIUM_INTERVAL = int(os.getenv("MEDIUM_INTERVAL", "30")) # 30 seconds
|
| 31 |
-
SLOW_INTERVAL = int(os.getenv("SLOW_INTERVAL", "60")) # 60 seconds
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
def get_priority_for_job_type(job_type: str) -> str:
|
| 44 |
-
"""Get the priority tier for a job type."""
|
| 45 |
-
return JOB_PRIORITY_MAP.get(job_type, "fast")
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
"""Worker that processes jobs of a specific priority tier."""
|
| 59 |
|
| 60 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
self.worker_id = worker_id
|
| 62 |
self.priority = priority
|
| 63 |
self.poll_interval = poll_interval
|
| 64 |
self.session_maker = session_maker
|
|
|
|
|
|
|
|
|
|
| 65 |
self._running = False
|
| 66 |
self._current_job_id: Optional[str] = None
|
| 67 |
|
|
@@ -88,24 +184,25 @@ class PriorityWorker:
|
|
| 88 |
async def _process_one_job(self):
|
| 89 |
"""Find and process one job."""
|
| 90 |
async with self.session_maker() as session:
|
| 91 |
-
# Find a job to process
|
| 92 |
now = datetime.utcnow()
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
and_(
|
| 95 |
-
|
| 96 |
-
|
| 97 |
or_(
|
| 98 |
-
|
| 99 |
-
|
| 100 |
)
|
| 101 |
)
|
| 102 |
-
).order_by(
|
| 103 |
|
| 104 |
result = await session.execute(query)
|
| 105 |
job = result.scalar_one_or_none()
|
| 106 |
|
| 107 |
if not job:
|
| 108 |
-
return
|
| 109 |
|
| 110 |
self._current_job_id = job.job_id
|
| 111 |
|
|
@@ -120,142 +217,60 @@ class PriorityWorker:
|
|
| 120 |
finally:
|
| 121 |
self._current_job_id = None
|
| 122 |
|
| 123 |
-
async def _process_job(self, session: AsyncSession, job:
|
| 124 |
"""Process a single job."""
|
| 125 |
-
logger.info(f"Worker {self.worker_id}: Processing job {job.job_id} (
|
| 126 |
-
|
| 127 |
-
service = GeminiService()
|
| 128 |
-
input_data = job.input_data or {}
|
| 129 |
|
| 130 |
-
# If queued, start the operation
|
| 131 |
if job.status == "queued":
|
|
|
|
| 132 |
job.status = "processing"
|
| 133 |
job.started_at = datetime.utcnow()
|
| 134 |
await session.commit()
|
| 135 |
|
| 136 |
-
#
|
| 137 |
-
await self.
|
| 138 |
-
|
| 139 |
-
# Check status for operations that need polling (video)
|
| 140 |
-
if job.job_type == "video" and job.third_party_id:
|
| 141 |
-
await self._check_video_status(session, job, service)
|
| 142 |
-
# For synchronous operations (already completed in _start_operation)
|
| 143 |
-
# Nothing more to do
|
| 144 |
-
|
| 145 |
-
async def _start_operation(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
|
| 146 |
-
"""Start the third-party operation based on job type."""
|
| 147 |
-
if job.job_type == "video":
|
| 148 |
-
# Start async video generation
|
| 149 |
-
result = await service.start_video_generation(
|
| 150 |
-
base64_image=input_data.get("base64_image", ""),
|
| 151 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 152 |
-
prompt=input_data.get("prompt", ""),
|
| 153 |
-
aspect_ratio=input_data.get("aspect_ratio", "16:9"),
|
| 154 |
-
resolution=input_data.get("resolution", "720p"),
|
| 155 |
-
number_of_videos=input_data.get("number_of_videos", 1)
|
| 156 |
-
)
|
| 157 |
-
job.third_party_id = result.get("gemini_operation_name")
|
| 158 |
-
# Schedule first status check
|
| 159 |
-
job.next_process_at = datetime.utcnow() + timedelta(seconds=self.poll_interval)
|
| 160 |
-
await session.commit()
|
| 161 |
-
|
| 162 |
-
elif job.job_type == "image":
|
| 163 |
-
# Synchronous image edit
|
| 164 |
-
result = await service.edit_image(
|
| 165 |
-
base64_image=input_data.get("base64_image", ""),
|
| 166 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 167 |
-
prompt=input_data.get("prompt", "")
|
| 168 |
-
)
|
| 169 |
-
job.status = "completed"
|
| 170 |
-
job.output_data = {"image": result}
|
| 171 |
-
job.completed_at = datetime.utcnow()
|
| 172 |
-
await session.commit()
|
| 173 |
-
|
| 174 |
-
elif job.job_type == "text":
|
| 175 |
-
# Synchronous text generation
|
| 176 |
-
result = await service.generate_text(
|
| 177 |
-
prompt=input_data.get("prompt", ""),
|
| 178 |
-
model=input_data.get("model")
|
| 179 |
-
)
|
| 180 |
-
job.status = "completed"
|
| 181 |
-
job.output_data = {"text": result}
|
| 182 |
-
job.completed_at = datetime.utcnow()
|
| 183 |
-
await session.commit()
|
| 184 |
-
|
| 185 |
-
elif job.job_type == "analyze":
|
| 186 |
-
# Synchronous image analysis
|
| 187 |
-
result = await service.analyze_image(
|
| 188 |
-
base64_image=input_data.get("base64_image", ""),
|
| 189 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 190 |
-
prompt=input_data.get("prompt", "")
|
| 191 |
-
)
|
| 192 |
-
job.status = "completed"
|
| 193 |
-
job.output_data = {"analysis": result}
|
| 194 |
-
job.completed_at = datetime.utcnow()
|
| 195 |
-
await session.commit()
|
| 196 |
-
|
| 197 |
-
elif job.job_type == "animation_prompt":
|
| 198 |
-
# Synchronous animation prompt generation
|
| 199 |
-
result = await service.generate_animation_prompt(
|
| 200 |
-
base64_image=input_data.get("base64_image", ""),
|
| 201 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 202 |
-
custom_prompt=input_data.get("custom_prompt")
|
| 203 |
-
)
|
| 204 |
-
job.status = "completed"
|
| 205 |
-
job.output_data = {"prompt": result}
|
| 206 |
-
job.completed_at = datetime.utcnow()
|
| 207 |
-
await session.commit()
|
| 208 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
job.status = "failed"
|
| 210 |
-
job.error_message = f"
|
| 211 |
job.completed_at = datetime.utcnow()
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
async def _check_video_status(self, session: AsyncSession, job: GeminiJob, service: GeminiService):
|
| 215 |
-
"""Check video generation status and reschedule if not done."""
|
| 216 |
-
try:
|
| 217 |
-
status_result = await service.check_video_status(job.third_party_id)
|
| 218 |
-
|
| 219 |
-
if status_result.get("done"):
|
| 220 |
-
if status_result.get("status") == "completed":
|
| 221 |
-
# Download video
|
| 222 |
-
video_url = status_result.get("video_url")
|
| 223 |
-
if video_url:
|
| 224 |
-
filename = await service.download_video(video_url, job.job_id)
|
| 225 |
-
job.status = "completed"
|
| 226 |
-
job.output_data = {"filename": filename}
|
| 227 |
-
else:
|
| 228 |
-
job.status = "failed"
|
| 229 |
-
job.error_message = "No video URL returned"
|
| 230 |
-
else:
|
| 231 |
-
job.status = "failed"
|
| 232 |
-
job.error_message = status_result.get("error", "Unknown error")
|
| 233 |
-
|
| 234 |
-
job.completed_at = datetime.utcnow()
|
| 235 |
-
else:
|
| 236 |
-
# Not done - reschedule
|
| 237 |
-
job.retry_count += 1
|
| 238 |
-
job.next_process_at = datetime.utcnow() + timedelta(seconds=self.poll_interval)
|
| 239 |
-
logger.debug(f"Job {job.job_id}: Not done, retry #{job.retry_count}, next check at {job.next_process_at}")
|
| 240 |
-
|
| 241 |
-
await session.commit()
|
| 242 |
-
|
| 243 |
-
except Exception as e:
|
| 244 |
-
logger.error(f"Error checking video status for {job.job_id}: {e}")
|
| 245 |
-
job.retry_count += 1
|
| 246 |
-
job.next_process_at = datetime.utcnow() + timedelta(seconds=self.poll_interval)
|
| 247 |
-
if job.retry_count > 60: # ~1 hour of retries
|
| 248 |
-
job.status = "failed"
|
| 249 |
-
job.error_message = f"Max retries exceeded: {str(e)}"
|
| 250 |
-
job.completed_at = datetime.utcnow()
|
| 251 |
-
await session.commit()
|
| 252 |
|
| 253 |
|
| 254 |
-
class
|
| 255 |
-
"""
|
|
|
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
self.session_maker = async_sessionmaker(
|
| 260 |
self.engine,
|
| 261 |
class_=AsyncSession,
|
|
@@ -270,55 +285,76 @@ class WorkerPool:
|
|
| 270 |
worker_id = 0
|
| 271 |
|
| 272 |
# Create fast workers
|
| 273 |
-
for i in range(
|
| 274 |
-
worker = PriorityWorker(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
self.workers.append(worker)
|
| 276 |
await worker.start()
|
| 277 |
worker_id += 1
|
| 278 |
|
| 279 |
# Create medium workers
|
| 280 |
-
for i in range(
|
| 281 |
-
worker = PriorityWorker(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
self.workers.append(worker)
|
| 283 |
await worker.start()
|
| 284 |
worker_id += 1
|
| 285 |
|
| 286 |
# Create slow workers
|
| 287 |
-
for i in range(
|
| 288 |
-
worker = PriorityWorker(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
self.workers.append(worker)
|
| 290 |
await worker.start()
|
| 291 |
worker_id += 1
|
| 292 |
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
async def stop(self):
|
| 296 |
"""Stop all workers."""
|
| 297 |
self._running = False
|
| 298 |
for worker in self.workers:
|
| 299 |
await worker.stop()
|
| 300 |
-
logger.info("
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
# Singleton pool instance
|
| 304 |
-
_pool: Optional[WorkerPool] = None
|
| 305 |
|
| 306 |
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
_pool = WorkerPool()
|
| 312 |
-
return _pool
|
| 313 |
|
| 314 |
|
| 315 |
-
|
| 316 |
-
"""
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
await pool.stop()
|
|
|
|
| 1 |
"""
|
| 2 |
+
Modular Priority-Tier Worker Pool
|
| 3 |
|
| 4 |
+
A self-contained, plug-and-play worker pool for processing async jobs
|
| 5 |
+
with priority-tier scheduling. Can be used in any Python application.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from services.priority_worker_pool import PriorityWorkerPool, WorkerConfig
|
| 9 |
+
|
| 10 |
+
# Define your job processor function
|
| 11 |
+
async def process_my_job(job, session):
|
| 12 |
+
# Process job and return updated job
|
| 13 |
+
job.status = "completed"
|
| 14 |
+
job.output_data = {"result": "done"}
|
| 15 |
+
return job
|
| 16 |
+
|
| 17 |
+
# Configure and start pool
|
| 18 |
+
pool = PriorityWorkerPool(
|
| 19 |
+
database_url="sqlite+aiosqlite:///./my_db.db",
|
| 20 |
+
job_model=MyJobModel,
|
| 21 |
+
job_processor=process_my_job,
|
| 22 |
+
config=WorkerConfig(fast_workers=5, medium_workers=5, slow_workers=5)
|
| 23 |
+
)
|
| 24 |
+
await pool.start()
|
| 25 |
+
|
| 26 |
+
Environment Variables (optional):
|
| 27 |
+
FAST_WORKERS: Number of fast workers (default: 5)
|
| 28 |
+
MEDIUM_WORKERS: Number of medium workers (default: 5)
|
| 29 |
+
SLOW_WORKERS: Number of slow workers (default: 5)
|
| 30 |
+
FAST_INTERVAL: Fast tier polling interval in seconds (default: 5)
|
| 31 |
+
MEDIUM_INTERVAL: Medium tier polling interval in seconds (default: 30)
|
| 32 |
+
SLOW_INTERVAL: Slow tier polling interval in seconds (default: 60)
|
| 33 |
+
|
| 34 |
+
Dependencies:
|
| 35 |
+
sqlalchemy[asyncio]>=2.0.0
|
| 36 |
+
aiosqlite (for SQLite) or asyncpg (for PostgreSQL)
|
| 37 |
+
|
| 38 |
+
Job Model Requirements:
|
| 39 |
+
Your job model must have these columns:
|
| 40 |
+
- job_id: str (unique identifier)
|
| 41 |
+
- status: str (queued, processing, completed, failed, cancelled)
|
| 42 |
+
- priority: str (fast, medium, slow)
|
| 43 |
+
- next_process_at: datetime (nullable, for rescheduling)
|
| 44 |
+
- retry_count: int (default 0)
|
| 45 |
+
- created_at: datetime
|
| 46 |
+
- started_at: datetime (nullable)
|
| 47 |
+
- completed_at: datetime (nullable)
|
| 48 |
+
- error_message: str (nullable)
|
| 49 |
"""
|
| 50 |
import asyncio
|
| 51 |
import logging
|
| 52 |
import os
|
| 53 |
+
from abc import ABC, abstractmethod
|
| 54 |
+
from dataclasses import dataclass, field
|
| 55 |
from datetime import datetime, timedelta
|
| 56 |
+
from typing import Optional, List, Callable, Any, TypeVar, Generic
|
| 57 |
from sqlalchemy import select, or_, and_
|
| 58 |
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
logger = logging.getLogger(__name__)
|
| 61 |
|
| 62 |
+
# Generic type for job model
|
| 63 |
+
JobType = TypeVar('JobType')
|
|
|
|
|
|
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
@dataclass
|
| 67 |
+
class WorkerConfig:
|
| 68 |
+
"""Configuration for the worker pool."""
|
| 69 |
+
fast_workers: int = 5
|
| 70 |
+
medium_workers: int = 5
|
| 71 |
+
slow_workers: int = 5
|
| 72 |
+
fast_interval: int = 5 # seconds
|
| 73 |
+
medium_interval: int = 30 # seconds
|
| 74 |
+
slow_interval: int = 60 # seconds
|
| 75 |
+
max_retries: int = 60 # Max retry attempts before failing
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def from_env(cls) -> 'WorkerConfig':
|
| 79 |
+
"""Create config from environment variables."""
|
| 80 |
+
return cls(
|
| 81 |
+
fast_workers=int(os.getenv("FAST_WORKERS", "5")),
|
| 82 |
+
medium_workers=int(os.getenv("MEDIUM_WORKERS", "5")),
|
| 83 |
+
slow_workers=int(os.getenv("SLOW_WORKERS", "5")),
|
| 84 |
+
fast_interval=int(os.getenv("FAST_INTERVAL", "5")),
|
| 85 |
+
medium_interval=int(os.getenv("MEDIUM_INTERVAL", "30")),
|
| 86 |
+
slow_interval=int(os.getenv("SLOW_INTERVAL", "60")),
|
| 87 |
+
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
@dataclass
|
| 91 |
+
class PriorityMapping:
|
| 92 |
+
"""Maps job types to priority tiers."""
|
| 93 |
+
mappings: dict = field(default_factory=dict)
|
| 94 |
+
|
| 95 |
+
def get_priority(self, job_type: str, default: str = "fast") -> str:
|
| 96 |
+
"""Get priority for a job type."""
|
| 97 |
+
return self.mappings.get(job_type, default)
|
| 98 |
+
|
| 99 |
+
def get_interval(self, priority: str, config: WorkerConfig) -> int:
|
| 100 |
+
"""Get polling interval for a priority tier."""
|
| 101 |
+
if priority == "fast":
|
| 102 |
+
return config.fast_interval
|
| 103 |
+
elif priority == "medium":
|
| 104 |
+
return config.medium_interval
|
| 105 |
+
else:
|
| 106 |
+
return config.slow_interval
|
| 107 |
|
| 108 |
|
| 109 |
+
class JobProcessor(ABC, Generic[JobType]):
|
| 110 |
+
"""Abstract base class for job processors."""
|
| 111 |
+
|
| 112 |
+
@abstractmethod
|
| 113 |
+
async def process(self, job: JobType, session: AsyncSession) -> JobType:
|
| 114 |
+
"""
|
| 115 |
+
Process a job and return the updated job.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
job: The job to process
|
| 119 |
+
session: Database session for updates
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
The updated job with new status/output
|
| 123 |
+
"""
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
@abstractmethod
|
| 127 |
+
async def check_status(self, job: JobType, session: AsyncSession) -> JobType:
|
| 128 |
+
"""
|
| 129 |
+
Check status of an in-progress job (for async third-party operations).
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
job: The job to check
|
| 133 |
+
session: Database session for updates
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
The updated job. Set next_process_at to reschedule if not done.
|
| 137 |
+
"""
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class PriorityWorker(Generic[JobType]):
|
| 142 |
"""Worker that processes jobs of a specific priority tier."""
|
| 143 |
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
worker_id: int,
|
| 147 |
+
priority: str,
|
| 148 |
+
poll_interval: int,
|
| 149 |
+
session_maker: async_sessionmaker,
|
| 150 |
+
job_model: type,
|
| 151 |
+
job_processor: JobProcessor[JobType],
|
| 152 |
+
max_retries: int = 60
|
| 153 |
+
):
|
| 154 |
self.worker_id = worker_id
|
| 155 |
self.priority = priority
|
| 156 |
self.poll_interval = poll_interval
|
| 157 |
self.session_maker = session_maker
|
| 158 |
+
self.job_model = job_model
|
| 159 |
+
self.job_processor = job_processor
|
| 160 |
+
self.max_retries = max_retries
|
| 161 |
self._running = False
|
| 162 |
self._current_job_id: Optional[str] = None
|
| 163 |
|
|
|
|
| 184 |
async def _process_one_job(self):
|
| 185 |
"""Find and process one job."""
|
| 186 |
async with self.session_maker() as session:
|
|
|
|
| 187 |
now = datetime.utcnow()
|
| 188 |
+
|
| 189 |
+
# Query for jobs matching this priority tier
|
| 190 |
+
query = select(self.job_model).where(
|
| 191 |
and_(
|
| 192 |
+
self.job_model.priority == self.priority,
|
| 193 |
+
self.job_model.status.in_(["queued", "processing"]),
|
| 194 |
or_(
|
| 195 |
+
self.job_model.next_process_at.is_(None),
|
| 196 |
+
self.job_model.next_process_at <= now
|
| 197 |
)
|
| 198 |
)
|
| 199 |
+
).order_by(self.job_model.created_at).limit(1)
|
| 200 |
|
| 201 |
result = await session.execute(query)
|
| 202 |
job = result.scalar_one_or_none()
|
| 203 |
|
| 204 |
if not job:
|
| 205 |
+
return
|
| 206 |
|
| 207 |
self._current_job_id = job.job_id
|
| 208 |
|
|
|
|
| 217 |
finally:
|
| 218 |
self._current_job_id = None
|
| 219 |
|
| 220 |
+
async def _process_job(self, session: AsyncSession, job: JobType):
|
| 221 |
"""Process a single job."""
|
| 222 |
+
logger.info(f"Worker {self.worker_id}: Processing job {job.job_id} (status: {job.status})")
|
|
|
|
|
|
|
|
|
|
| 223 |
|
|
|
|
| 224 |
if job.status == "queued":
|
| 225 |
+
# New job - start processing
|
| 226 |
job.status = "processing"
|
| 227 |
job.started_at = datetime.utcnow()
|
| 228 |
await session.commit()
|
| 229 |
|
| 230 |
+
# Process the job
|
| 231 |
+
job = await self.job_processor.process(job, session)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
else:
|
| 233 |
+
# Already processing - check status
|
| 234 |
+
job = await self.job_processor.check_status(job, session)
|
| 235 |
+
|
| 236 |
+
# Handle retry limit
|
| 237 |
+
if job.status == "processing" and job.retry_count > self.max_retries:
|
| 238 |
job.status = "failed"
|
| 239 |
+
job.error_message = f"Max retries ({self.max_retries}) exceeded"
|
| 240 |
job.completed_at = datetime.utcnow()
|
| 241 |
+
|
| 242 |
+
await session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
+
class PriorityWorkerPool(Generic[JobType]):
|
| 246 |
+
"""
|
| 247 |
+
Modular priority-tier worker pool.
|
| 248 |
|
| 249 |
+
Can be used with any job model that follows the required schema.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
database_url: str,
|
| 255 |
+
job_model: type,
|
| 256 |
+
job_processor: JobProcessor[JobType],
|
| 257 |
+
config: Optional[WorkerConfig] = None
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Initialize the worker pool.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
database_url: SQLAlchemy async database URL
|
| 264 |
+
job_model: Your ORM model class for jobs
|
| 265 |
+
job_processor: Instance of JobProcessor to handle jobs
|
| 266 |
+
config: Worker configuration (uses env vars if not provided)
|
| 267 |
+
"""
|
| 268 |
+
self.database_url = database_url
|
| 269 |
+
self.job_model = job_model
|
| 270 |
+
self.job_processor = job_processor
|
| 271 |
+
self.config = config or WorkerConfig.from_env()
|
| 272 |
+
|
| 273 |
+
self.engine = create_async_engine(database_url, echo=False)
|
| 274 |
self.session_maker = async_sessionmaker(
|
| 275 |
self.engine,
|
| 276 |
class_=AsyncSession,
|
|
|
|
| 285 |
worker_id = 0
|
| 286 |
|
| 287 |
# Create fast workers
|
| 288 |
+
for i in range(self.config.fast_workers):
|
| 289 |
+
worker = PriorityWorker(
|
| 290 |
+
worker_id=worker_id,
|
| 291 |
+
priority="fast",
|
| 292 |
+
poll_interval=self.config.fast_interval,
|
| 293 |
+
session_maker=self.session_maker,
|
| 294 |
+
job_model=self.job_model,
|
| 295 |
+
job_processor=self.job_processor,
|
| 296 |
+
max_retries=self.config.max_retries
|
| 297 |
+
)
|
| 298 |
self.workers.append(worker)
|
| 299 |
await worker.start()
|
| 300 |
worker_id += 1
|
| 301 |
|
| 302 |
# Create medium workers
|
| 303 |
+
for i in range(self.config.medium_workers):
|
| 304 |
+
worker = PriorityWorker(
|
| 305 |
+
worker_id=worker_id,
|
| 306 |
+
priority="medium",
|
| 307 |
+
poll_interval=self.config.medium_interval,
|
| 308 |
+
session_maker=self.session_maker,
|
| 309 |
+
job_model=self.job_model,
|
| 310 |
+
job_processor=self.job_processor,
|
| 311 |
+
max_retries=self.config.max_retries
|
| 312 |
+
)
|
| 313 |
self.workers.append(worker)
|
| 314 |
await worker.start()
|
| 315 |
worker_id += 1
|
| 316 |
|
| 317 |
# Create slow workers
|
| 318 |
+
for i in range(self.config.slow_workers):
|
| 319 |
+
worker = PriorityWorker(
|
| 320 |
+
worker_id=worker_id,
|
| 321 |
+
priority="slow",
|
| 322 |
+
poll_interval=self.config.slow_interval,
|
| 323 |
+
session_maker=self.session_maker,
|
| 324 |
+
job_model=self.job_model,
|
| 325 |
+
job_processor=self.job_processor,
|
| 326 |
+
max_retries=self.config.max_retries
|
| 327 |
+
)
|
| 328 |
self.workers.append(worker)
|
| 329 |
await worker.start()
|
| 330 |
worker_id += 1
|
| 331 |
|
| 332 |
+
total = self.config.fast_workers + self.config.medium_workers + self.config.slow_workers
|
| 333 |
+
logger.info(
|
| 334 |
+
f"PriorityWorkerPool started with {total} workers: "
|
| 335 |
+
f"{self.config.fast_workers} fast, {self.config.medium_workers} medium, {self.config.slow_workers} slow"
|
| 336 |
+
)
|
| 337 |
|
| 338 |
async def stop(self):
|
| 339 |
"""Stop all workers."""
|
| 340 |
self._running = False
|
| 341 |
for worker in self.workers:
|
| 342 |
await worker.stop()
|
| 343 |
+
logger.info("PriorityWorkerPool stopped")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
|
| 346 |
+
# Convenience functions for priority mapping
|
| 347 |
+
def get_priority_for_job_type(job_type: str, mappings: dict) -> str:
|
| 348 |
+
"""Get priority tier for a job type using provided mappings."""
|
| 349 |
+
return mappings.get(job_type, "fast")
|
|
|
|
|
|
|
| 350 |
|
| 351 |
|
| 352 |
+
def get_interval_for_priority(priority: str, config: Optional[WorkerConfig] = None) -> int:
|
| 353 |
+
"""Get polling interval for a priority tier."""
|
| 354 |
+
cfg = config or WorkerConfig.from_env()
|
| 355 |
+
if priority == "fast":
|
| 356 |
+
return cfg.fast_interval
|
| 357 |
+
elif priority == "medium":
|
| 358 |
+
return cfg.medium_interval
|
| 359 |
+
else:
|
| 360 |
+
return cfg.slow_interval
|
|
|
tests/test_worker_pool.py
CHANGED
|
@@ -7,13 +7,19 @@ import asyncio
|
|
| 7 |
from unittest.mock import patch, MagicMock, AsyncMock
|
| 8 |
from datetime import datetime, timedelta
|
| 9 |
|
| 10 |
-
# Test the priority
|
| 11 |
-
from services.
|
| 12 |
-
|
| 13 |
-
get_interval_for_priority,
|
| 14 |
PriorityWorker,
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
|
@@ -68,48 +74,37 @@ class TestJobPriorityMap:
|
|
| 68 |
|
| 69 |
|
| 70 |
class TestWorkerPoolConfiguration:
|
| 71 |
-
"""Test worker pool
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
"""Test
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
await pool.start()
|
| 103 |
-
|
| 104 |
-
for worker in pool.workers:
|
| 105 |
-
if worker.priority == "fast":
|
| 106 |
-
assert worker.poll_interval == 5
|
| 107 |
-
elif worker.priority == "medium":
|
| 108 |
-
assert worker.poll_interval == 30
|
| 109 |
-
elif worker.priority == "slow":
|
| 110 |
-
assert worker.poll_interval == 60
|
| 111 |
-
|
| 112 |
-
await pool.stop()
|
| 113 |
|
| 114 |
|
| 115 |
class TestPriorityWorker:
|
|
@@ -117,7 +112,15 @@ class TestPriorityWorker:
|
|
| 117 |
|
| 118 |
def test_worker_has_correct_attributes(self):
|
| 119 |
"""Test worker initialization."""
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
assert worker.worker_id == 0
|
| 123 |
assert worker.priority == "fast"
|
|
|
|
| 7 |
from unittest.mock import patch, MagicMock, AsyncMock
|
| 8 |
from datetime import datetime, timedelta
|
| 9 |
|
| 10 |
+
# Test the modular priority worker pool
|
| 11 |
+
from services.priority_worker_pool import (
|
| 12 |
+
PriorityWorkerPool,
|
|
|
|
| 13 |
PriorityWorker,
|
| 14 |
+
WorkerConfig,
|
| 15 |
+
get_interval_for_priority
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Test the Gemini-specific implementation
|
| 19 |
+
from services.gemini_job_worker import (
|
| 20 |
+
get_priority_for_job_type,
|
| 21 |
+
JOB_PRIORITY_MAP,
|
| 22 |
+
GeminiJobProcessor
|
| 23 |
)
|
| 24 |
|
| 25 |
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
class TestWorkerPoolConfiguration:
|
| 77 |
+
"""Test worker pool configuration."""
|
| 78 |
+
|
| 79 |
+
def test_default_config(self):
|
| 80 |
+
"""Test WorkerConfig defaults."""
|
| 81 |
+
config = WorkerConfig()
|
| 82 |
+
assert config.fast_workers == 5
|
| 83 |
+
assert config.medium_workers == 5
|
| 84 |
+
assert config.slow_workers == 5
|
| 85 |
+
assert config.fast_interval == 5
|
| 86 |
+
assert config.medium_interval == 30
|
| 87 |
+
assert config.slow_interval == 60
|
| 88 |
+
|
| 89 |
+
def test_custom_config(self):
|
| 90 |
+
"""Test WorkerConfig with custom values."""
|
| 91 |
+
config = WorkerConfig(
|
| 92 |
+
fast_workers=3,
|
| 93 |
+
medium_workers=2,
|
| 94 |
+
slow_workers=1,
|
| 95 |
+
fast_interval=10,
|
| 96 |
+
medium_interval=60,
|
| 97 |
+
slow_interval=120
|
| 98 |
+
)
|
| 99 |
+
assert config.fast_workers == 3
|
| 100 |
+
assert config.medium_workers == 2
|
| 101 |
+
assert config.slow_workers == 1
|
| 102 |
+
|
| 103 |
+
def test_total_workers_calculation(self):
|
| 104 |
+
"""Test total workers from config."""
|
| 105 |
+
config = WorkerConfig(fast_workers=5, medium_workers=5, slow_workers=5)
|
| 106 |
+
total = config.fast_workers + config.medium_workers + config.slow_workers
|
| 107 |
+
assert total == 15
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
class TestPriorityWorker:
|
|
|
|
| 112 |
|
| 113 |
def test_worker_has_correct_attributes(self):
|
| 114 |
"""Test worker initialization."""
|
| 115 |
+
# PriorityWorker now requires more args, test with mocks
|
| 116 |
+
worker = PriorityWorker(
|
| 117 |
+
worker_id=0,
|
| 118 |
+
priority="fast",
|
| 119 |
+
poll_interval=5,
|
| 120 |
+
session_maker=None,
|
| 121 |
+
job_model=None,
|
| 122 |
+
job_processor=None
|
| 123 |
+
)
|
| 124 |
|
| 125 |
assert worker.worker_id == 0
|
| 126 |
assert worker.priority == "fast"
|