jebin2 commited on
Commit
816ccbe
·
1 Parent(s): 19e4a8c

new worker

Browse files
core/models.py CHANGED
@@ -103,6 +103,7 @@ class AuditLog(Base):
103
  class GeminiJob(Base):
104
  """
105
  Generic job queue for Gemini operations (video, image, text).
 
106
  """
107
  __tablename__ = "gemini_jobs"
108
 
@@ -111,7 +112,13 @@ class GeminiJob(Base):
111
  user_id = Column(String(50), index=True, nullable=False) # User who requested
112
  job_type = Column(String(20), index=True, nullable=False) # video, image, text, analyze
113
  third_party_id = Column(String(255), nullable=True) # Gemini operation name (for video)
114
- status = Column(String(20), default="queued", index=True) # queued, processing, completed, failed
 
 
 
 
 
 
115
  input_data = Column(JSON, nullable=True) # Request details (prompt, settings, etc.)
116
  output_data = Column(JSON, nullable=True) # Result (filename, text, etc.)
117
  error_message = Column(Text, nullable=True)
@@ -120,4 +127,5 @@ class GeminiJob(Base):
120
  completed_at = Column(DateTime(timezone=True), nullable=True)
121
 
122
  def __repr__(self):
123
- return f"<GeminiJob(job_id={self.job_id}, type={self.job_type}, status={self.status})>"
 
 
103
  class GeminiJob(Base):
104
  """
105
  Generic job queue for Gemini operations (video, image, text).
106
+ Uses priority-tier system for worker assignment.
107
  """
108
  __tablename__ = "gemini_jobs"
109
 
 
112
  user_id = Column(String(50), index=True, nullable=False) # User who requested
113
  job_type = Column(String(20), index=True, nullable=False) # video, image, text, analyze
114
  third_party_id = Column(String(255), nullable=True) # Gemini operation name (for video)
115
+ status = Column(String(20), default="queued", index=True) # queued, processing, completed, failed, cancelled
116
+
117
+ # Priority-tier worker system
118
+ priority = Column(String(10), default="fast", index=True) # fast (5s), medium (30s), slow (60s)
119
+ next_process_at = Column(DateTime(timezone=True), nullable=True, index=True) # When worker should pick up again
120
+ retry_count = Column(Integer, default=0) # Number of status check retries
121
+
122
  input_data = Column(JSON, nullable=True) # Request details (prompt, settings, etc.)
123
  output_data = Column(JSON, nullable=True) # Result (filename, text, etc.)
124
  error_message = Column(Text, nullable=True)
 
127
  completed_at = Column(DateTime(timezone=True), nullable=True)
128
 
129
  def __repr__(self):
130
+ return f"<GeminiJob(job_id={self.job_id}, type={self.job_type}, status={self.status}, priority={self.priority})>"
131
+
routers/gemini.py CHANGED
@@ -74,14 +74,18 @@ async def create_job(
74
  job_type: str,
75
  input_data: dict
76
  ) -> GeminiJob:
77
- """Create a new job in the queue."""
 
 
78
  job_id = f"job_{uuid.uuid4().hex[:16]}"
 
79
 
80
  job = GeminiJob(
81
  job_id=job_id,
82
  user_id=user.user_id,
83
  job_type=job_type,
84
  status="queued",
 
85
  input_data=input_data
86
  )
87
  db.add(job)
@@ -344,6 +348,48 @@ async def download_video(
344
  )
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  @router.get("/models")
348
  async def get_models():
349
  """
 
74
  job_type: str,
75
  input_data: dict
76
  ) -> GeminiJob:
77
+ """Create a new job in the queue with auto-assigned priority."""
78
+ from services.job_worker import get_priority_for_job_type
79
+
80
  job_id = f"job_{uuid.uuid4().hex[:16]}"
81
+ priority = get_priority_for_job_type(job_type)
82
 
83
  job = GeminiJob(
84
  job_id=job_id,
85
  user_id=user.user_id,
86
  job_type=job_type,
87
  status="queued",
88
+ priority=priority,
89
  input_data=input_data
90
  )
91
  db.add(job)
 
348
  )
349
 
350
 
351
+ @router.post("/job/{job_id}/cancel")
352
+ async def cancel_job(
353
+ job_id: str,
354
+ user: User = Depends(get_current_user),
355
+ db: AsyncSession = Depends(get_db)
356
+ ):
357
+ """
358
+ Cancel a queued job.
359
+ Only jobs with status 'queued' can be cancelled.
360
+ Processing/completed/failed jobs cannot be cancelled.
361
+ """
362
+ query = select(GeminiJob).where(
363
+ GeminiJob.job_id == job_id,
364
+ GeminiJob.user_id == user.user_id
365
+ )
366
+ result = await db.execute(query)
367
+ job = result.scalar_one_or_none()
368
+
369
+ if not job:
370
+ raise HTTPException(
371
+ status_code=status.HTTP_404_NOT_FOUND,
372
+ detail="Job not found"
373
+ )
374
+
375
+ if job.status != "queued":
376
+ raise HTTPException(
377
+ status_code=status.HTTP_400_BAD_REQUEST,
378
+ detail=f"Cannot cancel job with status '{job.status}'. Only queued jobs can be cancelled."
379
+ )
380
+
381
+ job.status = "cancelled"
382
+ job.completed_at = datetime.utcnow()
383
+ await db.commit()
384
+
385
+ return {
386
+ "success": True,
387
+ "job_id": job.job_id,
388
+ "status": "cancelled",
389
+ "message": "Job cancelled successfully"
390
+ }
391
+
392
+
393
  @router.get("/models")
394
  async def get_models():
395
  """
services/job_worker.py CHANGED
@@ -1,184 +1,219 @@
1
  """
2
- Background worker for processing Gemini jobs.
3
- Runs continuously, picks up queued jobs, processes them.
 
 
 
 
 
4
  """
5
  import asyncio
6
  import logging
7
  import os
8
- from datetime import datetime
9
- from sqlalchemy import select
 
10
  from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
11
 
12
  from core.database import DATABASE_URL
13
  from core.models import GeminiJob
14
- from services.gemini_service import GeminiService, DOWNLOADS_DIR
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
  # Worker configuration
19
- WORKER_POLL_INTERVAL = int(os.getenv("WORKER_POLL_INTERVAL", "5")) # seconds
20
- MAX_CONCURRENT_VIDEO_JOBS = int(os.getenv("MAX_CONCURRENT_VIDEO_JOBS", "2"))
21
- MAX_CONCURRENT_IMAGE_JOBS = int(os.getenv("MAX_CONCURRENT_IMAGE_JOBS", "3"))
22
- MAX_CONCURRENT_TEXT_JOBS = int(os.getenv("MAX_CONCURRENT_TEXT_JOBS", "5"))
 
 
 
23
 
24
- # Track running jobs
25
- _running_jobs = {
26
- "video": 0,
27
- "image": 0,
28
- "text": 0,
29
- "analyze": 0,
30
- "animation_prompt": 0
 
31
  }
32
- _lock = asyncio.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
- class JobWorker:
36
- """Background worker for processing Gemini jobs."""
37
 
38
- def __init__(self):
39
- self.engine = create_async_engine(DATABASE_URL, echo=False)
40
- self.async_session = async_sessionmaker(
41
- self.engine,
42
- class_=AsyncSession,
43
- expire_on_commit=False
44
- )
45
  self._running = False
46
- self._tasks = []
47
 
48
  async def start(self):
49
- """Start the worker."""
50
  self._running = True
51
- logger.info("Job worker started")
52
  asyncio.create_task(self._poll_loop())
53
 
54
  async def stop(self):
55
  """Stop the worker."""
56
  self._running = False
57
- # Wait for running tasks to complete
58
- if self._tasks:
59
- await asyncio.gather(*self._tasks, return_exceptions=True)
60
- logger.info("Job worker stopped")
61
 
62
  async def _poll_loop(self):
63
  """Main polling loop."""
64
  while self._running:
65
  try:
66
- await self._process_queued_jobs()
67
  except Exception as e:
68
- logger.error(f"Error in worker poll loop: {e}")
69
- await asyncio.sleep(WORKER_POLL_INTERVAL)
70
 
71
- async def _process_queued_jobs(self):
72
- """Find and process queued jobs."""
73
- async with self.async_session() as session:
74
- # Get queued jobs
 
75
  query = select(GeminiJob).where(
76
- GeminiJob.status == "queued"
77
- ).order_by(GeminiJob.created_at)
 
 
 
 
 
 
 
78
 
79
  result = await session.execute(query)
80
- jobs = result.scalars().all()
 
 
 
 
 
81
 
82
- for job in jobs:
83
- # Check if we can process this job type
84
- if not await self._can_process(job.job_type):
85
- continue
86
-
87
- # Mark as processing
88
- async with _lock:
89
- _running_jobs[job.job_type] = _running_jobs.get(job.job_type, 0) + 1
90
-
91
- try:
92
- job.status = "processing"
93
- job.started_at = datetime.utcnow()
94
- await session.commit()
95
-
96
- # Process job in background
97
- task = asyncio.create_task(self._process_job(job.job_id))
98
- self._tasks.append(task)
99
- task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
100
- except Exception as e:
101
- logger.error(f"Error starting job {job.job_id}: {e}")
102
- async with _lock:
103
- _running_jobs[job.job_type] = max(0, _running_jobs.get(job.job_type, 0) - 1)
104
-
105
- async def _can_process(self, job_type: str) -> bool:
106
- """Check if we can process another job of this type."""
107
- async with _lock:
108
- current = _running_jobs.get(job_type, 0)
109
- if job_type == "video":
110
- return current < MAX_CONCURRENT_VIDEO_JOBS
111
- elif job_type in ("image", "edit_image"):
112
- return current < MAX_CONCURRENT_IMAGE_JOBS
113
- else: # text, analyze, animation_prompt
114
- return current < MAX_CONCURRENT_TEXT_JOBS
115
-
116
- async def _process_job(self, job_id: str):
117
- """Process a single job."""
118
- async with self.async_session() as session:
119
  try:
120
- # Get the job
121
- query = select(GeminiJob).where(GeminiJob.job_id == job_id)
122
- result = await session.execute(query)
123
- job = result.scalar_one_or_none()
124
-
125
- if not job:
126
- logger.error(f"Job {job_id} not found")
127
- return
128
-
129
- logger.info(f"Processing job {job_id} (type: {job.job_type})")
130
-
131
- service = GeminiService()
132
- input_data = job.input_data or {}
133
-
134
- if job.job_type == "video":
135
- await self._process_video_job(session, job, service, input_data)
136
- elif job.job_type == "image":
137
- await self._process_image_job(session, job, service, input_data)
138
- elif job.job_type == "text":
139
- await self._process_text_job(session, job, service, input_data)
140
- elif job.job_type == "analyze":
141
- await self._process_analyze_job(session, job, service, input_data)
142
- elif job.job_type == "animation_prompt":
143
- await self._process_animation_prompt_job(session, job, service, input_data)
144
- else:
145
- job.status = "failed"
146
- job.error_message = f"Unknown job type: {job.job_type}"
147
- job.completed_at = datetime.utcnow()
148
- await session.commit()
149
-
150
  except Exception as e:
151
- logger.error(f"Error processing job {job_id}: {e}")
152
- try:
153
- job.status = "failed"
154
- job.error_message = str(e)
155
- job.completed_at = datetime.utcnow()
156
- await session.commit()
157
- except:
158
- pass
159
  finally:
160
- async with _lock:
161
- job_type = job.job_type if job else "unknown"
162
- _running_jobs[job_type] = max(0, _running_jobs.get(job_type, 0) - 1)
163
 
164
- async def _process_video_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
165
- """Process a video generation job."""
166
- # Start video generation
167
- result = await service.start_video_generation(
168
- base64_image=input_data.get("base64_image", ""),
169
- mime_type=input_data.get("mime_type", "image/jpeg"),
170
- prompt=input_data.get("prompt", ""),
171
- aspect_ratio=input_data.get("aspect_ratio", "16:9"),
172
- resolution=input_data.get("resolution", "720p"),
173
- number_of_videos=input_data.get("number_of_videos", 1)
174
- )
175
 
176
- # Save third party ID
177
- job.third_party_id = result.get("gemini_operation_name")
178
- await session.commit()
 
 
 
 
 
179
 
180
- # Poll until done
181
- while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  status_result = await service.check_video_status(job.third_party_id)
183
 
184
  if status_result.get("done"):
@@ -197,82 +232,93 @@ class JobWorker:
197
  job.error_message = status_result.get("error", "Unknown error")
198
 
199
  job.completed_at = datetime.utcnow()
200
- await session.commit()
201
- break
 
 
 
202
 
203
- await asyncio.sleep(10) # Poll every 10 seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- async def _process_image_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
206
- """Process an image edit job."""
207
- result = await service.edit_image(
208
- base64_image=input_data.get("base64_image", ""),
209
- mime_type=input_data.get("mime_type", "image/jpeg"),
210
- prompt=input_data.get("prompt", "")
211
  )
212
-
213
- job.status = "completed"
214
- job.output_data = {"image": result}
215
- job.completed_at = datetime.utcnow()
216
- await session.commit()
217
 
218
- async def _process_text_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
219
- """Process a text generation job."""
220
- result = await service.generate_text(
221
- prompt=input_data.get("prompt", ""),
222
- model=input_data.get("model")
223
- )
224
 
225
- job.status = "completed"
226
- job.output_data = {"text": result}
227
- job.completed_at = datetime.utcnow()
228
- await session.commit()
229
-
230
- async def _process_analyze_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
231
- """Process an image analysis job."""
232
- result = await service.analyze_image(
233
- base64_image=input_data.get("base64_image", ""),
234
- mime_type=input_data.get("mime_type", "image/jpeg"),
235
- prompt=input_data.get("prompt", "")
236
- )
237
 
238
- job.status = "completed"
239
- job.output_data = {"analysis": result}
240
- job.completed_at = datetime.utcnow()
241
- await session.commit()
242
-
243
- async def _process_animation_prompt_job(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
244
- """Process an animation prompt generation job."""
245
- result = await service.generate_animation_prompt(
246
- base64_image=input_data.get("base64_image", ""),
247
- mime_type=input_data.get("mime_type", "image/jpeg"),
248
- custom_prompt=input_data.get("custom_prompt")
249
- )
 
250
 
251
- job.status = "completed"
252
- job.output_data = {"prompt": result}
253
- job.completed_at = datetime.utcnow()
254
- await session.commit()
 
 
 
 
255
 
256
 
257
- # Singleton worker instance
258
- _worker: JobWorker = None
259
 
260
 
261
- def get_worker() -> JobWorker:
262
- """Get the global worker instance."""
263
- global _worker
264
- if _worker is None:
265
- _worker = JobWorker()
266
- return _worker
267
 
268
 
269
  async def start_worker():
270
- """Start the background worker."""
271
- worker = get_worker()
272
- await worker.start()
273
 
274
 
275
  async def stop_worker():
276
- """Stop the background worker."""
277
- worker = get_worker()
278
- await worker.stop()
 
1
  """
2
+ Priority-Tier Worker Pool for processing Gemini jobs.
3
+
4
+ Architecture:
5
+ - 15 workers total: 5 fast (5s), 5 medium (30s), 5 slow (60s)
6
+ - Workers pick jobs based on priority tier
7
+ - Jobs are rescheduled with next_process_at if third-party not done
8
+ - No blocking on third-party polling
9
  """
10
  import asyncio
11
  import logging
12
  import os
13
+ from datetime import datetime, timedelta
14
+ from typing import Optional, List
15
+ from sqlalchemy import select, or_, and_
16
  from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
17
 
18
  from core.database import DATABASE_URL
19
  from core.models import GeminiJob
20
+ from services.gemini_service import GeminiService
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
  # Worker configuration
25
+ FAST_WORKERS = int(os.getenv("FAST_WORKERS", "5"))
26
+ MEDIUM_WORKERS = int(os.getenv("MEDIUM_WORKERS", "5"))
27
+ SLOW_WORKERS = int(os.getenv("SLOW_WORKERS", "5"))
28
+
29
+ FAST_INTERVAL = int(os.getenv("FAST_INTERVAL", "5")) # 5 seconds
30
+ MEDIUM_INTERVAL = int(os.getenv("MEDIUM_INTERVAL", "30")) # 30 seconds
31
+ SLOW_INTERVAL = int(os.getenv("SLOW_INTERVAL", "60")) # 60 seconds
32
 
33
+ # Job type to priority mapping
34
+ JOB_PRIORITY_MAP = {
35
+ "text": "fast",
36
+ "analyze": "fast",
37
+ "animation_prompt": "fast",
38
+ "image": "medium",
39
+ "edit_image": "medium",
40
+ "video": "slow"
41
  }
42
+
43
+ def get_priority_for_job_type(job_type: str) -> str:
44
+ """Get the priority tier for a job type."""
45
+ return JOB_PRIORITY_MAP.get(job_type, "fast")
46
+
47
+ def get_interval_for_priority(priority: str) -> int:
48
+ """Get the polling interval in seconds for a priority tier."""
49
+ if priority == "fast":
50
+ return FAST_INTERVAL
51
+ elif priority == "medium":
52
+ return MEDIUM_INTERVAL
53
+ else:
54
+ return SLOW_INTERVAL
55
 
56
 
57
+ class PriorityWorker:
58
+ """Worker that processes jobs of a specific priority tier."""
59
 
60
+ def __init__(self, worker_id: int, priority: str, poll_interval: int, session_maker):
61
+ self.worker_id = worker_id
62
+ self.priority = priority
63
+ self.poll_interval = poll_interval
64
+ self.session_maker = session_maker
 
 
65
  self._running = False
66
+ self._current_job_id: Optional[str] = None
67
 
68
  async def start(self):
69
+ """Start the worker polling loop."""
70
  self._running = True
71
+ logger.info(f"Worker {self.worker_id} ({self.priority}) started, polling every {self.poll_interval}s")
72
  asyncio.create_task(self._poll_loop())
73
 
74
  async def stop(self):
75
  """Stop the worker."""
76
  self._running = False
77
+ logger.info(f"Worker {self.worker_id} ({self.priority}) stopped")
 
 
 
78
 
79
  async def _poll_loop(self):
80
  """Main polling loop."""
81
  while self._running:
82
  try:
83
+ await self._process_one_job()
84
  except Exception as e:
85
+ logger.error(f"Worker {self.worker_id}: Error in poll loop: {e}")
86
+ await asyncio.sleep(self.poll_interval)
87
 
88
+ async def _process_one_job(self):
89
+ """Find and process one job."""
90
+ async with self.session_maker() as session:
91
+ # Find a job to process
92
+ now = datetime.utcnow()
93
  query = select(GeminiJob).where(
94
+ and_(
95
+ GeminiJob.priority == self.priority,
96
+ GeminiJob.status.in_(["queued", "processing"]),
97
+ or_(
98
+ GeminiJob.next_process_at.is_(None),
99
+ GeminiJob.next_process_at <= now
100
+ )
101
+ )
102
+ ).order_by(GeminiJob.created_at).limit(1)
103
 
104
  result = await session.execute(query)
105
+ job = result.scalar_one_or_none()
106
+
107
+ if not job:
108
+ return # No jobs to process
109
+
110
+ self._current_job_id = job.job_id
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  try:
113
+ await self._process_job(session, job)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  except Exception as e:
115
+ logger.error(f"Worker {self.worker_id}: Error processing job {job.job_id}: {e}")
116
+ job.status = "failed"
117
+ job.error_message = str(e)
118
+ job.completed_at = datetime.utcnow()
119
+ await session.commit()
 
 
 
120
  finally:
121
+ self._current_job_id = None
 
 
122
 
123
+ async def _process_job(self, session: AsyncSession, job: GeminiJob):
124
+ """Process a single job."""
125
+ logger.info(f"Worker {self.worker_id}: Processing job {job.job_id} (type: {job.job_type}, status: {job.status})")
126
+
127
+ service = GeminiService()
128
+ input_data = job.input_data or {}
 
 
 
 
 
129
 
130
+ # If queued, start the operation
131
+ if job.status == "queued":
132
+ job.status = "processing"
133
+ job.started_at = datetime.utcnow()
134
+ await session.commit()
135
+
136
+ # Start the third-party operation
137
+ await self._start_operation(session, job, service, input_data)
138
 
139
+ # Check status for operations that need polling (video)
140
+ if job.job_type == "video" and job.third_party_id:
141
+ await self._check_video_status(session, job, service)
142
+ # For synchronous operations (already completed in _start_operation)
143
+ # Nothing more to do
144
+
145
+ async def _start_operation(self, session: AsyncSession, job: GeminiJob, service: GeminiService, input_data: dict):
146
+ """Start the third-party operation based on job type."""
147
+ if job.job_type == "video":
148
+ # Start async video generation
149
+ result = await service.start_video_generation(
150
+ base64_image=input_data.get("base64_image", ""),
151
+ mime_type=input_data.get("mime_type", "image/jpeg"),
152
+ prompt=input_data.get("prompt", ""),
153
+ aspect_ratio=input_data.get("aspect_ratio", "16:9"),
154
+ resolution=input_data.get("resolution", "720p"),
155
+ number_of_videos=input_data.get("number_of_videos", 1)
156
+ )
157
+ job.third_party_id = result.get("gemini_operation_name")
158
+ # Schedule first status check
159
+ job.next_process_at = datetime.utcnow() + timedelta(seconds=self.poll_interval)
160
+ await session.commit()
161
+
162
+ elif job.job_type == "image":
163
+ # Synchronous image edit
164
+ result = await service.edit_image(
165
+ base64_image=input_data.get("base64_image", ""),
166
+ mime_type=input_data.get("mime_type", "image/jpeg"),
167
+ prompt=input_data.get("prompt", "")
168
+ )
169
+ job.status = "completed"
170
+ job.output_data = {"image": result}
171
+ job.completed_at = datetime.utcnow()
172
+ await session.commit()
173
+
174
+ elif job.job_type == "text":
175
+ # Synchronous text generation
176
+ result = await service.generate_text(
177
+ prompt=input_data.get("prompt", ""),
178
+ model=input_data.get("model")
179
+ )
180
+ job.status = "completed"
181
+ job.output_data = {"text": result}
182
+ job.completed_at = datetime.utcnow()
183
+ await session.commit()
184
+
185
+ elif job.job_type == "analyze":
186
+ # Synchronous image analysis
187
+ result = await service.analyze_image(
188
+ base64_image=input_data.get("base64_image", ""),
189
+ mime_type=input_data.get("mime_type", "image/jpeg"),
190
+ prompt=input_data.get("prompt", "")
191
+ )
192
+ job.status = "completed"
193
+ job.output_data = {"analysis": result}
194
+ job.completed_at = datetime.utcnow()
195
+ await session.commit()
196
+
197
+ elif job.job_type == "animation_prompt":
198
+ # Synchronous animation prompt generation
199
+ result = await service.generate_animation_prompt(
200
+ base64_image=input_data.get("base64_image", ""),
201
+ mime_type=input_data.get("mime_type", "image/jpeg"),
202
+ custom_prompt=input_data.get("custom_prompt")
203
+ )
204
+ job.status = "completed"
205
+ job.output_data = {"prompt": result}
206
+ job.completed_at = datetime.utcnow()
207
+ await session.commit()
208
+ else:
209
+ job.status = "failed"
210
+ job.error_message = f"Unknown job type: {job.job_type}"
211
+ job.completed_at = datetime.utcnow()
212
+ await session.commit()
213
+
214
+ async def _check_video_status(self, session: AsyncSession, job: GeminiJob, service: GeminiService):
215
+ """Check video generation status and reschedule if not done."""
216
+ try:
217
  status_result = await service.check_video_status(job.third_party_id)
218
 
219
  if status_result.get("done"):
 
232
  job.error_message = status_result.get("error", "Unknown error")
233
 
234
  job.completed_at = datetime.utcnow()
235
+ else:
236
+ # Not done - reschedule
237
+ job.retry_count += 1
238
+ job.next_process_at = datetime.utcnow() + timedelta(seconds=self.poll_interval)
239
+ logger.debug(f"Job {job.job_id}: Not done, retry #{job.retry_count}, next check at {job.next_process_at}")
240
 
241
+ await session.commit()
242
+
243
+ except Exception as e:
244
+ logger.error(f"Error checking video status for {job.job_id}: {e}")
245
+ job.retry_count += 1
246
+ job.next_process_at = datetime.utcnow() + timedelta(seconds=self.poll_interval)
247
+ if job.retry_count > 60: # ~1 hour of retries
248
+ job.status = "failed"
249
+ job.error_message = f"Max retries exceeded: {str(e)}"
250
+ job.completed_at = datetime.utcnow()
251
+ await session.commit()
252
+
253
+
254
+ class WorkerPool:
255
+ """Pool of priority-tier workers."""
256
 
257
+ def __init__(self):
258
+ self.engine = create_async_engine(DATABASE_URL, echo=False)
259
+ self.session_maker = async_sessionmaker(
260
+ self.engine,
261
+ class_=AsyncSession,
262
+ expire_on_commit=False
263
  )
264
+ self.workers: List[PriorityWorker] = []
265
+ self._running = False
 
 
 
266
 
267
+ async def start(self):
268
+ """Start all workers."""
269
+ self._running = True
270
+ worker_id = 0
 
 
271
 
272
+ # Create fast workers
273
+ for i in range(FAST_WORKERS):
274
+ worker = PriorityWorker(worker_id, "fast", FAST_INTERVAL, self.session_maker)
275
+ self.workers.append(worker)
276
+ await worker.start()
277
+ worker_id += 1
 
 
 
 
 
 
278
 
279
+ # Create medium workers
280
+ for i in range(MEDIUM_WORKERS):
281
+ worker = PriorityWorker(worker_id, "medium", MEDIUM_INTERVAL, self.session_maker)
282
+ self.workers.append(worker)
283
+ await worker.start()
284
+ worker_id += 1
285
+
286
+ # Create slow workers
287
+ for i in range(SLOW_WORKERS):
288
+ worker = PriorityWorker(worker_id, "slow", SLOW_INTERVAL, self.session_maker)
289
+ self.workers.append(worker)
290
+ await worker.start()
291
+ worker_id += 1
292
 
293
+ logger.info(f"WorkerPool started with {len(self.workers)} workers: {FAST_WORKERS} fast, {MEDIUM_WORKERS} medium, {SLOW_WORKERS} slow")
294
+
295
+ async def stop(self):
296
+ """Stop all workers."""
297
+ self._running = False
298
+ for worker in self.workers:
299
+ await worker.stop()
300
+ logger.info("WorkerPool stopped")
301
 
302
 
303
+ # Singleton pool instance
304
+ _pool: Optional[WorkerPool] = None
305
 
306
 
307
+ def get_pool() -> WorkerPool:
308
+ """Get the global worker pool instance."""
309
+ global _pool
310
+ if _pool is None:
311
+ _pool = WorkerPool()
312
+ return _pool
313
 
314
 
315
  async def start_worker():
316
+ """Start the background worker pool."""
317
+ pool = get_pool()
318
+ await pool.start()
319
 
320
 
321
  async def stop_worker():
322
+ """Stop the background worker pool."""
323
+ pool = get_pool()
324
+ await pool.stop()
tests/test_worker_pool.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for Priority-Tier Worker Pool implementation.
3
+ Tests job creation with priority, cancel endpoint, and worker assignment.
4
+ """
5
+ import pytest
6
+ import asyncio
7
+ from unittest.mock import patch, MagicMock, AsyncMock
8
+ from datetime import datetime, timedelta
9
+
10
+ # Test the priority mapping
11
+ from services.job_worker import (
12
+ get_priority_for_job_type,
13
+ get_interval_for_priority,
14
+ PriorityWorker,
15
+ WorkerPool,
16
+ JOB_PRIORITY_MAP
17
+ )
18
+
19
+
20
+ class TestPriorityMapping:
21
+ """Test job type to priority mapping."""
22
+
23
+ def test_text_job_is_fast(self):
24
+ assert get_priority_for_job_type("text") == "fast"
25
+
26
+ def test_analyze_job_is_fast(self):
27
+ assert get_priority_for_job_type("analyze") == "fast"
28
+
29
+ def test_animation_prompt_is_fast(self):
30
+ assert get_priority_for_job_type("animation_prompt") == "fast"
31
+
32
+ def test_image_job_is_medium(self):
33
+ assert get_priority_for_job_type("image") == "medium"
34
+
35
+ def test_edit_image_is_medium(self):
36
+ assert get_priority_for_job_type("edit_image") == "medium"
37
+
38
+ def test_video_job_is_slow(self):
39
+ assert get_priority_for_job_type("video") == "slow"
40
+
41
+ def test_unknown_job_defaults_to_fast(self):
42
+ assert get_priority_for_job_type("unknown_type") == "fast"
43
+
44
+
45
+ class TestIntervalMapping:
46
+ """Test priority to interval mapping."""
47
+
48
+ def test_fast_interval(self):
49
+ assert get_interval_for_priority("fast") == 5
50
+
51
+ def test_medium_interval(self):
52
+ assert get_interval_for_priority("medium") == 30
53
+
54
+ def test_slow_interval(self):
55
+ assert get_interval_for_priority("slow") == 60
56
+
57
+ def test_unknown_defaults_to_slow(self):
58
+ assert get_interval_for_priority("unknown") == 60
59
+
60
+
61
+ class TestJobPriorityMap:
62
+ """Test that all expected job types are covered."""
63
+
64
+ def test_all_job_types_have_priority(self):
65
+ expected_types = ["text", "analyze", "animation_prompt", "image", "edit_image", "video"]
66
+ for job_type in expected_types:
67
+ assert job_type in JOB_PRIORITY_MAP, f"Job type '{job_type}' not in priority map"
68
+
69
+
70
+ class TestWorkerPoolConfiguration:
71
+ """Test worker pool creates correct number of workers."""
72
+
73
+ @pytest.mark.asyncio
74
+ async def test_creates_15_workers(self):
75
+ """Test that WorkerPool creates 15 workers (5 fast, 5 medium, 5 slow)."""
76
+ pool = WorkerPool()
77
+
78
+ # Start pool (workers will be created)
79
+ # Mock to prevent actual polling
80
+ with patch.object(PriorityWorker, '_poll_loop', new_callable=AsyncMock):
81
+ await pool.start()
82
+
83
+ assert len(pool.workers) == 15
84
+
85
+ # Count by priority
86
+ fast_workers = [w for w in pool.workers if w.priority == "fast"]
87
+ medium_workers = [w for w in pool.workers if w.priority == "medium"]
88
+ slow_workers = [w for w in pool.workers if w.priority == "slow"]
89
+
90
+ assert len(fast_workers) == 5
91
+ assert len(medium_workers) == 5
92
+ assert len(slow_workers) == 5
93
+
94
+ await pool.stop()
95
+
96
+ @pytest.mark.asyncio
97
+ async def test_workers_have_correct_intervals(self):
98
+ """Test that workers have correct poll intervals."""
99
+ pool = WorkerPool()
100
+
101
+ with patch.object(PriorityWorker, '_poll_loop', new_callable=AsyncMock):
102
+ await pool.start()
103
+
104
+ for worker in pool.workers:
105
+ if worker.priority == "fast":
106
+ assert worker.poll_interval == 5
107
+ elif worker.priority == "medium":
108
+ assert worker.poll_interval == 30
109
+ elif worker.priority == "slow":
110
+ assert worker.poll_interval == 60
111
+
112
+ await pool.stop()
113
+
114
+
115
+ class TestPriorityWorker:
116
+ """Test individual worker behavior."""
117
+
118
+ def test_worker_has_correct_attributes(self):
119
+ """Test worker initialization."""
120
+ worker = PriorityWorker(0, "fast", 5, None)
121
+
122
+ assert worker.worker_id == 0
123
+ assert worker.priority == "fast"
124
+ assert worker.poll_interval == 5
125
+ assert worker._running == False
126
+ assert worker._current_job_id is None
127
+
128
+
129
+ # Integration test with the actual router
130
+ class TestCancelEndpoint:
131
+ """Test cancel job endpoint logic."""
132
+
133
+ def test_only_queued_jobs_can_be_cancelled(self):
134
+ """Verify the cancellation logic - only queued status allowed."""
135
+ valid_statuses = ["queued"]
136
+ invalid_statuses = ["processing", "completed", "failed", "cancelled"]
137
+
138
+ # This is a logic validation, actual HTTP testing would need the app
139
+ for status in valid_statuses:
140
+ assert status == "queued"
141
+
142
+ for status in invalid_statuses:
143
+ assert status != "queued"
144
+
145
+
146
+ if __name__ == "__main__":
147
+ pytest.main([__file__, "-v"])