jebin2 commited on
Commit
02588dd
·
1 Parent(s): c4f61f9
routers/gemini.py CHANGED
@@ -372,8 +372,8 @@ async def get_job_status(
372
  response["output"] = job.output_data
373
 
374
  # For video jobs, add download URL
375
- if job.job_type == "video" and job.output_data and (job.output_data.get("filename") or job.output_data.get("video_url")):
376
- response["download_url"] = f"/gemini/download/{job.job_id}"
377
 
378
  if job.status == "failed":
379
  response["error"] = job.error_message
 
372
  response["output"] = job.output_data
373
 
374
  # For video jobs, add download URL
375
+ # if job.job_type == "video" and job.output_data and (job.output_data.get("filename") or job.output_data.get("video_url")):
376
+ # response["download_url"] = f"/gemini/download/{job.job_id}"
377
 
378
  if job.status == "failed":
379
  response["error"] = job.error_message
services/priority_worker_pool.py CHANGED
@@ -73,6 +73,7 @@ class WorkerConfig:
73
  medium_interval: int = 10 # seconds
74
  slow_interval: int = 15 # seconds
75
  max_retries: int = 60 # Max retry attempts before failing
 
76
 
77
  @classmethod
78
  def from_env(cls) -> 'WorkerConfig':
@@ -84,6 +85,7 @@ class WorkerConfig:
84
  fast_interval=int(os.getenv("FAST_INTERVAL", "5")),
85
  medium_interval=int(os.getenv("MEDIUM_INTERVAL", "30")),
86
  slow_interval=int(os.getenv("SLOW_INTERVAL", "60")),
 
87
  )
88
 
89
 
@@ -150,7 +152,8 @@ class PriorityWorker(Generic[JobType]):
150
  job_model: type,
151
  job_processor: JobProcessor[JobType],
152
  max_retries: int = 60,
153
- wake_event: Optional[asyncio.Event] = None
 
154
  ):
155
  self.worker_id = worker_id
156
  self.priority = priority
@@ -162,6 +165,7 @@ class PriorityWorker(Generic[JobType]):
162
  self._running = False
163
  self._current_job_id: Optional[str] = None
164
  self._wake_event = wake_event # Event to wake worker immediately when new jobs arrive
 
165
 
166
  async def start(self):
167
  """Start the worker polling loop."""
@@ -210,12 +214,33 @@ class PriorityWorker(Generic[JobType]):
210
  async def _process_one_job(self) -> bool:
211
  """Find and process one job.
212
 
 
 
 
 
213
  Returns:
214
  True if a job was found and processed, False if no jobs available
215
  """
216
  async with self.session_maker() as session:
 
 
217
  now = datetime.utcnow()
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Query for jobs matching this priority tier
220
  query = select(self.job_model).where(
221
  and_(
@@ -234,6 +259,27 @@ class PriorityWorker(Generic[JobType]):
234
  if not job:
235
  return False
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  self._current_job_id = job.job_id
238
 
239
  try:
@@ -400,7 +446,8 @@ class PriorityWorkerPool(Generic[JobType]):
400
  job_model=self.job_model,
401
  job_processor=self.job_processor,
402
  max_retries=self.config.max_retries,
403
- wake_event=self._wake_events["fast"]
 
404
  )
405
  self.workers.append(worker)
406
  await worker.start()
@@ -416,7 +463,8 @@ class PriorityWorkerPool(Generic[JobType]):
416
  job_model=self.job_model,
417
  job_processor=self.job_processor,
418
  max_retries=self.config.max_retries,
419
- wake_event=self._wake_events["medium"]
 
420
  )
421
  self.workers.append(worker)
422
  await worker.start()
@@ -432,7 +480,8 @@ class PriorityWorkerPool(Generic[JobType]):
432
  job_model=self.job_model,
433
  job_processor=self.job_processor,
434
  max_retries=self.config.max_retries,
435
- wake_event=self._wake_events["slow"]
 
436
  )
437
  self.workers.append(worker)
438
  await worker.start()
 
73
  medium_interval: int = 10 # seconds
74
  slow_interval: int = 15 # seconds
75
  max_retries: int = 60 # Max retry attempts before failing
76
+ job_per_api_key: int = 1 # Max concurrent jobs per API key
77
 
78
  @classmethod
79
  def from_env(cls) -> 'WorkerConfig':
 
85
  fast_interval=int(os.getenv("FAST_INTERVAL", "5")),
86
  medium_interval=int(os.getenv("MEDIUM_INTERVAL", "30")),
87
  slow_interval=int(os.getenv("SLOW_INTERVAL", "60")),
88
+ job_per_api_key=int(os.getenv("JOB_PER_API_KEY", "1")),
89
  )
90
 
91
 
 
152
  job_model: type,
153
  job_processor: JobProcessor[JobType],
154
  max_retries: int = 60,
155
+ wake_event: Optional[asyncio.Event] = None,
156
+ config: Optional[WorkerConfig] = None
157
  ):
158
  self.worker_id = worker_id
159
  self.priority = priority
 
165
  self._running = False
166
  self._current_job_id: Optional[str] = None
167
  self._wake_event = wake_event # Event to wake worker immediately when new jobs arrive
168
+ self._config = config or WorkerConfig.from_env()
169
 
170
  async def start(self):
171
  """Start the worker polling loop."""
 
214
  async def _process_one_job(self) -> bool:
215
  """Find and process one job.
216
 
217
+ Enforces constraints:
218
+ 1. Only one job per user can be in processing state at a time
219
+ 2. Total processing jobs limited to JOB_PER_API_KEY * number of API keys
220
+
221
  Returns:
222
  True if a job was found and processed, False if no jobs available
223
  """
224
  async with self.session_maker() as session:
225
+ from sqlalchemy import func
226
+
227
  now = datetime.utcnow()
228
 
229
+ # Get number of API keys for capacity calculation
230
+ try:
231
+ from services.api_key_manager import get_key_count
232
+ num_api_keys = get_key_count()
233
+ max_processing = self._config.job_per_api_key * num_api_keys
234
+ except ImportError:
235
+ max_processing = 10 # Default fallback
236
+
237
+ # Check if we're at max processing capacity (only for new jobs being picked up)
238
+ count_query = select(func.count()).where(
239
+ self.job_model.status == "processing"
240
+ )
241
+ count_result = await session.execute(count_query)
242
+ current_processing = count_result.scalar() or 0
243
+
244
  # Query for jobs matching this priority tier
245
  query = select(self.job_model).where(
246
  and_(
 
259
  if not job:
260
  return False
261
 
262
+ # For queued jobs, apply the constraints
263
+ if job.status == "queued":
264
+ # Constraint 1: Check if this user already has a job in processing
265
+ user_processing_query = select(func.count()).where(
266
+ and_(
267
+ self.job_model.user_id == job.user_id,
268
+ self.job_model.status == "processing"
269
+ )
270
+ )
271
+ user_result = await session.execute(user_processing_query)
272
+ user_processing_count = user_result.scalar() or 0
273
+
274
+ if user_processing_count > 0:
275
+ logger.debug(f"Worker {self.worker_id}: User {job.user_id} already has a job processing, skipping")
276
+ return False
277
+
278
+ # Constraint 2: Check if we're at max total processing capacity
279
+ if current_processing >= max_processing:
280
+ logger.debug(f"Worker {self.worker_id}: At max capacity ({current_processing}/{max_processing}), skipping new job")
281
+ return False
282
+
283
  self._current_job_id = job.job_id
284
 
285
  try:
 
446
  job_model=self.job_model,
447
  job_processor=self.job_processor,
448
  max_retries=self.config.max_retries,
449
+ wake_event=self._wake_events["fast"],
450
+ config=self.config
451
  )
452
  self.workers.append(worker)
453
  await worker.start()
 
463
  job_model=self.job_model,
464
  job_processor=self.job_processor,
465
  max_retries=self.config.max_retries,
466
+ wake_event=self._wake_events["medium"],
467
+ config=self.config
468
  )
469
  self.workers.append(worker)
470
  await worker.start()
 
480
  job_model=self.job_model,
481
  job_processor=self.job_processor,
482
  max_retries=self.config.max_retries,
483
+ wake_event=self._wake_events["slow"],
484
+ config=self.config
485
  )
486
  self.workers.append(worker)
487
  await worker.start()