Spaces:
Sleeping
Sleeping
File size: 16,367 Bytes
036c5c4 3c56e03 c3f1513 036c5c4 b49d66f 036c5c4 c3f1513 b49d66f 036c5c4 b49d66f 036c5c4 ec0e527 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 b49d66f 036c5c4 ec0e527 b49d66f 036c5c4 b49d66f 036c5c4 ec0e527 036c5c4 c4f61f9 c3f1513 be85b16 036c5c4 8c4055f b49d66f 036c5c4 8c4055f b49d66f 036c5c4 b49d66f 036c5c4 ec0e527 b49d66f 036c5c4 c94217d 2d2a3f1 c94217d 036c5c4 c94217d 036c5c4 ec0e527 036c5c4 ec0e527 036c5c4 c3f1513 be85b16 036c5c4 ec0e527 036c5c4 c3f1513 be85b16 036c5c4 ec0e527 036c5c4 c3f1513 be85b16 036c5c4 ec0e527 036c5c4 c3f1513 be85b16 036c5c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
"""
Gemini Job Worker - Specific implementation using the modular PriorityWorkerPool.
This file shows how to use the modular PriorityWorkerPool with Gemini-specific
job processing logic.
"""
import logging
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from core.database import DATABASE_URL
from core.models import GeminiJob
from services.priority_worker_pool import (
PriorityWorkerPool,
JobProcessor,
WorkerConfig,
get_interval_for_priority
)
from services.gemini_service.api_client import GeminiService
from services.drive_service import DriveService
import asyncio
logger = logging.getLogger(__name__)
# Job type to priority mapping for Gemini jobs
JOB_PRIORITY_MAP = {
"text": "fast",
"analyze": "fast",
"animation_prompt": "fast",
"image": "medium",
"edit_image": "medium",
"video": "slow"
}
def get_priority_for_job_type(job_type: str) -> str:
"""Get the priority tier for a Gemini job type."""
return JOB_PRIORITY_MAP.get(job_type, "fast")
class GeminiJobProcessor(JobProcessor[GeminiJob]):
"""Processes Gemini AI jobs (text, image, video generation) with round-robin API keys."""
def __init__(self):
self.drive_service = DriveService()
async def _get_service_with_key(self, session: AsyncSession) -> tuple:
"""Get a GeminiService with the least-used API key."""
from services.api_key_manager import get_least_used_key
key_index, api_key = await get_least_used_key(session)
return key_index, GeminiService(api_key=api_key)
async def _record_usage(self, session: AsyncSession, key_index: int, success: bool, error_message: Optional[str] = None):
"""Record API key usage after request."""
from services.api_key_manager import record_usage
await record_usage(session, key_index, success, error_message)
def _handle_error(self, job: GeminiJob, error: Exception, reset_to_queued: bool = False) -> tuple[bool, str]:
"""
Handle job errors with retry logic.
Args:
job: The job object
error: The exception raised
reset_to_queued: Whether to reset status to 'queued' on retry (for process())
Returns:
Tuple of (success, error_message)
success is False (since it's an error)
error_message is the formatted error string
"""
error_str = str(error)
is_retryable = False
log_msg = ""
# Check for Rate Limit (429)
if "429" in error_str or "ResourceExhausted" in error_str:
is_retryable = True
log_msg = f"Rate limit hit for job {job.job_id}"
# Check for Auth/Billing errors (401, 403, API key not found, API key not valid, FAILED_PRECONDITION)
elif "401" in error_str or "403" in error_str or "Unauthenticated" in error_str or "PermissionDenied" in error_str or "API key not found" in error_str or "API key not valid" in error_str or "FAILED_PRECONDITION" in error_str:
is_retryable = True
log_msg = f"Auth/Billing error for job {job.job_id}: {error_str}. Rescheduling to try different key."
# Check for Server errors (500, 503, 504)
elif "500" in error_str or "503" in error_str or "504" in error_str or "INTERNAL" in error_str or "UNAVAILABLE" in error_str or "DEADLINE_EXCEEDED" in error_str:
is_retryable = True
log_msg = f"Server error for job {job.job_id}: {error_str}"
# Try to parse JSON error details if present
try:
import json
import re
# Look for JSON-like structure in error string
json_match = re.search(r"(\{.*\})", error_str)
if json_match:
job.api_response = json.loads(json_match.group(1))
else:
job.api_response = {"error": error_str}
except Exception:
job.api_response = {"error": error_str}
if is_retryable:
logger.warning(f"{log_msg}. Rescheduling.")
job.retry_count += 1
config = WorkerConfig.from_env()
# Use a longer delay for these errors (e.g., 30s)
interval = 30
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
if reset_to_queued:
job.status = "queued"
return False, f"Retryable error: {error_str}"
else:
logger.error(f"Error processing job {job.job_id}: {error}")
job.status = "failed"
job.error_message = str(error)
job.completed_at = datetime.utcnow()
return False, str(error)
async def process(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
"""Start processing a new job with round-robin API key."""
key_index, service = await self._get_service_with_key(session)
input_data = job.input_data or {}
success = False
error_msg = None
try:
if job.job_type == "video":
job = await self._start_video(job, session, service, input_data)
success = True
elif job.job_type == "image":
job = await self._process_image(job, service, input_data)
success = True
elif job.job_type == "text":
job = await self._process_text(job, service, input_data)
success = True
elif job.job_type == "analyze":
job = await self._process_analyze(job, service, input_data)
success = True
elif job.job_type == "animation_prompt":
job = await self._process_animation_prompt(job, service, input_data)
success = True
else:
job.status = "failed"
job.error_message = f"Unknown job type: {job.job_type}"
job.completed_at = datetime.utcnow()
error_msg = job.error_message
except Exception as e:
# Use helper for error handling
# reset_to_queued=True because if we fail to start, we want to try starting again from scratch
success, error_msg = self._handle_error(job, e, reset_to_queued=True)
# Record usage
await self._record_usage(session, key_index, success, error_msg)
return job
async def check_status(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
"""Check status of an in-progress job (video generation)."""
if job.job_type != "video" or not job.third_party_id:
job.status = "failed"
job.error_message = "Invalid job state for status check"
job.completed_at = datetime.utcnow()
return job
# Use round-robin key for status check
key_index, service = await self._get_service_with_key(session)
success = False
error_msg = None
try:
status_result = await service.check_video_status(job.third_party_id)
# Save raw response
job.api_response = status_result
if status_result.get("done"):
if status_result.get("status") == "completed":
video_url = status_result.get("video_url")
if video_url:
# Store video URL - download will happen on-demand when client requests
job.status = "completed"
job.output_data = {"video_url": video_url}
job.error_message = None # Clear any previous error
job.completed_at = datetime.utcnow()
success = True
# Sync DB on success
from services.backup_service import get_backup_service
backup_service = get_backup_service()
await backup_service.backup_async()
else:
job.status = "failed"
job.error_message = "No video URL returned"
job.completed_at = datetime.utcnow()
error_msg = job.error_message
else:
job.status = "failed"
job.error_message = status_result.get("error", "Unknown error")
job.completed_at = datetime.utcnow()
error_msg = job.error_message
else:
# Not done - reschedule
job.retry_count += 1
config = WorkerConfig.from_env()
interval = get_interval_for_priority(job.priority, config)
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
success = True # Status check succeeded even if video not ready
except Exception as e:
# Use helper for error handling
# reset_to_queued=False because we want to continue checking status, not restart
success, error_msg = self._handle_error(job, e, reset_to_queued=False)
# Record usage
await self._record_usage(session, key_index, success, error_msg)
return job
# Record usage
await self._record_usage(session, key_index, success, error_msg)
return job
async def _start_video(self, job: GeminiJob, session: AsyncSession, service: GeminiService, input_data: dict) -> GeminiJob:
"""Start async video generation."""
prompt = input_data.get("prompt", "")
# If prompt is missing, generate one using the animation template
if not prompt:
try:
import os
template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "prompt", "animation.md")
if os.path.exists(template_path):
with open(template_path, "r") as f:
template_prompt = f.read().strip()
logger.info(f"Generating auto-prompt for job {job.job_id} using template")
prompt = await service.generate_animation_prompt(
base64_image=input_data.get("base64_image", ""),
mime_type=input_data.get("mime_type", "image/jpeg"),
custom_prompt=template_prompt
)
logger.info(f"Generated prompt for job {job.job_id}: {prompt}")
# Update input data with generated prompt for reference
# Create a new dictionary to ensure SQLAlchemy detects the change
new_input_data = dict(input_data)
new_input_data["prompt"] = prompt
job.input_data = new_input_data
# We need to commit this change to DB so it persists
# But session commit happens outside this method usually?
# Actually process() calls this, and process() returns job,
# but doesn't explicitly commit job changes until later?
# The worker loop commits after process() returns.
else:
logger.warning(f"Animation prompt template not found at {template_path}")
except Exception as e:
logger.error(f"Failed to generate auto-prompt: {e}")
# Fallback to empty prompt or error?
# Let's let it proceed with empty prompt which might fail at API level or use API default
result = await service.start_video_generation(
base64_image=input_data.get("base64_image", ""),
mime_type=input_data.get("mime_type", "image/jpeg"),
prompt=prompt,
aspect_ratio=input_data.get("aspect_ratio", "16:9"),
resolution=input_data.get("resolution", "720p"),
number_of_videos=input_data.get("number_of_videos", 1)
)
job.third_party_id = result.get("gemini_operation_name")
job.api_response = result
# Schedule first status check
config = WorkerConfig.from_env()
interval = get_interval_for_priority(job.priority, config)
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
return job
async def _process_image(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
"""Process image edit (synchronous)."""
result = await service.edit_image(
base64_image=input_data.get("base64_image", ""),
mime_type=input_data.get("mime_type", "image/jpeg"),
prompt=input_data.get("prompt", "")
)
job.status = "completed"
job.output_data = {"image": result}
# Don't save full base64 image to api_response
job.api_response = {"status": "success", "type": "image_edit"}
job.completed_at = datetime.utcnow()
# Sync DB on success
from services.backup_service import get_backup_service
backup_service = get_backup_service()
await backup_service.backup_async()
return job
async def _process_text(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
"""Process text generation (synchronous)."""
result = await service.generate_text(
prompt=input_data.get("prompt", ""),
model=input_data.get("model")
)
job.status = "completed"
job.output_data = {"text": result}
job.api_response = {"result": result}
job.completed_at = datetime.utcnow()
# Sync DB on success
from services.backup_service import get_backup_service
backup_service = get_backup_service()
await backup_service.backup_async()
return job
async def _process_analyze(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
"""Process image analysis (synchronous)."""
result = await service.analyze_image(
base64_image=input_data.get("base64_image", ""),
mime_type=input_data.get("mime_type", "image/jpeg"),
prompt=input_data.get("prompt", "")
)
job.status = "completed"
job.output_data = {"analysis": result}
job.api_response = {"result": result}
job.completed_at = datetime.utcnow()
# Sync DB on success
from services.backup_service import get_backup_service
backup_service = get_backup_service()
await backup_service.backup_async()
return job
async def _process_animation_prompt(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
"""Process animation prompt generation (synchronous)."""
result = await service.generate_animation_prompt(
base64_image=input_data.get("base64_image", ""),
mime_type=input_data.get("mime_type", "image/jpeg"),
custom_prompt=input_data.get("custom_prompt")
)
job.status = "completed"
job.output_data = {"prompt": result}
job.api_response = {"result": result}
job.completed_at = datetime.utcnow()
# Sync DB on success
from services.backup_service import get_backup_service
backup_service = get_backup_service()
await backup_service.backup_async()
return job
# Singleton pool instance
_pool: Optional[PriorityWorkerPool] = None
def get_pool() -> PriorityWorkerPool:
"""Get the global Gemini worker pool instance."""
global _pool
if _pool is None:
_pool = PriorityWorkerPool(
database_url=DATABASE_URL,
job_model=GeminiJob,
job_processor=GeminiJobProcessor(),
config=WorkerConfig.from_env()
)
return _pool
async def start_worker():
"""Start the Gemini job worker pool."""
pool = get_pool()
await pool.start()
async def stop_worker():
"""Stop the Gemini job worker pool."""
pool = get_pool()
await pool.stop()
|