Spaces:
Build error
Build error
refactor: remove worker pool, use direct fal.ai API calls
Browse files- Remove services/worker/ and priority_worker_pool.py
- Remove services/gemini_service/ (replaced by fal_service)
- Update gemini.py to call FalService directly
- Use fal.ai submit() for non-blocking job start
- Use fal.ai status() for non-blocking status check
- Remove worker startup/shutdown from app.py
- All API responses now immediate, no background processing
- .env.example +15 -0
- app.py +7 -25
- requirements.txt +17 -18
- routers/gemini.py +48 -25
- services/fal_service/__init__.py +17 -0
- services/fal_service/api_client.py +254 -0
- services/gemini_service/__init__.py +0 -55
- services/gemini_service/api_client.py +0 -401
- services/gemini_service/api_key_config.py +0 -100
- services/gemini_service/api_key_middleware.py +0 -180
- services/gemini_service/job_processor.py +0 -378
- services/priority_worker_pool.py +0 -547
- tests/debug_gemini_service.py +0 -165
- tests/test_fal_service.py +290 -0
- tests/test_gemini_service.py +0 -814
- tests/{test_worker_pool.py → test_worker_pool.py.archived} +0 -0
.env.example
CHANGED
|
@@ -97,6 +97,21 @@ JOB_PER_API_KEY=2
|
|
| 97 |
# Enable mock mode for testing without consuming API credits
|
| 98 |
# GEMINI_MOCK_MODE=true
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# -----------------------------------------------------------------------------
|
| 101 |
# Email Configuration (Optional)
|
| 102 |
# -----------------------------------------------------------------------------
|
|
|
|
| 97 |
# Enable mock mode for testing without consuming API credits
|
| 98 |
# GEMINI_MOCK_MODE=true
|
| 99 |
|
| 100 |
+
# -----------------------------------------------------------------------------
|
| 101 |
+
# AI Provider Configuration
|
| 102 |
+
# -----------------------------------------------------------------------------
|
| 103 |
+
# Which AI provider to use for video generation
|
| 104 |
+
# Options: "google" (Gemini/Veo) or "fal" (fal.ai)
|
| 105 |
+
AI_PROVIDER=google
|
| 106 |
+
|
| 107 |
+
# Fal.ai API Key (required if AI_PROVIDER=fal)
|
| 108 |
+
# Get from: https://fal.ai/dashboard/keys
|
| 109 |
+
# Note: fal_client expects this env var to be named FAL_KEY
|
| 110 |
+
# FAL_KEY=your-fal-api-key
|
| 111 |
+
|
| 112 |
+
# Enable mock mode for fal.ai testing without consuming API credits
|
| 113 |
+
# FAL_MOCK_MODE=true
|
| 114 |
+
|
| 115 |
# -----------------------------------------------------------------------------
|
| 116 |
# Email Configuration (Optional)
|
| 117 |
# -----------------------------------------------------------------------------
|
app.py
CHANGED
|
@@ -159,24 +159,12 @@ async def lifespan(app: FastAPI):
|
|
| 159 |
)
|
| 160 |
logger.info("✅ Audit Service configured")
|
| 161 |
|
| 162 |
-
#
|
| 163 |
-
from services.gemini_service import APIKeyServiceConfig
|
| 164 |
-
APIKeyServiceConfig.register(
|
| 165 |
-
rotation_strategy="least_used", # or "round_robin"
|
| 166 |
-
cooldown_seconds=60, # Wait 1 min after quota error
|
| 167 |
-
max_requests_per_minute=60,
|
| 168 |
-
retry_on_quota_error=True # Auto-retry with different key
|
| 169 |
-
)
|
| 170 |
-
logger.info("✅ API Key Service configured")
|
| 171 |
-
|
| 172 |
-
# Worker Pool Section
|
| 173 |
logger.info("")
|
| 174 |
-
logger.info("
|
|
|
|
| 175 |
|
| 176 |
-
|
| 177 |
-
from services.gemini_service import start_worker, stop_worker
|
| 178 |
-
await start_worker()
|
| 179 |
-
logger.info("✅ Worker pool started")
|
| 180 |
|
| 181 |
# Log CORS configuration
|
| 182 |
allowed_origins = os.getenv("CORS_ORIGINS").split(",")
|
|
@@ -189,17 +177,15 @@ async def lifespan(app: FastAPI):
|
|
| 189 |
logger.info("═" * 60)
|
| 190 |
logger.info(" 🚀 API Gateway Ready")
|
| 191 |
logger.info(" • Database: ✅ Ready")
|
| 192 |
-
logger.info(" • Services:
|
| 193 |
-
logger.info(" •
|
| 194 |
logger.info(" • Endpoint: http://0.0.0.0:8000")
|
| 195 |
logger.info("═" * 60)
|
| 196 |
logger.info("")
|
| 197 |
|
| 198 |
yield
|
| 199 |
|
| 200 |
-
#
|
| 201 |
-
await stop_worker()
|
| 202 |
-
logger.info("Background job worker stopped")
|
| 203 |
|
| 204 |
# Shutdown: Upload DB to Drive
|
| 205 |
logger.info("Shutdown: Uploading database to Google Drive...")
|
|
@@ -230,10 +216,6 @@ from services.audit_service import AuditMiddleware
|
|
| 230 |
app.add_middleware(AuditMiddleware)
|
| 231 |
|
| 232 |
|
| 233 |
-
from services.gemini_service import APIKeyMiddleware
|
| 234 |
-
app.add_middleware(APIKeyMiddleware)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
from services.auth_service import AuthMiddleware
|
| 238 |
app.add_middleware(AuthMiddleware)
|
| 239 |
|
|
|
|
| 159 |
)
|
| 160 |
logger.info("✅ Audit Service configured")
|
| 161 |
|
| 162 |
+
# Job Processing Info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
logger.info("")
|
| 164 |
+
logger.info("⚡ [JOB PROCESSING]")
|
| 165 |
+
logger.info("✅ Using inline processor (fire-and-forget async)")
|
| 166 |
|
| 167 |
+
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
# Log CORS configuration
|
| 170 |
allowed_origins = os.getenv("CORS_ORIGINS").split(",")
|
|
|
|
| 177 |
logger.info("═" * 60)
|
| 178 |
logger.info(" 🚀 API Gateway Ready")
|
| 179 |
logger.info(" • Database: ✅ Ready")
|
| 180 |
+
logger.info(" • Services: 4 initialized (DB, Auth, Credit, Audit)")
|
| 181 |
+
logger.info(" • Processing: Inline (no workers)")
|
| 182 |
logger.info(" • Endpoint: http://0.0.0.0:8000")
|
| 183 |
logger.info("═" * 60)
|
| 184 |
logger.info("")
|
| 185 |
|
| 186 |
yield
|
| 187 |
|
| 188 |
+
# No worker cleanup needed - inline processor uses fire-and-forget tasks
|
|
|
|
|
|
|
| 189 |
|
| 190 |
# Shutdown: Upload DB to Drive
|
| 191 |
logger.info("Shutdown: Uploading database to Google Drive...")
|
|
|
|
| 216 |
app.add_middleware(AuditMiddleware)
|
| 217 |
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
from services.auth_service import AuthMiddleware
|
| 220 |
app.add_middleware(AuthMiddleware)
|
| 221 |
|
requirements.txt
CHANGED
|
@@ -1,19 +1,18 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
httpx>=0.25.0
|
| 9 |
-
|
| 10 |
-
passlib[bcrypt]>=1.7.4
|
| 11 |
-
email-validator>=2.0.0
|
| 12 |
-
python-dotenv>=1.0.0
|
| 13 |
-
google-api-python-client>=2.0.0
|
| 14 |
-
google-auth-oauthlib>=1.0.0
|
| 15 |
-
google-auth-httplib2>=0.1.0
|
| 16 |
-
google-genai>=1.0.0
|
| 17 |
-
PyJWT>=2.8.0
|
| 18 |
-
razorpay>=1.4.0
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.12
|
| 2 |
+
uvicorn[standard]==0.34.3
|
| 3 |
+
sqlalchemy==2.0.41
|
| 4 |
+
aiosqlite==0.21.0
|
| 5 |
+
cryptography==45.0.5
|
| 6 |
+
pydantic==2.11.7
|
| 7 |
+
httpx==0.28.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
passlib[bcrypt]==1.7.4
|
| 10 |
+
email-validator==2.2.0
|
| 11 |
+
python-dotenv==1.1.1
|
| 12 |
+
google-api-python-client==2.187.0
|
| 13 |
+
google-auth-oauthlib==1.2.1
|
| 14 |
+
google-auth-httplib2==0.2.0
|
| 15 |
+
google-genai==1.57.0
|
| 16 |
+
PyJWT==2.10.1
|
| 17 |
+
razorpay==2.0.0
|
| 18 |
+
fal-client==0.5.9
|
routers/gemini.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
"""
|
| 4 |
import os
|
| 5 |
import uuid
|
|
@@ -12,10 +13,10 @@ from sqlalchemy import select, func
|
|
| 12 |
|
| 13 |
from core.database import get_db
|
| 14 |
from core.models import User, GeminiJob
|
| 15 |
-
from services.
|
| 16 |
from datetime import datetime
|
| 17 |
|
| 18 |
-
router = APIRouter(prefix="/gemini", tags=["
|
| 19 |
|
| 20 |
|
| 21 |
|
|
@@ -71,18 +72,15 @@ async def create_job(
|
|
| 71 |
input_data: dict,
|
| 72 |
credits_reserved: int = 0
|
| 73 |
) -> GeminiJob:
|
| 74 |
-
"""Create
|
| 75 |
-
from services.gemini_service.job_processor import get_priority_for_job_type, get_pool
|
| 76 |
-
|
| 77 |
job_id = f"job_{uuid.uuid4().hex[:16]}"
|
| 78 |
-
priority = get_priority_for_job_type(job_type)
|
| 79 |
|
| 80 |
job = GeminiJob(
|
| 81 |
job_id=job_id,
|
| 82 |
user_id=user.id,
|
| 83 |
job_type=job_type,
|
| 84 |
status="queued",
|
| 85 |
-
priority=
|
| 86 |
input_data=input_data,
|
| 87 |
credits_reserved=credits_reserved
|
| 88 |
)
|
|
@@ -90,8 +88,22 @@ async def create_job(
|
|
| 90 |
await db.commit()
|
| 91 |
await db.refresh(job)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
return job
|
| 97 |
|
|
@@ -322,7 +334,7 @@ async def get_job_status(
|
|
| 322 |
req: Request,
|
| 323 |
db: AsyncSession = Depends(get_db)
|
| 324 |
):
|
| 325 |
-
"""Get job status
|
| 326 |
user = req.state.user
|
| 327 |
query = select(GeminiJob).where(
|
| 328 |
GeminiJob.job_id == job_id,
|
|
@@ -337,13 +349,30 @@ async def get_job_status(
|
|
| 337 |
detail="Job not found"
|
| 338 |
)
|
| 339 |
|
| 340 |
-
|
| 341 |
if job.status == "processing" and job.job_type == "video" and job.third_party_id:
|
| 342 |
-
from services.
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
response = {
|
| 349 |
"success": True,
|
|
@@ -353,23 +382,17 @@ async def get_job_status(
|
|
| 353 |
"created_at": job.created_at.isoformat() if job.created_at else None,
|
| 354 |
"credits_remaining": user.credits
|
| 355 |
}
|
| 356 |
-
|
| 357 |
|
| 358 |
if job.job_type == "video" and job.input_data:
|
| 359 |
response["prompt"] = job.input_data.get("prompt")
|
| 360 |
|
| 361 |
-
if job.status == "queued":
|
| 362 |
-
response["position"] = await get_queue_position(db, job.job_id)
|
| 363 |
-
|
| 364 |
if job.status == "processing":
|
| 365 |
response["started_at"] = job.started_at.isoformat() if job.started_at else None
|
| 366 |
|
| 367 |
if job.status == "completed":
|
| 368 |
response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
if job.output_data and "prompt" in job.output_data:
|
| 372 |
-
response["prompt"] = job.output_data["prompt"]
|
| 373 |
|
| 374 |
if job.status == "failed":
|
| 375 |
response["error"] = job.error_message
|
|
|
|
| 1 |
"""
|
| 2 |
+
Video Router - API endpoints for AI video generation services.
|
| 3 |
+
Uses fal.ai for video generation.
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import uuid
|
|
|
|
| 13 |
|
| 14 |
from core.database import get_db
|
| 15 |
from core.models import User, GeminiJob
|
| 16 |
+
from services.fal_service import MODELS
|
| 17 |
from datetime import datetime
|
| 18 |
|
| 19 |
+
router = APIRouter(prefix="/gemini", tags=["video"])
|
| 20 |
|
| 21 |
|
| 22 |
|
|
|
|
| 72 |
input_data: dict,
|
| 73 |
credits_reserved: int = 0
|
| 74 |
) -> GeminiJob:
|
| 75 |
+
"""Create job and start processing on fal.ai (non-blocking)."""
|
|
|
|
|
|
|
| 76 |
job_id = f"job_{uuid.uuid4().hex[:16]}"
|
|
|
|
| 77 |
|
| 78 |
job = GeminiJob(
|
| 79 |
job_id=job_id,
|
| 80 |
user_id=user.id,
|
| 81 |
job_type=job_type,
|
| 82 |
status="queued",
|
| 83 |
+
priority="slow",
|
| 84 |
input_data=input_data,
|
| 85 |
credits_reserved=credits_reserved
|
| 86 |
)
|
|
|
|
| 88 |
await db.commit()
|
| 89 |
await db.refresh(job)
|
| 90 |
|
| 91 |
+
# Start fal.ai job immediately (non-blocking)
|
| 92 |
+
if job_type == "video":
|
| 93 |
+
from services.fal_service import FalService
|
| 94 |
+
fal = FalService()
|
| 95 |
+
result = await fal.start_video_generation(
|
| 96 |
+
base64_image=input_data.get("base64_image", ""),
|
| 97 |
+
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 98 |
+
prompt=input_data.get("prompt", "Animate this image with subtle movement."),
|
| 99 |
+
aspect_ratio=input_data.get("aspect_ratio", "16:9"),
|
| 100 |
+
resolution=input_data.get("resolution", "720p"),
|
| 101 |
+
)
|
| 102 |
+
job.status = "processing"
|
| 103 |
+
job.started_at = datetime.utcnow()
|
| 104 |
+
job.third_party_id = result.get("fal_request_id")
|
| 105 |
+
job.api_response = result
|
| 106 |
+
await db.commit()
|
| 107 |
|
| 108 |
return job
|
| 109 |
|
|
|
|
| 334 |
req: Request,
|
| 335 |
db: AsyncSession = Depends(get_db)
|
| 336 |
):
|
| 337 |
+
"""Get job status. Checks fal.ai if processing (non-blocking)."""
|
| 338 |
user = req.state.user
|
| 339 |
query = select(GeminiJob).where(
|
| 340 |
GeminiJob.job_id == job_id,
|
|
|
|
| 349 |
detail="Job not found"
|
| 350 |
)
|
| 351 |
|
| 352 |
+
# Check fal.ai status if processing (non-blocking)
|
| 353 |
if job.status == "processing" and job.job_type == "video" and job.third_party_id:
|
| 354 |
+
from services.fal_service import FalService
|
| 355 |
+
fal = FalService()
|
| 356 |
+
result = await fal.check_video_status(job.third_party_id)
|
| 357 |
+
|
| 358 |
+
if result.get("done"):
|
| 359 |
+
if result.get("status") == "completed":
|
| 360 |
+
job.status = "completed"
|
| 361 |
+
job.output_data = {"video_url": result.get("video_url")}
|
| 362 |
+
job.completed_at = datetime.utcnow()
|
| 363 |
+
else:
|
| 364 |
+
job.status = "failed"
|
| 365 |
+
job.error_message = result.get("error", "Unknown error")
|
| 366 |
+
job.completed_at = datetime.utcnow()
|
| 367 |
+
|
| 368 |
+
# Handle credits on completion
|
| 369 |
+
if job.credits_reserved > 0:
|
| 370 |
+
try:
|
| 371 |
+
from services.credit_service.credit_manager import handle_job_completion
|
| 372 |
+
await handle_job_completion(db, job)
|
| 373 |
+
except Exception:
|
| 374 |
+
pass
|
| 375 |
+
await db.commit()
|
| 376 |
|
| 377 |
response = {
|
| 378 |
"success": True,
|
|
|
|
| 382 |
"created_at": job.created_at.isoformat() if job.created_at else None,
|
| 383 |
"credits_remaining": user.credits
|
| 384 |
}
|
|
|
|
| 385 |
|
| 386 |
if job.job_type == "video" and job.input_data:
|
| 387 |
response["prompt"] = job.input_data.get("prompt")
|
| 388 |
|
|
|
|
|
|
|
|
|
|
| 389 |
if job.status == "processing":
|
| 390 |
response["started_at"] = job.started_at.isoformat() if job.started_at else None
|
| 391 |
|
| 392 |
if job.status == "completed":
|
| 393 |
response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
|
| 394 |
+
if job.output_data and job.output_data.get("video_url"):
|
| 395 |
+
response["video_url"] = job.output_data.get("video_url")
|
|
|
|
|
|
|
| 396 |
|
| 397 |
if job.status == "failed":
|
| 398 |
response["error"] = job.error_message
|
services/fal_service/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fal.ai Service for video generation.
|
| 3 |
+
Provides access to Veo 3.1 and other models through fal.ai's unified API.
|
| 4 |
+
"""
|
| 5 |
+
from services.fal_service.api_client import (
|
| 6 |
+
FalService,
|
| 7 |
+
MODELS,
|
| 8 |
+
MOCK_MODE,
|
| 9 |
+
get_fal_api_key,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"FalService",
|
| 14 |
+
"MODELS",
|
| 15 |
+
"MOCK_MODE",
|
| 16 |
+
"get_fal_api_key",
|
| 17 |
+
]
|
services/fal_service/api_client.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fal.ai Service for video generation.
|
| 3 |
+
Python implementation using fal-client SDK.
|
| 4 |
+
Uses server-side API key from environment.
|
| 5 |
+
"""
|
| 6 |
+
import asyncio
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
from typing import Optional, Literal
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Model names - easily configurable
|
| 14 |
+
MODELS = {
|
| 15 |
+
"video_generation": "fal-ai/veo3.1/fast/image-to-video"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
# Type aliases
|
| 19 |
+
AspectRatio = Literal["16:9", "9:16", "auto"]
|
| 20 |
+
Resolution = Literal["720p", "1080p"]
|
| 21 |
+
|
| 22 |
+
# Mock mode for local testing (set FAL_MOCK_MODE=true to skip real API calls)
|
| 23 |
+
MOCK_MODE = os.getenv("FAL_MOCK_MODE", "false").lower() == "true"
|
| 24 |
+
|
| 25 |
+
# Sample video URL for mock mode
|
| 26 |
+
MOCK_VIDEO_URL = "https://v3b.fal.media/files/mock/mock-video.mp4"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_fal_api_key() -> str:
|
| 30 |
+
"""Get Fal.ai API key from environment."""
|
| 31 |
+
api_key = os.getenv("FAL_KEY")
|
| 32 |
+
if not api_key:
|
| 33 |
+
raise ValueError("Server Authentication Error: FAL_KEY not configured")
|
| 34 |
+
return api_key
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FalService:
|
| 38 |
+
"""
|
| 39 |
+
Fal.ai Service for video generation.
|
| 40 |
+
Uses server-side API key from environment (FAL_KEY).
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 44 |
+
"""Initialize the Fal.ai client with API key from env or provided."""
|
| 45 |
+
self.api_key = api_key or get_fal_api_key()
|
| 46 |
+
# fal_client reads from FAL_KEY env var by default
|
| 47 |
+
# Set it explicitly if a custom key is provided
|
| 48 |
+
if api_key:
|
| 49 |
+
os.environ["FAL_KEY"] = api_key
|
| 50 |
+
|
| 51 |
+
def _handle_api_error(self, error: Exception, context: str):
|
| 52 |
+
"""Handle API errors with descriptive messages."""
|
| 53 |
+
msg = str(error)
|
| 54 |
+
if "401" in msg or "Unauthorized" in msg:
|
| 55 |
+
raise ValueError(
|
| 56 |
+
f"Authentication failed ({context}). Check your FAL_KEY is valid."
|
| 57 |
+
)
|
| 58 |
+
if "402" in msg or "Payment Required" in msg:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"Insufficient credits ({context}). Add credits at fal.ai."
|
| 61 |
+
)
|
| 62 |
+
if "429" in msg or "Rate limit" in msg.lower():
|
| 63 |
+
raise ValueError(
|
| 64 |
+
f"Rate limit exceeded ({context}). Wait and retry."
|
| 65 |
+
)
|
| 66 |
+
raise error
|
| 67 |
+
|
| 68 |
+
async def start_video_generation(
|
| 69 |
+
self,
|
| 70 |
+
base64_image: str,
|
| 71 |
+
mime_type: str,
|
| 72 |
+
prompt: str,
|
| 73 |
+
aspect_ratio: AspectRatio = "16:9",
|
| 74 |
+
resolution: Resolution = "720p",
|
| 75 |
+
number_of_videos: int = 1
|
| 76 |
+
) -> dict:
|
| 77 |
+
"""
|
| 78 |
+
Start video generation using Fal.ai Veo 3.1 model.
|
| 79 |
+
Unlike Gemini, fal.ai subscribe() handles polling internally,
|
| 80 |
+
so this returns the completed video directly.
|
| 81 |
+
|
| 82 |
+
Returns dict with:
|
| 83 |
+
- fal_request_id: Request ID for reference
|
| 84 |
+
- done: Always True (fal.ai waits for completion)
|
| 85 |
+
- status: "completed" or "failed"
|
| 86 |
+
- video_url: URL to the generated video
|
| 87 |
+
"""
|
| 88 |
+
# Mock mode for testing without API credits
|
| 89 |
+
if MOCK_MODE:
|
| 90 |
+
import uuid
|
| 91 |
+
mock_request_id = f"mock_fal_{uuid.uuid4().hex[:16]}"
|
| 92 |
+
logger.info(f"[MOCK MODE] Video generation: {mock_request_id}")
|
| 93 |
+
await asyncio.sleep(2) # Simulate API delay
|
| 94 |
+
return {
|
| 95 |
+
"fal_request_id": mock_request_id,
|
| 96 |
+
"done": True,
|
| 97 |
+
"status": "completed",
|
| 98 |
+
"video_url": MOCK_VIDEO_URL
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
import fal_client
|
| 103 |
+
|
| 104 |
+
# Use submit() instead of subscribe() - returns immediately without waiting
|
| 105 |
+
# This starts the job and returns a request_id for status checking
|
| 106 |
+
handle = await asyncio.to_thread(
|
| 107 |
+
fal_client.submit,
|
| 108 |
+
MODELS["video_generation"],
|
| 109 |
+
arguments={
|
| 110 |
+
"prompt": prompt,
|
| 111 |
+
"image_url": f"data:{mime_type};base64,{base64_image}",
|
| 112 |
+
"aspect_ratio": aspect_ratio,
|
| 113 |
+
"resolution": resolution,
|
| 114 |
+
"generate_audio": True,
|
| 115 |
+
},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Get the request ID from the handle
|
| 119 |
+
request_id = handle.request_id if hasattr(handle, 'request_id') else str(handle)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"fal_request_id": request_id,
|
| 123 |
+
"done": False,
|
| 124 |
+
"status": "processing",
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
except ImportError:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
"fal-client package not installed. Run: pip install fal-client"
|
| 130 |
+
)
|
| 131 |
+
except Exception as error:
|
| 132 |
+
self._handle_api_error(error, MODELS["video_generation"])
|
| 133 |
+
|
| 134 |
+
async def check_video_status(self, fal_request_id: str) -> dict:
|
| 135 |
+
"""
|
| 136 |
+
Check the status of a video generation request.
|
| 137 |
+
Returns immediately with current status (does not wait).
|
| 138 |
+
"""
|
| 139 |
+
# Mock mode for testing
|
| 140 |
+
if MOCK_MODE:
|
| 141 |
+
import random
|
| 142 |
+
# Simulate completion after a few checks
|
| 143 |
+
if random.random() > 0.7:
|
| 144 |
+
return {
|
| 145 |
+
"fal_request_id": fal_request_id,
|
| 146 |
+
"done": True,
|
| 147 |
+
"status": "completed",
|
| 148 |
+
"video_url": MOCK_VIDEO_URL
|
| 149 |
+
}
|
| 150 |
+
return {
|
| 151 |
+
"fal_request_id": fal_request_id,
|
| 152 |
+
"done": False,
|
| 153 |
+
"status": "processing"
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
import fal_client
|
| 158 |
+
|
| 159 |
+
# Get status without waiting
|
| 160 |
+
status = await asyncio.to_thread(
|
| 161 |
+
fal_client.status,
|
| 162 |
+
MODELS["video_generation"],
|
| 163 |
+
fal_request_id,
|
| 164 |
+
with_logs=False
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Check if completed
|
| 168 |
+
if hasattr(status, 'status'):
|
| 169 |
+
if status.status == "COMPLETED":
|
| 170 |
+
# Get the result
|
| 171 |
+
result = await asyncio.to_thread(
|
| 172 |
+
fal_client.result,
|
| 173 |
+
MODELS["video_generation"],
|
| 174 |
+
fal_request_id
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Extract video URL
|
| 178 |
+
video_url = None
|
| 179 |
+
if isinstance(result, dict) and "video" in result:
|
| 180 |
+
video_url = result["video"].get("url")
|
| 181 |
+
elif hasattr(result, "video") and hasattr(result.video, "url"):
|
| 182 |
+
video_url = result.video.url
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"fal_request_id": fal_request_id,
|
| 186 |
+
"done": True,
|
| 187 |
+
"status": "completed",
|
| 188 |
+
"video_url": video_url
|
| 189 |
+
}
|
| 190 |
+
elif status.status == "FAILED":
|
| 191 |
+
return {
|
| 192 |
+
"fal_request_id": fal_request_id,
|
| 193 |
+
"done": True,
|
| 194 |
+
"status": "failed",
|
| 195 |
+
"error": getattr(status, 'error', 'Unknown error')
|
| 196 |
+
}
|
| 197 |
+
else:
|
| 198 |
+
# Still processing (IN_QUEUE, IN_PROGRESS)
|
| 199 |
+
return {
|
| 200 |
+
"fal_request_id": fal_request_id,
|
| 201 |
+
"done": False,
|
| 202 |
+
"status": "processing"
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
# Fallback - assume still processing
|
| 206 |
+
return {
|
| 207 |
+
"fal_request_id": fal_request_id,
|
| 208 |
+
"done": False,
|
| 209 |
+
"status": "processing"
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
except ImportError:
|
| 213 |
+
raise ValueError(
|
| 214 |
+
"fal-client package not installed. Run: pip install fal-client"
|
| 215 |
+
)
|
| 216 |
+
except Exception as error:
|
| 217 |
+
logger.error(f"Error checking status for {fal_request_id}: {error}")
|
| 218 |
+
return {
|
| 219 |
+
"fal_request_id": fal_request_id,
|
| 220 |
+
"done": False,
|
| 221 |
+
"status": "processing",
|
| 222 |
+
"error": str(error)
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
async def download_video(self, video_url: str, request_id: str) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Download video from fal.ai to local storage.
|
| 228 |
+
Returns the local filename.
|
| 229 |
+
"""
|
| 230 |
+
import httpx
|
| 231 |
+
|
| 232 |
+
# Use same downloads directory as Gemini service
|
| 233 |
+
downloads_dir = os.path.join(
|
| 234 |
+
os.path.dirname(os.path.dirname(__file__)),
|
| 235 |
+
"downloads"
|
| 236 |
+
)
|
| 237 |
+
os.makedirs(downloads_dir, exist_ok=True)
|
| 238 |
+
|
| 239 |
+
filename = f"{request_id}.mp4"
|
| 240 |
+
filepath = os.path.join(downloads_dir, filename)
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
|
| 244 |
+
response = await client.get(video_url)
|
| 245 |
+
response.raise_for_status()
|
| 246 |
+
|
| 247 |
+
with open(filepath, 'wb') as f:
|
| 248 |
+
f.write(response.content)
|
| 249 |
+
|
| 250 |
+
logger.info(f"Downloaded video to {filepath}")
|
| 251 |
+
return filename
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(f"Failed to download video: {e}")
|
| 254 |
+
raise ValueError(f"Failed to download video: {e}")
|
services/gemini_service/__init__.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Gemini Service - AI-powered image and video generation
|
| 3 |
-
|
| 4 |
-
Provides:
|
| 5 |
-
- Text generation
|
| 6 |
-
- Image editing
|
| 7 |
-
- Video generation
|
| 8 |
-
- Job processing and background workers
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
# API Client exports
|
| 12 |
-
from services.gemini_service.api_client import (
|
| 13 |
-
GeminiService,
|
| 14 |
-
MODELS,
|
| 15 |
-
DOWNLOADS_DIR,
|
| 16 |
-
get_gemini_api_key,
|
| 17 |
-
MOCK_MODE,
|
| 18 |
-
MOCK_VIDEO_URL,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
# Job Processor exports
|
| 22 |
-
from services.gemini_service.job_processor import (
|
| 23 |
-
GeminiJobProcessor,
|
| 24 |
-
PriorityWorkerPool,
|
| 25 |
-
get_pool,
|
| 26 |
-
get_priority_for_job_type,
|
| 27 |
-
start_worker,
|
| 28 |
-
stop_worker,
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
# API Key Middleware exports # Added
|
| 32 |
-
from services.gemini_service.api_key_config import APIKeyServiceConfig # Added
|
| 33 |
-
from services.gemini_service.api_key_middleware import APIKeyMiddleware # Added
|
| 34 |
-
|
| 35 |
-
__all__ = [
|
| 36 |
-
# API Client
|
| 37 |
-
'GeminiService',
|
| 38 |
-
'MODELS',
|
| 39 |
-
'DOWNLOADS_DIR',
|
| 40 |
-
'get_gemini_api_key',
|
| 41 |
-
'MOCK_MODE',
|
| 42 |
-
'MOCK_VIDEO_URL',
|
| 43 |
-
|
| 44 |
-
# Job Processor
|
| 45 |
-
'GeminiJobProcessor',
|
| 46 |
-
'PriorityWorkerPool',
|
| 47 |
-
'get_pool',
|
| 48 |
-
'get_priority_for_job_type',
|
| 49 |
-
'start_worker',
|
| 50 |
-
'stop_worker',
|
| 51 |
-
|
| 52 |
-
# API Key Middleware
|
| 53 |
-
'APIKeyServiceConfig',
|
| 54 |
-
'APIKeyMiddleware',
|
| 55 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/gemini_service/api_client.py
DELETED
|
@@ -1,401 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Gemini AI Service for image and video generation.
|
| 3 |
-
Python port of the TypeScript geminiService.ts
|
| 4 |
-
Uses server-side API key from environment.
|
| 5 |
-
"""
|
| 6 |
-
import asyncio
|
| 7 |
-
import logging
|
| 8 |
-
import os
|
| 9 |
-
import uuid
|
| 10 |
-
import httpx
|
| 11 |
-
from typing import Optional, Literal
|
| 12 |
-
from google import genai
|
| 13 |
-
from google.genai import types
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
# Model names - easily configurable
|
| 18 |
-
MODELS = {
|
| 19 |
-
"text_generation": "gemini-2.5-flash",
|
| 20 |
-
"image_edit": "gemini-2.5-flash-image",
|
| 21 |
-
"video_generation": "veo-3.1-generate-preview"
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
-
# Type aliases
|
| 25 |
-
AspectRatio = Literal["16:9", "9:16"]
|
| 26 |
-
Resolution = Literal["720p", "1080p"]
|
| 27 |
-
|
| 28 |
-
# Video downloads directory
|
| 29 |
-
DOWNLOADS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "downloads")
|
| 30 |
-
|
| 31 |
-
# Ensure downloads directory exists
|
| 32 |
-
os.makedirs(DOWNLOADS_DIR, exist_ok=True)
|
| 33 |
-
|
| 34 |
-
# Mock mode for local testing (set GEMINI_MOCK_MODE=true to skip real API calls)
|
| 35 |
-
MOCK_MODE = os.getenv("GEMINI_MOCK_MODE", "false").lower() == "true"
|
| 36 |
-
MOCK_MODE_SLEEP_TIME = os.getenv("GEMINI_MOCK_MODE_SLEEP_TIME", "0.5")
|
| 37 |
-
|
| 38 |
-
# Sample video URL for mock mode (a public test video)
|
| 39 |
-
MOCK_VIDEO_URL = "https://video.twimg.com/amplify_video/1994083297756848128/vid/avc1/576x576/ue31qU0xts8L9tXD.mp4?tag=21"
|
| 40 |
-
|
| 41 |
-
# Concurrency limits from environment (defaults)
|
| 42 |
-
MAX_CONCURRENT_VIDEOS = int(os.getenv("MAX_CONCURRENT_VIDEOS", "2"))
|
| 43 |
-
MAX_CONCURRENT_IMAGES = int(os.getenv("MAX_CONCURRENT_IMAGES", "5"))
|
| 44 |
-
MAX_CONCURRENT_TEXT = int(os.getenv("MAX_CONCURRENT_TEXT", "10"))
|
| 45 |
-
|
| 46 |
-
# Semaphores for concurrency control
|
| 47 |
-
_video_semaphore: Optional[asyncio.Semaphore] = None
|
| 48 |
-
_image_semaphore: Optional[asyncio.Semaphore] = None
|
| 49 |
-
_text_semaphore: Optional[asyncio.Semaphore] = None
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def get_video_semaphore() -> asyncio.Semaphore:
|
| 53 |
-
"""Get or create video semaphore."""
|
| 54 |
-
global _video_semaphore
|
| 55 |
-
if _video_semaphore is None:
|
| 56 |
-
_video_semaphore = asyncio.Semaphore(MAX_CONCURRENT_VIDEOS)
|
| 57 |
-
logger.info(f"Video semaphore initialized with limit: {MAX_CONCURRENT_VIDEOS}")
|
| 58 |
-
return _video_semaphore
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def get_image_semaphore() -> asyncio.Semaphore:
|
| 62 |
-
"""Get or create image semaphore."""
|
| 63 |
-
global _image_semaphore
|
| 64 |
-
if _image_semaphore is None:
|
| 65 |
-
_image_semaphore = asyncio.Semaphore(MAX_CONCURRENT_IMAGES)
|
| 66 |
-
logger.info(f"Image semaphore initialized with limit: {MAX_CONCURRENT_IMAGES}")
|
| 67 |
-
return _image_semaphore
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def get_text_semaphore() -> asyncio.Semaphore:
|
| 71 |
-
"""Get or create text semaphore."""
|
| 72 |
-
global _text_semaphore
|
| 73 |
-
if _text_semaphore is None:
|
| 74 |
-
_text_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TEXT)
|
| 75 |
-
logger.info(f"Text semaphore initialized with limit: {MAX_CONCURRENT_TEXT}")
|
| 76 |
-
return _text_semaphore
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_gemini_api_key() -> str:
|
| 80 |
-
"""Get Gemini API key from environment."""
|
| 81 |
-
api_key = os.getenv("GEMINI_API_KEY")
|
| 82 |
-
if not api_key:
|
| 83 |
-
raise ValueError("Server Authentication Error with GEMINI")
|
| 84 |
-
return api_key
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class GeminiService:
|
| 88 |
-
"""
|
| 89 |
-
Gemini AI Service for text, image, and video generation.
|
| 90 |
-
Uses server-side API key from environment.
|
| 91 |
-
"""
|
| 92 |
-
|
| 93 |
-
def __init__(self, api_key: Optional[str] = None):
|
| 94 |
-
"""Initialize the Gemini client with API key from env or provided."""
|
| 95 |
-
self.api_key = api_key or get_gemini_api_key()
|
| 96 |
-
self.client = genai.Client(api_key=self.api_key)
|
| 97 |
-
|
| 98 |
-
def _handle_api_error(self, error: Exception, context: str):
|
| 99 |
-
"""Handle API errors with descriptive messages."""
|
| 100 |
-
msg = str(error)
|
| 101 |
-
if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg or "[5," in msg:
|
| 102 |
-
raise ValueError(
|
| 103 |
-
f"Model not found ({context}). Ensure your API key project has access to this model. "
|
| 104 |
-
"Veo requires a paid account."
|
| 105 |
-
)
|
| 106 |
-
raise error
|
| 107 |
-
|
| 108 |
-
async def generate_animation_prompt(
|
| 109 |
-
self,
|
| 110 |
-
base64_image: str,
|
| 111 |
-
mime_type: str,
|
| 112 |
-
custom_prompt: Optional[str] = None
|
| 113 |
-
) -> str:
|
| 114 |
-
"""
|
| 115 |
-
Analyzes the image to generate a suitable animation prompt.
|
| 116 |
-
"""
|
| 117 |
-
# Mock mode for testing
|
| 118 |
-
if MOCK_MODE:
|
| 119 |
-
logger.info("[MOCK MODE] Generating animation prompt")
|
| 120 |
-
await asyncio.sleep(GEMINI_MOCK_MODE_SLEEP_TIME) # Simulate API delay
|
| 121 |
-
return "A gentle breeze rustles through the scene as soft light dances across the surface. The camera slowly zooms in with a subtle parallax effect, creating depth and movement."
|
| 122 |
-
|
| 123 |
-
default_prompt = custom_prompt or "Describe how this image could be subtly animated with cinematic movement."
|
| 124 |
-
async with get_text_semaphore():
|
| 125 |
-
try:
|
| 126 |
-
response = await asyncio.to_thread(
|
| 127 |
-
self.client.models.generate_content,
|
| 128 |
-
model=MODELS["text_generation"],
|
| 129 |
-
contents=types.Content(
|
| 130 |
-
parts=[
|
| 131 |
-
types.Part.from_bytes(
|
| 132 |
-
data=base64_image,
|
| 133 |
-
mime_type=mime_type
|
| 134 |
-
),
|
| 135 |
-
types.Part.from_text(text=default_prompt)
|
| 136 |
-
]
|
| 137 |
-
)
|
| 138 |
-
)
|
| 139 |
-
return response.text or "Cinematic subtle movement"
|
| 140 |
-
except Exception as error:
|
| 141 |
-
self._handle_api_error(error, MODELS["text_generation"])
|
| 142 |
-
|
| 143 |
-
async def edit_image(
|
| 144 |
-
self,
|
| 145 |
-
base64_image: str,
|
| 146 |
-
mime_type: str,
|
| 147 |
-
prompt: str
|
| 148 |
-
) -> str:
|
| 149 |
-
"""
|
| 150 |
-
Edit an image using Gemini image model.
|
| 151 |
-
Returns base64 data URI of the edited image.
|
| 152 |
-
"""
|
| 153 |
-
# Mock mode for testing - return a sample image
|
| 154 |
-
if MOCK_MODE:
|
| 155 |
-
logger.info(f"[MOCK MODE] Editing image with prompt: {prompt}")
|
| 156 |
-
await asyncio.sleep(1) # Simulate API delay
|
| 157 |
-
# Return a small red placeholder image (1x1 pixel)
|
| 158 |
-
return "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
| 159 |
-
|
| 160 |
-
async with get_image_semaphore():
|
| 161 |
-
try:
|
| 162 |
-
response = await asyncio.to_thread(
|
| 163 |
-
self.client.models.generate_content,
|
| 164 |
-
model=MODELS["image_edit"],
|
| 165 |
-
contents=types.Content(
|
| 166 |
-
parts=[
|
| 167 |
-
types.Part.from_bytes(
|
| 168 |
-
data=base64_image,
|
| 169 |
-
mime_type=mime_type
|
| 170 |
-
),
|
| 171 |
-
types.Part.from_text(text=prompt or "Enhance this image")
|
| 172 |
-
]
|
| 173 |
-
)
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
candidates = response.candidates
|
| 177 |
-
if not candidates:
|
| 178 |
-
raise ValueError("No candidates returned from Gemini.")
|
| 179 |
-
|
| 180 |
-
for part in candidates[0].content.parts:
|
| 181 |
-
if hasattr(part, 'inline_data') and part.inline_data and part.inline_data.data:
|
| 182 |
-
result_mime = part.inline_data.mime_type or 'image/png'
|
| 183 |
-
return f"data:{result_mime};base64,{part.inline_data.data}"
|
| 184 |
-
|
| 185 |
-
raise ValueError("No image data found in the response.")
|
| 186 |
-
except Exception as error:
|
| 187 |
-
self._handle_api_error(error, MODELS["image_edit"])
|
| 188 |
-
|
| 189 |
-
async def start_video_generation(
|
| 190 |
-
self,
|
| 191 |
-
base64_image: str,
|
| 192 |
-
mime_type: str,
|
| 193 |
-
prompt: str,
|
| 194 |
-
aspect_ratio: AspectRatio = "16:9",
|
| 195 |
-
resolution: Resolution = "720p",
|
| 196 |
-
number_of_videos: int = 1
|
| 197 |
-
) -> dict:
|
| 198 |
-
"""
|
| 199 |
-
Start video generation using Veo model.
|
| 200 |
-
Returns operation details for polling.
|
| 201 |
-
"""
|
| 202 |
-
# Mock mode for testing without API credits
|
| 203 |
-
if MOCK_MODE:
|
| 204 |
-
import uuid
|
| 205 |
-
mock_operation_name = f"mock_operation_{uuid.uuid4().hex[:16]}"
|
| 206 |
-
logger.info(f"[MOCK MODE] Starting video generation: {mock_operation_name}")
|
| 207 |
-
return {
|
| 208 |
-
"gemini_operation_name": mock_operation_name,
|
| 209 |
-
"done": False,
|
| 210 |
-
"status": "pending"
|
| 211 |
-
}
|
| 212 |
-
|
| 213 |
-
async with get_video_semaphore():
|
| 214 |
-
try:
|
| 215 |
-
# Start video generation
|
| 216 |
-
operation = await asyncio.to_thread(
|
| 217 |
-
self.client.models.generate_videos,
|
| 218 |
-
model=MODELS["video_generation"],
|
| 219 |
-
prompt=prompt,
|
| 220 |
-
image=types.Image(
|
| 221 |
-
image_bytes=base64_image,
|
| 222 |
-
mime_type=mime_type
|
| 223 |
-
),
|
| 224 |
-
config=types.GenerateVideosConfig(
|
| 225 |
-
number_of_videos=number_of_videos,
|
| 226 |
-
resolution=resolution,
|
| 227 |
-
aspect_ratio=aspect_ratio
|
| 228 |
-
)
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
# Return operation details
|
| 232 |
-
return {
|
| 233 |
-
"gemini_operation_name": operation.name,
|
| 234 |
-
"done": operation.done,
|
| 235 |
-
"status": "completed" if operation.done else "pending"
|
| 236 |
-
}
|
| 237 |
-
except Exception as error:
|
| 238 |
-
self._handle_api_error(error, MODELS["video_generation"])
|
| 239 |
-
|
| 240 |
-
async def check_video_status(self, gemini_operation_name: str) -> dict:
|
| 241 |
-
"""
|
| 242 |
-
Check the status of a video generation operation.
|
| 243 |
-
Returns status and video URL if complete.
|
| 244 |
-
"""
|
| 245 |
-
# Mock mode for testing without API credits
|
| 246 |
-
if MOCK_MODE:
|
| 247 |
-
# Simulate processing time: complete after 2 checks (track via a simple mechanism)
|
| 248 |
-
# For simplicity, always return completed with mock video URL
|
| 249 |
-
logger.info(f"[MOCK MODE] Checking video status: {gemini_operation_name}")
|
| 250 |
-
await asyncio.sleep(2) # Simulate API delay
|
| 251 |
-
return {
|
| 252 |
-
"gemini_operation_name": gemini_operation_name,
|
| 253 |
-
"done": True,
|
| 254 |
-
"status": "completed",
|
| 255 |
-
"video_url": MOCK_VIDEO_URL
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
try:
|
| 259 |
-
# Get operation status using the operation object
|
| 260 |
-
# First, we need to recreate the operation from the name
|
| 261 |
-
from google.genai.types import GenerateVideosOperation
|
| 262 |
-
|
| 263 |
-
operation = await asyncio.to_thread(
|
| 264 |
-
self.client.operations.get,
|
| 265 |
-
GenerateVideosOperation(name=gemini_operation_name, done=False)
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
if not operation.done:
|
| 269 |
-
return {
|
| 270 |
-
"gemini_operation_name": gemini_operation_name,
|
| 271 |
-
"done": False,
|
| 272 |
-
"status": "pending"
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
# Check for error - handle both string and object types
|
| 276 |
-
if operation.error:
|
| 277 |
-
error_msg = operation.error
|
| 278 |
-
if hasattr(operation.error, 'message'):
|
| 279 |
-
error_msg = operation.error.message
|
| 280 |
-
return {
|
| 281 |
-
"gemini_operation_name": gemini_operation_name,
|
| 282 |
-
"done": True,
|
| 283 |
-
"status": "failed",
|
| 284 |
-
"error": str(error_msg) or "Unknown error"
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
# Extract video URI from result
|
| 288 |
-
result = operation.result
|
| 289 |
-
if result and hasattr(result, 'generated_videos') and result.generated_videos:
|
| 290 |
-
video = result.generated_videos[0]
|
| 291 |
-
if hasattr(video, 'video') and video.video and hasattr(video.video, 'uri'):
|
| 292 |
-
video_uri = video.video.uri
|
| 293 |
-
return {
|
| 294 |
-
"gemini_operation_name": gemini_operation_name,
|
| 295 |
-
"done": True,
|
| 296 |
-
"status": "completed",
|
| 297 |
-
"video_url": f"{video_uri}&key={self.api_key}"
|
| 298 |
-
}
|
| 299 |
-
|
| 300 |
-
return {
|
| 301 |
-
"gemini_operation_name": gemini_operation_name,
|
| 302 |
-
"done": True,
|
| 303 |
-
"status": "failed",
|
| 304 |
-
"error": "No video URI returned. May be due to safety filters."
|
| 305 |
-
}
|
| 306 |
-
|
| 307 |
-
except Exception as error:
|
| 308 |
-
msg = str(error)
|
| 309 |
-
if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg:
|
| 310 |
-
return {
|
| 311 |
-
"gemini_operation_name": gemini_operation_name,
|
| 312 |
-
"done": True,
|
| 313 |
-
"status": "failed",
|
| 314 |
-
"error": "Operation not found (404). It may have expired."
|
| 315 |
-
}
|
| 316 |
-
raise error
|
| 317 |
-
|
| 318 |
-
async def download_video(self, video_url: str, operation_id: str) -> str:
|
| 319 |
-
"""
|
| 320 |
-
Download video from Gemini to local storage.
|
| 321 |
-
Returns the local filename.
|
| 322 |
-
"""
|
| 323 |
-
filename = f"{operation_id}.mp4"
|
| 324 |
-
filepath = os.path.join(DOWNLOADS_DIR, filename)
|
| 325 |
-
|
| 326 |
-
try:
|
| 327 |
-
# follow_redirects=True is required as Gemini returns 302 redirects
|
| 328 |
-
async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
|
| 329 |
-
response = await client.get(video_url)
|
| 330 |
-
response.raise_for_status()
|
| 331 |
-
|
| 332 |
-
with open(filepath, 'wb') as f:
|
| 333 |
-
f.write(response.content)
|
| 334 |
-
|
| 335 |
-
logger.info(f"Downloaded video to {filepath}")
|
| 336 |
-
return filename
|
| 337 |
-
except Exception as e:
|
| 338 |
-
logger.error(f"Failed to download video: {e}")
|
| 339 |
-
raise ValueError(f"Failed to download video: {e}")
|
| 340 |
-
|
| 341 |
-
async def generate_text(
|
| 342 |
-
self,
|
| 343 |
-
prompt: str,
|
| 344 |
-
model: Optional[str] = None
|
| 345 |
-
) -> str:
|
| 346 |
-
"""
|
| 347 |
-
Simple text generation with Gemini.
|
| 348 |
-
"""
|
| 349 |
-
# Mock mode for testing
|
| 350 |
-
if MOCK_MODE:
|
| 351 |
-
logger.info(f"[MOCK MODE] Generating text for prompt: {prompt[:50]}...")
|
| 352 |
-
await asyncio.sleep(MOCK_MODE_SLEEP_TIME) # Simulate API delay
|
| 353 |
-
return f"This is a mock response for your prompt: '{prompt[:100]}...'. In production, this would be generated by Gemini AI."
|
| 354 |
-
|
| 355 |
-
model_name = model or MODELS["text_generation"]
|
| 356 |
-
async with get_text_semaphore():
|
| 357 |
-
try:
|
| 358 |
-
response = await asyncio.to_thread(
|
| 359 |
-
self.client.models.generate_content,
|
| 360 |
-
model=model_name,
|
| 361 |
-
contents=types.Content(
|
| 362 |
-
parts=[types.Part.from_text(text=prompt)]
|
| 363 |
-
)
|
| 364 |
-
)
|
| 365 |
-
return response.text or ""
|
| 366 |
-
except Exception as error:
|
| 367 |
-
self._handle_api_error(error, model_name)
|
| 368 |
-
|
| 369 |
-
async def analyze_image(
|
| 370 |
-
self,
|
| 371 |
-
base64_image: str,
|
| 372 |
-
mime_type: str,
|
| 373 |
-
prompt: str
|
| 374 |
-
) -> str:
|
| 375 |
-
"""
|
| 376 |
-
Analyze image with custom prompt.
|
| 377 |
-
"""
|
| 378 |
-
# Mock mode for testing
|
| 379 |
-
if MOCK_MODE:
|
| 380 |
-
logger.info(f"[MOCK MODE] Analyzing image with prompt: {prompt[:50]}...")
|
| 381 |
-
await asyncio.sleep(MOCK_MODE_SLEEP_TIME) # Simulate API delay
|
| 382 |
-
return f"Mock analysis result: The image appears to show a scene that matches your query '{prompt[:50]}...'. This is placeholder content for testing."
|
| 383 |
-
|
| 384 |
-
async with get_text_semaphore():
|
| 385 |
-
try:
|
| 386 |
-
response = await asyncio.to_thread(
|
| 387 |
-
self.client.models.generate_content,
|
| 388 |
-
model=MODELS["text_generation"],
|
| 389 |
-
contents=types.Content(
|
| 390 |
-
parts=[
|
| 391 |
-
types.Part.from_bytes(
|
| 392 |
-
data=base64_image,
|
| 393 |
-
mime_type=mime_type
|
| 394 |
-
),
|
| 395 |
-
types.Part.from_text(text=prompt)
|
| 396 |
-
]
|
| 397 |
-
)
|
| 398 |
-
)
|
| 399 |
-
return response.text or ""
|
| 400 |
-
except Exception as error:
|
| 401 |
-
self._handle_api_error(error, MODELS["text_generation"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/gemini_service/api_key_config.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
API Key Service Configuration
|
| 3 |
-
|
| 4 |
-
Configures automatic API key selection and rotation via middleware.
|
| 5 |
-
"""
|
| 6 |
-
from typing import List, Optional
|
| 7 |
-
import os
|
| 8 |
-
import logging
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class APIKeyServiceConfig:
|
| 14 |
-
"""Configuration for API key middleware."""
|
| 15 |
-
|
| 16 |
-
_rotation_strategy: str = "least_used" # or "round_robin"
|
| 17 |
-
_cooldown_seconds: int = 60
|
| 18 |
-
_max_requests_per_minute: int = 60
|
| 19 |
-
_retry_on_quota_error: bool = True
|
| 20 |
-
_api_keys: Optional[List[str]] = None
|
| 21 |
-
|
| 22 |
-
@classmethod
|
| 23 |
-
def register(
|
| 24 |
-
cls,
|
| 25 |
-
rotation_strategy: str = "least_used",
|
| 26 |
-
cooldown_seconds: int = 60,
|
| 27 |
-
max_requests_per_minute: int = 60,
|
| 28 |
-
retry_on_quota_error: bool = True
|
| 29 |
-
) -> None:
|
| 30 |
-
"""
|
| 31 |
-
Register API key service configuration.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
rotation_strategy: "least_used" or "round_robin"
|
| 35 |
-
cooldown_seconds: Time to wait before reusing a key after quota error
|
| 36 |
-
max_requests_per_minute: Rate limit per key
|
| 37 |
-
retry_on_quota_error: Auto-retry with different key on 429
|
| 38 |
-
|
| 39 |
-
Example:
|
| 40 |
-
APIKeyServiceConfig.register(
|
| 41 |
-
rotation_strategy="least_used",
|
| 42 |
-
cooldown_seconds=60,
|
| 43 |
-
retry_on_quota_error=True
|
| 44 |
-
)
|
| 45 |
-
"""
|
| 46 |
-
cls._rotation_strategy = rotation_strategy
|
| 47 |
-
cls._cooldown_seconds = cooldown_seconds
|
| 48 |
-
cls._max_requests_per_minute = max_requests_per_minute
|
| 49 |
-
cls._retry_on_quota_error = retry_on_quota_error
|
| 50 |
-
|
| 51 |
-
# Load API keys from env
|
| 52 |
-
cls._load_api_keys()
|
| 53 |
-
|
| 54 |
-
logger.info(
|
| 55 |
-
f"API Key Service configured: "
|
| 56 |
-
f"keys={len(cls._api_keys or [])}, "
|
| 57 |
-
f"strategy={rotation_strategy}, "
|
| 58 |
-
f"retry={retry_on_quota_error}"
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
@classmethod
|
| 62 |
-
def _load_api_keys(cls):
|
| 63 |
-
"""Load API keys from environment variables."""
|
| 64 |
-
keys_str = os.getenv("GEMINI_API_KEYS", "")
|
| 65 |
-
if not keys_str:
|
| 66 |
-
# Fallback to single key
|
| 67 |
-
single_key = os.getenv("GEMINI_API_KEY", "")
|
| 68 |
-
if single_key:
|
| 69 |
-
cls._api_keys = [single_key]
|
| 70 |
-
else:
|
| 71 |
-
cls._api_keys = []
|
| 72 |
-
logger.warning("No Gemini API keys configured!")
|
| 73 |
-
else:
|
| 74 |
-
cls._api_keys = [k.strip() for k in keys_str.split(",") if k.strip()]
|
| 75 |
-
|
| 76 |
-
if cls._api_keys:
|
| 77 |
-
logger.info(f"Loaded {len(cls._api_keys)} Gemini API key(s)")
|
| 78 |
-
|
| 79 |
-
@classmethod
|
| 80 |
-
def get_api_keys(cls) -> List[str]:
|
| 81 |
-
"""Get loaded API keys."""
|
| 82 |
-
if cls._api_keys is None:
|
| 83 |
-
cls._load_api_keys()
|
| 84 |
-
return cls._api_keys or []
|
| 85 |
-
|
| 86 |
-
@classmethod
|
| 87 |
-
def get_key_count(cls) -> int:
|
| 88 |
-
"""Get number of available keys."""
|
| 89 |
-
return len(cls.get_api_keys())
|
| 90 |
-
|
| 91 |
-
@classmethod
|
| 92 |
-
def get_config(cls) -> dict:
|
| 93 |
-
"""Get current configuration."""
|
| 94 |
-
return {
|
| 95 |
-
"key_count": cls.get_key_count(),
|
| 96 |
-
"rotation_strategy": cls._rotation_strategy,
|
| 97 |
-
"cooldown_seconds": cls._cooldown_seconds,
|
| 98 |
-
"max_requests_per_minute": cls._max_requests_per_minute,
|
| 99 |
-
"retry_on_quota_error": cls._retry_on_quota_error
|
| 100 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/gemini_service/api_key_middleware.py
DELETED
|
@@ -1,180 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
API Key Middleware - Automatic key selection and rotation
|
| 3 |
-
|
| 4 |
-
Automatically selects and injects Gemini API keys for requests.
|
| 5 |
-
Handles quota errors with automatic key rotation and retry.
|
| 6 |
-
"""
|
| 7 |
-
import time
|
| 8 |
-
import logging
|
| 9 |
-
from datetime import datetime, timedelta
|
| 10 |
-
from typing import Optional, Dict
|
| 11 |
-
from fastapi import Request, Response
|
| 12 |
-
from starlette.middleware.base import BaseHTTPMiddleware
|
| 13 |
-
from starlette.types import ASGIApp
|
| 14 |
-
|
| 15 |
-
from core.database import async_session_maker
|
| 16 |
-
from services.gemini_service.api_key_config import APIKeyServiceConfig
|
| 17 |
-
|
| 18 |
-
logger = logging.getLogger(__name__)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# Track key cooldowns in memory
|
| 22 |
-
_key_cooldowns: Dict[int, datetime] = {}
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class APIKeyMiddleware(BaseHTTPMiddleware):
|
| 26 |
-
"""
|
| 27 |
-
Middleware for automatic API key management.
|
| 28 |
-
|
| 29 |
-
Features:
|
| 30 |
-
- Automatic key selection based on strategy
|
| 31 |
-
- Quota error detection and recovery
|
| 32 |
-
- Key cooldown management
|
| 33 |
-
- Usage tracking
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
def __init__(self, app: ASGIApp):
|
| 37 |
-
super().__init__(app)
|
| 38 |
-
|
| 39 |
-
async def dispatch(self, request: Request, call_next):
|
| 40 |
-
"""
|
| 41 |
-
Process request with automatic API key injection.
|
| 42 |
-
|
| 43 |
-
Flow:
|
| 44 |
-
1. Check if Gemini request
|
| 45 |
-
2. Select best available key
|
| 46 |
-
3. Inject into request state
|
| 47 |
-
4. Handle response (quota errors)
|
| 48 |
-
"""
|
| 49 |
-
# Only handle Gemini requests
|
| 50 |
-
if not self._is_gemini_request(request):
|
| 51 |
-
return await call_next(request)
|
| 52 |
-
|
| 53 |
-
# Select API key
|
| 54 |
-
try:
|
| 55 |
-
key_index, api_key = await self._select_api_key()
|
| 56 |
-
request.state.gemini_api_key = api_key
|
| 57 |
-
request.state.gemini_key_index = key_index
|
| 58 |
-
except ValueError as e:
|
| 59 |
-
# No keys available
|
| 60 |
-
logger.error(f"No API keys available: {e}")
|
| 61 |
-
return Response(
|
| 62 |
-
content=f'{{"detail": "{str(e)}"}}',
|
| 63 |
-
status_code=503,
|
| 64 |
-
media_type="application/json"
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
# Process request
|
| 68 |
-
response = await call_next(request)
|
| 69 |
-
|
| 70 |
-
# Handle quota errors
|
| 71 |
-
if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error:
|
| 72 |
-
logger.warning(f"Quota error on key {key_index}, attempting retry")
|
| 73 |
-
|
| 74 |
-
# Mark key in cooldown
|
| 75 |
-
self._mark_cooldown(key_index)
|
| 76 |
-
|
| 77 |
-
# Try to select different key
|
| 78 |
-
try:
|
| 79 |
-
key_index, api_key = await self._select_api_key(exclude_index=key_index)
|
| 80 |
-
request.state.gemini_api_key = api_key
|
| 81 |
-
request.state.gemini_key_index = key_index
|
| 82 |
-
|
| 83 |
-
# Retry request
|
| 84 |
-
logger.info(f"Retrying with key {key_index}")
|
| 85 |
-
response = await call_next(request)
|
| 86 |
-
except ValueError:
|
| 87 |
-
# No other keys available
|
| 88 |
-
logger.error("All API keys in cooldown or exhausted")
|
| 89 |
-
|
| 90 |
-
# Track usage
|
| 91 |
-
success = response.status_code < 400
|
| 92 |
-
await self._track_usage(key_index, success, response.status_code)
|
| 93 |
-
|
| 94 |
-
return response
|
| 95 |
-
|
| 96 |
-
def _is_gemini_request(self, request: Request) -> bool:
|
| 97 |
-
"""Check if request is for Gemini service."""
|
| 98 |
-
path = request.url.path
|
| 99 |
-
gemini_paths = ["/gemini/", "/api/gemini"]
|
| 100 |
-
return any(path.startswith(p) for p in gemini_paths)
|
| 101 |
-
|
| 102 |
-
async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]:
|
| 103 |
-
"""
|
| 104 |
-
Select best available API key.
|
| 105 |
-
|
| 106 |
-
Args:
|
| 107 |
-
exclude_index: Key index to exclude (e.g., after quota error)
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
Tuple of (key_index, api_key)
|
| 111 |
-
|
| 112 |
-
Raises:
|
| 113 |
-
ValueError: If no keys available
|
| 114 |
-
"""
|
| 115 |
-
keys = APIKeyServiceConfig.get_api_keys()
|
| 116 |
-
if not keys:
|
| 117 |
-
raise ValueError("No API keys configured")
|
| 118 |
-
|
| 119 |
-
# Filter out excluded and cooldown keys
|
| 120 |
-
available_indices = []
|
| 121 |
-
for i in range(len(keys)):
|
| 122 |
-
if i == exclude_index:
|
| 123 |
-
continue
|
| 124 |
-
if self._is_in_cooldown(i):
|
| 125 |
-
continue
|
| 126 |
-
available_indices.append(i)
|
| 127 |
-
|
| 128 |
-
if not available_indices:
|
| 129 |
-
raise ValueError("All API keys in cooldown")
|
| 130 |
-
|
| 131 |
-
# Select based on strategy
|
| 132 |
-
if APIKeyServiceConfig._rotation_strategy == "round_robin":
|
| 133 |
-
# Simple round-robin
|
| 134 |
-
selected_index = available_indices[0]
|
| 135 |
-
else: # least_used
|
| 136 |
-
# Get usage stats from DB
|
| 137 |
-
async with async_session_maker() as db:
|
| 138 |
-
from services.api_key_manager import get_least_used_key
|
| 139 |
-
try:
|
| 140 |
-
selected_index, _ = await get_least_used_key(db)
|
| 141 |
-
if selected_index not in available_indices:
|
| 142 |
-
# Fallback to first available
|
| 143 |
-
selected_index = available_indices[0]
|
| 144 |
-
except Exception as e:
|
| 145 |
-
logger.error(f"Error getting least used key: {e}")
|
| 146 |
-
selected_index = available_indices[0]
|
| 147 |
-
|
| 148 |
-
logger.debug(f"Selected API key index {selected_index}")
|
| 149 |
-
return selected_index, keys[selected_index]
|
| 150 |
-
|
| 151 |
-
def _is_in_cooldown(self, key_index: int) -> bool:
|
| 152 |
-
"""Check if key is in cooldown period."""
|
| 153 |
-
if key_index not in _key_cooldowns:
|
| 154 |
-
return False
|
| 155 |
-
|
| 156 |
-
cooldown_until = _key_cooldowns[key_index]
|
| 157 |
-
if datetime.utcnow() > cooldown_until:
|
| 158 |
-
# Cooldown expired
|
| 159 |
-
del _key_cooldowns[key_index]
|
| 160 |
-
return False
|
| 161 |
-
|
| 162 |
-
return True
|
| 163 |
-
|
| 164 |
-
def _mark_cooldown(self, key_index: int):
|
| 165 |
-
"""Mark key as in cooldown."""
|
| 166 |
-
cooldown_seconds = APIKeyServiceConfig._cooldown_seconds
|
| 167 |
-
cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds)
|
| 168 |
-
_key_cooldowns[key_index] = cooldown_until
|
| 169 |
-
logger.info(f"Key {key_index} in cooldown until {cooldown_until}")
|
| 170 |
-
|
| 171 |
-
async def _track_usage(self, key_index: int, success: bool, status_code: int):
|
| 172 |
-
"""Track API key usage."""
|
| 173 |
-
try:
|
| 174 |
-
async with async_session_maker() as db:
|
| 175 |
-
from services.api_key_manager import record_usage
|
| 176 |
-
error_message = f"HTTP {status_code}" if not success else None
|
| 177 |
-
await record_usage(db, key_index, success, error_message)
|
| 178 |
-
await db.commit()
|
| 179 |
-
except Exception as e:
|
| 180 |
-
logger.error(f"Failed to track usage: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/gemini_service/job_processor.py
DELETED
|
@@ -1,378 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Gemini Job Worker - Specific implementation using the modular PriorityWorkerPool.
|
| 3 |
-
|
| 4 |
-
This file shows how to use the modular PriorityWorkerPool with Gemini-specific
|
| 5 |
-
job processing logic.
|
| 6 |
-
"""
|
| 7 |
-
import logging
|
| 8 |
-
from datetime import datetime, timedelta
|
| 9 |
-
from typing import Optional
|
| 10 |
-
from sqlalchemy.ext.asyncio import AsyncSession
|
| 11 |
-
|
| 12 |
-
from core.database import DATABASE_URL
|
| 13 |
-
from core.models import GeminiJob
|
| 14 |
-
from services.priority_worker_pool import (
|
| 15 |
-
PriorityWorkerPool,
|
| 16 |
-
JobProcessor,
|
| 17 |
-
WorkerConfig,
|
| 18 |
-
get_interval_for_priority
|
| 19 |
-
)
|
| 20 |
-
from services.gemini_service.api_client import GeminiService
|
| 21 |
-
from services.drive_service import DriveService
|
| 22 |
-
import asyncio
|
| 23 |
-
|
| 24 |
-
logger = logging.getLogger(__name__)
|
| 25 |
-
|
| 26 |
-
# Job type to priority mapping for Gemini jobs
|
| 27 |
-
JOB_PRIORITY_MAP = {
|
| 28 |
-
"text": "fast",
|
| 29 |
-
"analyze": "fast",
|
| 30 |
-
"animation_prompt": "fast",
|
| 31 |
-
"image": "medium",
|
| 32 |
-
"edit_image": "medium",
|
| 33 |
-
"video": "slow"
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get_priority_for_job_type(job_type: str) -> str:
|
| 38 |
-
"""Get the priority tier for a Gemini job type."""
|
| 39 |
-
return JOB_PRIORITY_MAP.get(job_type, "fast")
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class GeminiJobProcessor(JobProcessor[GeminiJob]):
|
| 43 |
-
"""Processes Gemini AI jobs (text, image, video generation) with round-robin API keys."""
|
| 44 |
-
|
| 45 |
-
def __init__(self):
|
| 46 |
-
self.drive_service = DriveService()
|
| 47 |
-
|
| 48 |
-
async def _get_service_with_key(self, session: AsyncSession) -> tuple:
|
| 49 |
-
"""Get a GeminiService with the least-used API key."""
|
| 50 |
-
from services.api_key_manager import get_least_used_key
|
| 51 |
-
key_index, api_key = await get_least_used_key(session)
|
| 52 |
-
return key_index, GeminiService(api_key=api_key)
|
| 53 |
-
|
| 54 |
-
async def _record_usage(self, session: AsyncSession, key_index: int, success: bool, error_message: Optional[str] = None):
|
| 55 |
-
"""Record API key usage after request."""
|
| 56 |
-
from services.api_key_manager import record_usage
|
| 57 |
-
await record_usage(session, key_index, success, error_message)
|
| 58 |
-
|
| 59 |
-
def _handle_error(self, job: GeminiJob, error: Exception, reset_to_queued: bool = False) -> tuple[bool, str]:
|
| 60 |
-
"""
|
| 61 |
-
Handle job errors with retry logic.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
job: The job object
|
| 65 |
-
error: The exception raised
|
| 66 |
-
reset_to_queued: Whether to reset status to 'queued' on retry (for process())
|
| 67 |
-
|
| 68 |
-
Returns:
|
| 69 |
-
Tuple of (success, error_message)
|
| 70 |
-
success is False (since it's an error)
|
| 71 |
-
error_message is the formatted error string
|
| 72 |
-
"""
|
| 73 |
-
error_str = str(error)
|
| 74 |
-
is_retryable = False
|
| 75 |
-
log_msg = ""
|
| 76 |
-
|
| 77 |
-
# Check for Rate Limit (429)
|
| 78 |
-
if "429" in error_str or "ResourceExhausted" in error_str:
|
| 79 |
-
is_retryable = True
|
| 80 |
-
log_msg = f"Rate limit hit for job {job.job_id}"
|
| 81 |
-
|
| 82 |
-
# Check for Auth/Billing errors (401, 403, API key not found, API key not valid, FAILED_PRECONDITION)
|
| 83 |
-
elif "401" in error_str or "403" in error_str or "Unauthenticated" in error_str or "PermissionDenied" in error_str or "API key not found" in error_str or "API key not valid" in error_str or "FAILED_PRECONDITION" in error_str:
|
| 84 |
-
is_retryable = True
|
| 85 |
-
log_msg = f"Auth/Billing error for job {job.job_id}: {error_str}. Rescheduling to try different key."
|
| 86 |
-
|
| 87 |
-
# Check for Server errors (500, 503, 504)
|
| 88 |
-
elif "500" in error_str or "503" in error_str or "504" in error_str or "INTERNAL" in error_str or "UNAVAILABLE" in error_str or "DEADLINE_EXCEEDED" in error_str:
|
| 89 |
-
is_retryable = True
|
| 90 |
-
log_msg = f"Server error for job {job.job_id}: {error_str}"
|
| 91 |
-
|
| 92 |
-
# Try to parse JSON error details if present
|
| 93 |
-
try:
|
| 94 |
-
import json
|
| 95 |
-
import re
|
| 96 |
-
# Look for JSON-like structure in error string
|
| 97 |
-
json_match = re.search(r"(\{.*\})", error_str)
|
| 98 |
-
if json_match:
|
| 99 |
-
job.api_response = json.loads(json_match.group(1))
|
| 100 |
-
else:
|
| 101 |
-
job.api_response = {"error": error_str}
|
| 102 |
-
except Exception:
|
| 103 |
-
job.api_response = {"error": error_str}
|
| 104 |
-
|
| 105 |
-
if is_retryable:
|
| 106 |
-
logger.warning(f"{log_msg}. Rescheduling.")
|
| 107 |
-
job.retry_count += 1
|
| 108 |
-
config = WorkerConfig.from_env()
|
| 109 |
-
# Use a longer delay for these errors (e.g., 30s)
|
| 110 |
-
interval = 30
|
| 111 |
-
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
|
| 112 |
-
|
| 113 |
-
if reset_to_queued:
|
| 114 |
-
job.status = "queued"
|
| 115 |
-
|
| 116 |
-
return False, f"Retryable error: {error_str}"
|
| 117 |
-
else:
|
| 118 |
-
logger.error(f"Error processing job {job.job_id}: {error}")
|
| 119 |
-
job.status = "failed"
|
| 120 |
-
job.error_message = str(error)
|
| 121 |
-
job.completed_at = datetime.utcnow()
|
| 122 |
-
return False, str(error)
|
| 123 |
-
|
| 124 |
-
async def process(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
|
| 125 |
-
"""Start processing a new job with round-robin API key."""
|
| 126 |
-
key_index, service = await self._get_service_with_key(session)
|
| 127 |
-
input_data = job.input_data or {}
|
| 128 |
-
success = False
|
| 129 |
-
error_msg = None
|
| 130 |
-
|
| 131 |
-
try:
|
| 132 |
-
if job.job_type == "video":
|
| 133 |
-
job = await self._start_video(job, session, service, input_data)
|
| 134 |
-
success = True
|
| 135 |
-
elif job.job_type == "image":
|
| 136 |
-
job = await self._process_image(job, service, input_data)
|
| 137 |
-
success = True
|
| 138 |
-
elif job.job_type == "text":
|
| 139 |
-
job = await self._process_text(job, service, input_data)
|
| 140 |
-
success = True
|
| 141 |
-
elif job.job_type == "analyze":
|
| 142 |
-
job = await self._process_analyze(job, service, input_data)
|
| 143 |
-
success = True
|
| 144 |
-
elif job.job_type == "animation_prompt":
|
| 145 |
-
job = await self._process_animation_prompt(job, service, input_data)
|
| 146 |
-
success = True
|
| 147 |
-
else:
|
| 148 |
-
job.status = "failed"
|
| 149 |
-
job.error_message = f"Unknown job type: {job.job_type}"
|
| 150 |
-
job.completed_at = datetime.utcnow()
|
| 151 |
-
error_msg = job.error_message
|
| 152 |
-
except Exception as e:
|
| 153 |
-
# Use helper for error handling
|
| 154 |
-
# reset_to_queued=True because if we fail to start, we want to try starting again from scratch
|
| 155 |
-
success, error_msg = self._handle_error(job, e, reset_to_queued=True)
|
| 156 |
-
|
| 157 |
-
# Record usage
|
| 158 |
-
await self._record_usage(session, key_index, success, error_msg)
|
| 159 |
-
|
| 160 |
-
return job
|
| 161 |
-
|
| 162 |
-
async def check_status(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
|
| 163 |
-
"""Check status of an in-progress job (video generation)."""
|
| 164 |
-
if job.job_type != "video" or not job.third_party_id:
|
| 165 |
-
job.status = "failed"
|
| 166 |
-
job.error_message = "Invalid job state for status check"
|
| 167 |
-
job.completed_at = datetime.utcnow()
|
| 168 |
-
return job
|
| 169 |
-
|
| 170 |
-
# Use round-robin key for status check
|
| 171 |
-
key_index, service = await self._get_service_with_key(session)
|
| 172 |
-
success = False
|
| 173 |
-
error_msg = None
|
| 174 |
-
|
| 175 |
-
try:
|
| 176 |
-
status_result = await service.check_video_status(job.third_party_id)
|
| 177 |
-
# Save raw response
|
| 178 |
-
job.api_response = status_result
|
| 179 |
-
|
| 180 |
-
if status_result.get("done"):
|
| 181 |
-
if status_result.get("status") == "completed":
|
| 182 |
-
video_url = status_result.get("video_url")
|
| 183 |
-
if video_url:
|
| 184 |
-
# Store video URL - download will happen on-demand when client requests
|
| 185 |
-
job.status = "completed"
|
| 186 |
-
job.output_data = {"video_url": video_url}
|
| 187 |
-
job.error_message = None # Clear any previous error
|
| 188 |
-
job.completed_at = datetime.utcnow()
|
| 189 |
-
success = True
|
| 190 |
-
# Sync DB on success
|
| 191 |
-
from services.backup_service import get_backup_service
|
| 192 |
-
backup_service = get_backup_service()
|
| 193 |
-
await backup_service.backup_async()
|
| 194 |
-
else:
|
| 195 |
-
job.status = "failed"
|
| 196 |
-
job.error_message = "No video URL returned"
|
| 197 |
-
job.completed_at = datetime.utcnow()
|
| 198 |
-
error_msg = job.error_message
|
| 199 |
-
else:
|
| 200 |
-
job.status = "failed"
|
| 201 |
-
job.error_message = status_result.get("error", "Unknown error")
|
| 202 |
-
job.completed_at = datetime.utcnow()
|
| 203 |
-
error_msg = job.error_message
|
| 204 |
-
else:
|
| 205 |
-
# Not done - reschedule
|
| 206 |
-
job.retry_count += 1
|
| 207 |
-
config = WorkerConfig.from_env()
|
| 208 |
-
interval = get_interval_for_priority(job.priority, config)
|
| 209 |
-
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
|
| 210 |
-
success = True # Status check succeeded even if video not ready
|
| 211 |
-
|
| 212 |
-
except Exception as e:
|
| 213 |
-
# Use helper for error handling
|
| 214 |
-
# reset_to_queued=False because we want to continue checking status, not restart
|
| 215 |
-
success, error_msg = self._handle_error(job, e, reset_to_queued=False)
|
| 216 |
-
|
| 217 |
-
# Record usage
|
| 218 |
-
await self._record_usage(session, key_index, success, error_msg)
|
| 219 |
-
|
| 220 |
-
return job
|
| 221 |
-
|
| 222 |
-
# Record usage
|
| 223 |
-
await self._record_usage(session, key_index, success, error_msg)
|
| 224 |
-
|
| 225 |
-
return job
|
| 226 |
-
|
| 227 |
-
async def _start_video(self, job: GeminiJob, session: AsyncSession, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 228 |
-
"""Start async video generation."""
|
| 229 |
-
prompt = input_data.get("prompt", "")
|
| 230 |
-
|
| 231 |
-
# If prompt is missing, generate one using the animation template
|
| 232 |
-
if not prompt:
|
| 233 |
-
try:
|
| 234 |
-
import os
|
| 235 |
-
template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "prompt", "animation.md")
|
| 236 |
-
if os.path.exists(template_path):
|
| 237 |
-
with open(template_path, "r") as f:
|
| 238 |
-
template_prompt = f.read().strip()
|
| 239 |
-
|
| 240 |
-
logger.info(f"Generating auto-prompt for job {job.job_id} using template")
|
| 241 |
-
prompt = await service.generate_animation_prompt(
|
| 242 |
-
base64_image=input_data.get("base64_image", ""),
|
| 243 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 244 |
-
custom_prompt=template_prompt
|
| 245 |
-
)
|
| 246 |
-
logger.info(f"Generated prompt for job {job.job_id}: {prompt}")
|
| 247 |
-
|
| 248 |
-
# Update input data with generated prompt for reference
|
| 249 |
-
# Create a new dictionary to ensure SQLAlchemy detects the change
|
| 250 |
-
new_input_data = dict(input_data)
|
| 251 |
-
new_input_data["prompt"] = prompt
|
| 252 |
-
job.input_data = new_input_data
|
| 253 |
-
# We need to commit this change to DB so it persists
|
| 254 |
-
# But session commit happens outside this method usually?
|
| 255 |
-
# Actually process() calls this, and process() returns job,
|
| 256 |
-
# but doesn't explicitly commit job changes until later?
|
| 257 |
-
# The worker loop commits after process() returns.
|
| 258 |
-
else:
|
| 259 |
-
logger.warning(f"Animation prompt template not found at {template_path}")
|
| 260 |
-
except Exception as e:
|
| 261 |
-
logger.error(f"Failed to generate auto-prompt: {e}")
|
| 262 |
-
# Fallback to empty prompt or error?
|
| 263 |
-
# Let's let it proceed with empty prompt which might fail at API level or use API default
|
| 264 |
-
|
| 265 |
-
result = await service.start_video_generation(
|
| 266 |
-
base64_image=input_data.get("base64_image", ""),
|
| 267 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 268 |
-
prompt=prompt,
|
| 269 |
-
aspect_ratio=input_data.get("aspect_ratio", "16:9"),
|
| 270 |
-
resolution=input_data.get("resolution", "720p"),
|
| 271 |
-
number_of_videos=input_data.get("number_of_videos", 1)
|
| 272 |
-
)
|
| 273 |
-
job.third_party_id = result.get("gemini_operation_name")
|
| 274 |
-
job.api_response = result
|
| 275 |
-
|
| 276 |
-
# Schedule first status check
|
| 277 |
-
config = WorkerConfig.from_env()
|
| 278 |
-
interval = get_interval_for_priority(job.priority, config)
|
| 279 |
-
job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
|
| 280 |
-
|
| 281 |
-
return job
|
| 282 |
-
|
| 283 |
-
async def _process_image(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 284 |
-
"""Process image edit (synchronous)."""
|
| 285 |
-
result = await service.edit_image(
|
| 286 |
-
base64_image=input_data.get("base64_image", ""),
|
| 287 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 288 |
-
prompt=input_data.get("prompt", "")
|
| 289 |
-
)
|
| 290 |
-
job.status = "completed"
|
| 291 |
-
job.output_data = {"image": result}
|
| 292 |
-
# Don't save full base64 image to api_response
|
| 293 |
-
job.api_response = {"status": "success", "type": "image_edit"}
|
| 294 |
-
job.completed_at = datetime.utcnow()
|
| 295 |
-
# Sync DB on success
|
| 296 |
-
from services.backup_service import get_backup_service
|
| 297 |
-
backup_service = get_backup_service()
|
| 298 |
-
await backup_service.backup_async()
|
| 299 |
-
return job
|
| 300 |
-
|
| 301 |
-
async def _process_text(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 302 |
-
"""Process text generation (synchronous)."""
|
| 303 |
-
result = await service.generate_text(
|
| 304 |
-
prompt=input_data.get("prompt", ""),
|
| 305 |
-
model=input_data.get("model")
|
| 306 |
-
)
|
| 307 |
-
job.status = "completed"
|
| 308 |
-
job.output_data = {"text": result}
|
| 309 |
-
job.api_response = {"result": result}
|
| 310 |
-
job.completed_at = datetime.utcnow()
|
| 311 |
-
# Sync DB on success
|
| 312 |
-
from services.backup_service import get_backup_service
|
| 313 |
-
backup_service = get_backup_service()
|
| 314 |
-
await backup_service.backup_async()
|
| 315 |
-
return job
|
| 316 |
-
|
| 317 |
-
async def _process_analyze(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 318 |
-
"""Process image analysis (synchronous)."""
|
| 319 |
-
result = await service.analyze_image(
|
| 320 |
-
base64_image=input_data.get("base64_image", ""),
|
| 321 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 322 |
-
prompt=input_data.get("prompt", "")
|
| 323 |
-
)
|
| 324 |
-
job.status = "completed"
|
| 325 |
-
job.output_data = {"analysis": result}
|
| 326 |
-
job.api_response = {"result": result}
|
| 327 |
-
job.completed_at = datetime.utcnow()
|
| 328 |
-
# Sync DB on success
|
| 329 |
-
from services.backup_service import get_backup_service
|
| 330 |
-
backup_service = get_backup_service()
|
| 331 |
-
await backup_service.backup_async()
|
| 332 |
-
return job
|
| 333 |
-
|
| 334 |
-
async def _process_animation_prompt(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
|
| 335 |
-
"""Process animation prompt generation (synchronous)."""
|
| 336 |
-
result = await service.generate_animation_prompt(
|
| 337 |
-
base64_image=input_data.get("base64_image", ""),
|
| 338 |
-
mime_type=input_data.get("mime_type", "image/jpeg"),
|
| 339 |
-
custom_prompt=input_data.get("custom_prompt")
|
| 340 |
-
)
|
| 341 |
-
job.status = "completed"
|
| 342 |
-
job.output_data = {"prompt": result}
|
| 343 |
-
job.api_response = {"result": result}
|
| 344 |
-
job.completed_at = datetime.utcnow()
|
| 345 |
-
# Sync DB on success
|
| 346 |
-
from services.backup_service import get_backup_service
|
| 347 |
-
backup_service = get_backup_service()
|
| 348 |
-
await backup_service.backup_async()
|
| 349 |
-
return job
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
# Singleton pool instance
|
| 353 |
-
_pool: Optional[PriorityWorkerPool] = None
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
def get_pool() -> PriorityWorkerPool:
|
| 357 |
-
"""Get the global Gemini worker pool instance."""
|
| 358 |
-
global _pool
|
| 359 |
-
if _pool is None:
|
| 360 |
-
_pool = PriorityWorkerPool(
|
| 361 |
-
database_url=DATABASE_URL,
|
| 362 |
-
job_model=GeminiJob,
|
| 363 |
-
job_processor=GeminiJobProcessor(),
|
| 364 |
-
config=WorkerConfig.from_env()
|
| 365 |
-
)
|
| 366 |
-
return _pool
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
async def start_worker():
|
| 370 |
-
"""Start the Gemini job worker pool."""
|
| 371 |
-
pool = get_pool()
|
| 372 |
-
await pool.start()
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
async def stop_worker():
|
| 376 |
-
"""Stop the Gemini job worker pool."""
|
| 377 |
-
pool = get_pool()
|
| 378 |
-
await pool.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/priority_worker_pool.py
DELETED
|
@@ -1,547 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Modular Priority-Tier Worker Pool
|
| 3 |
-
|
| 4 |
-
A self-contained, plug-and-play worker pool for processing async jobs
|
| 5 |
-
with priority-tier scheduling. Can be used in any Python application.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
from services.priority_worker_pool import PriorityWorkerPool, WorkerConfig
|
| 9 |
-
|
| 10 |
-
# Define your job processor function
|
| 11 |
-
async def process_my_job(job, session):
|
| 12 |
-
# Process job and return updated job
|
| 13 |
-
job.status = "completed"
|
| 14 |
-
job.output_data = {"result": "done"}
|
| 15 |
-
return job
|
| 16 |
-
|
| 17 |
-
# Configure and start pool
|
| 18 |
-
pool = PriorityWorkerPool(
|
| 19 |
-
database_url="sqlite+aiosqlite:///./my_db.db",
|
| 20 |
-
job_model=MyJobModel,
|
| 21 |
-
job_processor=process_my_job,
|
| 22 |
-
config=WorkerConfig(fast_workers=5, medium_workers=5, slow_workers=5)
|
| 23 |
-
)
|
| 24 |
-
await pool.start()
|
| 25 |
-
|
| 26 |
-
Environment Variables (optional):
|
| 27 |
-
FAST_WORKERS: Number of fast workers (default: 5)
|
| 28 |
-
MEDIUM_WORKERS: Number of medium workers (default: 5)
|
| 29 |
-
SLOW_WORKERS: Number of slow workers (default: 5)
|
| 30 |
-
FAST_INTERVAL: Fast tier polling interval in seconds (default: 5)
|
| 31 |
-
MEDIUM_INTERVAL: Medium tier polling interval in seconds (default: 30)
|
| 32 |
-
SLOW_INTERVAL: Slow tier polling interval in seconds (default: 60)
|
| 33 |
-
|
| 34 |
-
Dependencies:
|
| 35 |
-
sqlalchemy[asyncio]>=2.0.0
|
| 36 |
-
aiosqlite (for SQLite) or asyncpg (for PostgreSQL)
|
| 37 |
-
|
| 38 |
-
Job Model Requirements:
|
| 39 |
-
Your job model must have these columns:
|
| 40 |
-
- job_id: str (unique identifier)
|
| 41 |
-
- status: str (queued, processing, completed, failed, cancelled)
|
| 42 |
-
- priority: str (fast, medium, slow)
|
| 43 |
-
- next_process_at: datetime (nullable, for rescheduling)
|
| 44 |
-
- retry_count: int (default 0)
|
| 45 |
-
- created_at: datetime
|
| 46 |
-
- started_at: datetime (nullable)
|
| 47 |
-
- completed_at: datetime (nullable)
|
| 48 |
-
- error_message: str (nullable)
|
| 49 |
-
"""
|
| 50 |
-
import asyncio
|
| 51 |
-
import logging
|
| 52 |
-
import os
|
| 53 |
-
from abc import ABC, abstractmethod
|
| 54 |
-
from dataclasses import dataclass, field
|
| 55 |
-
from datetime import datetime, timedelta
|
| 56 |
-
from typing import Optional, List, Callable, Any, TypeVar, Generic
|
| 57 |
-
from sqlalchemy import select, or_, and_
|
| 58 |
-
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 59 |
-
|
| 60 |
-
logger = logging.getLogger(__name__)
|
| 61 |
-
|
| 62 |
-
# Generic type for job model
|
| 63 |
-
JobType = TypeVar('JobType')
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
@dataclass
|
| 67 |
-
class WorkerConfig:
|
| 68 |
-
"""Configuration for the worker pool."""
|
| 69 |
-
fast_workers: int = 5
|
| 70 |
-
medium_workers: int = 5
|
| 71 |
-
slow_workers: int = 5
|
| 72 |
-
fast_interval: int = 2 # seconds
|
| 73 |
-
medium_interval: int = 10 # seconds
|
| 74 |
-
slow_interval: int = 15 # seconds
|
| 75 |
-
max_retries: int = 60 # Max retry attempts before failing
|
| 76 |
-
job_per_api_key: int = 1 # Max concurrent jobs per API key
|
| 77 |
-
|
| 78 |
-
@classmethod
|
| 79 |
-
def from_env(cls) -> 'WorkerConfig':
|
| 80 |
-
"""Create config from environment variables."""
|
| 81 |
-
return cls(
|
| 82 |
-
fast_workers=int(os.getenv("FAST_WORKERS", "5")),
|
| 83 |
-
medium_workers=int(os.getenv("MEDIUM_WORKERS", "5")),
|
| 84 |
-
slow_workers=int(os.getenv("SLOW_WORKERS", "5")),
|
| 85 |
-
fast_interval=int(os.getenv("FAST_INTERVAL", "5")),
|
| 86 |
-
medium_interval=int(os.getenv("MEDIUM_INTERVAL", "30")),
|
| 87 |
-
slow_interval=int(os.getenv("SLOW_INTERVAL", "60")),
|
| 88 |
-
job_per_api_key=int(os.getenv("JOB_PER_API_KEY", "1")),
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
@dataclass
|
| 93 |
-
class PriorityMapping:
|
| 94 |
-
"""Maps job types to priority tiers."""
|
| 95 |
-
mappings: dict = field(default_factory=dict)
|
| 96 |
-
|
| 97 |
-
def get_priority(self, job_type: str, default: str = "fast") -> str:
|
| 98 |
-
"""Get priority for a job type."""
|
| 99 |
-
return self.mappings.get(job_type, default)
|
| 100 |
-
|
| 101 |
-
def get_interval(self, priority: str, config: WorkerConfig) -> int:
|
| 102 |
-
"""Get polling interval for a priority tier."""
|
| 103 |
-
if priority == "fast":
|
| 104 |
-
return config.fast_interval
|
| 105 |
-
elif priority == "medium":
|
| 106 |
-
return config.medium_interval
|
| 107 |
-
else:
|
| 108 |
-
return config.slow_interval
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
class JobProcessor(ABC, Generic[JobType]):
|
| 112 |
-
"""Abstract base class for job processors."""
|
| 113 |
-
|
| 114 |
-
@abstractmethod
|
| 115 |
-
async def process(self, job: JobType, session: AsyncSession) -> JobType:
|
| 116 |
-
"""
|
| 117 |
-
Process a job and return the updated job.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
job: The job to process
|
| 121 |
-
session: Database session for updates
|
| 122 |
-
|
| 123 |
-
Returns:
|
| 124 |
-
The updated job with new status/output
|
| 125 |
-
"""
|
| 126 |
-
pass
|
| 127 |
-
|
| 128 |
-
@abstractmethod
|
| 129 |
-
async def check_status(self, job: JobType, session: AsyncSession) -> JobType:
|
| 130 |
-
"""
|
| 131 |
-
Check status of an in-progress job (for async third-party operations).
|
| 132 |
-
|
| 133 |
-
Args:
|
| 134 |
-
job: The job to check
|
| 135 |
-
session: Database session for updates
|
| 136 |
-
|
| 137 |
-
Returns:
|
| 138 |
-
The updated job. Set next_process_at to reschedule if not done.
|
| 139 |
-
"""
|
| 140 |
-
pass
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
class PriorityWorker(Generic[JobType]):
|
| 144 |
-
"""Worker that processes jobs of a specific priority tier."""
|
| 145 |
-
|
| 146 |
-
def __init__(
|
| 147 |
-
self,
|
| 148 |
-
worker_id: int,
|
| 149 |
-
priority: str,
|
| 150 |
-
poll_interval: int,
|
| 151 |
-
session_maker: async_sessionmaker,
|
| 152 |
-
job_model: type,
|
| 153 |
-
job_processor: JobProcessor[JobType],
|
| 154 |
-
max_retries: int = 60,
|
| 155 |
-
wake_event: Optional[asyncio.Event] = None,
|
| 156 |
-
config: Optional[WorkerConfig] = None
|
| 157 |
-
):
|
| 158 |
-
self.worker_id = worker_id
|
| 159 |
-
self.priority = priority
|
| 160 |
-
self.poll_interval = poll_interval
|
| 161 |
-
self.session_maker = session_maker
|
| 162 |
-
self.job_model = job_model
|
| 163 |
-
self.job_processor = job_processor
|
| 164 |
-
self.max_retries = max_retries
|
| 165 |
-
self._running = False
|
| 166 |
-
self._current_job_id: Optional[str] = None
|
| 167 |
-
self._wake_event = wake_event # Event to wake worker immediately when new jobs arrive
|
| 168 |
-
self._config = config or WorkerConfig.from_env()
|
| 169 |
-
|
| 170 |
-
async def start(self):
|
| 171 |
-
"""Start the worker polling loop."""
|
| 172 |
-
self._running = True
|
| 173 |
-
logger.debug(f"Worker {self.worker_id} ({self.priority}) started, polling every {self.poll_interval}s")
|
| 174 |
-
asyncio.create_task(self._poll_loop())
|
| 175 |
-
|
| 176 |
-
async def stop(self):
|
| 177 |
-
"""Stop the worker."""
|
| 178 |
-
self._running = False
|
| 179 |
-
logger.info(f"Worker {self.worker_id} ({self.priority}) stopped")
|
| 180 |
-
|
| 181 |
-
async def _poll_loop(self):
|
| 182 |
-
"""Main polling loop with optimized scheduling.
|
| 183 |
-
|
| 184 |
-
Optimizations:
|
| 185 |
-
- When no jobs are found, sleep for poll_interval before checking again
|
| 186 |
-
- When a job is processed, immediately check for the next job (no waiting)
|
| 187 |
-
- This ensures first job starts immediately when queue was empty
|
| 188 |
-
- This ensures next job starts immediately after current job finishes
|
| 189 |
-
"""
|
| 190 |
-
while self._running:
|
| 191 |
-
job_found = False
|
| 192 |
-
try:
|
| 193 |
-
job_found = await self._process_one_job()
|
| 194 |
-
except Exception as e:
|
| 195 |
-
logger.error(f"Worker {self.worker_id}: Error in poll loop: {e}")
|
| 196 |
-
|
| 197 |
-
# Only sleep if no job was found - otherwise immediately look for next job
|
| 198 |
-
if not job_found:
|
| 199 |
-
# Wait on event with timeout - allows immediate wake-up when new job arrives
|
| 200 |
-
if self._wake_event:
|
| 201 |
-
try:
|
| 202 |
-
# Wait for event or timeout (whichever comes first)
|
| 203 |
-
await asyncio.wait_for(
|
| 204 |
-
self._wake_event.wait(),
|
| 205 |
-
timeout=self.poll_interval
|
| 206 |
-
)
|
| 207 |
-
# Clear event after waking (we'll check for jobs)
|
| 208 |
-
self._wake_event.clear()
|
| 209 |
-
except asyncio.TimeoutError:
|
| 210 |
-
pass # Normal timeout, check for jobs
|
| 211 |
-
else:
|
| 212 |
-
await asyncio.sleep(self.poll_interval)
|
| 213 |
-
|
| 214 |
-
async def _process_one_job(self) -> bool:
|
| 215 |
-
"""Find and process one job.
|
| 216 |
-
|
| 217 |
-
Enforces constraints:
|
| 218 |
-
1. Only one job per user can be in processing state at a time
|
| 219 |
-
2. Total processing jobs limited to JOB_PER_API_KEY * number of API keys
|
| 220 |
-
|
| 221 |
-
Returns:
|
| 222 |
-
True if a job was found and processed, False if no jobs available
|
| 223 |
-
"""
|
| 224 |
-
async with self.session_maker() as session:
|
| 225 |
-
from sqlalchemy import func
|
| 226 |
-
|
| 227 |
-
now = datetime.utcnow()
|
| 228 |
-
|
| 229 |
-
# Get number of API keys for capacity calculation
|
| 230 |
-
try:
|
| 231 |
-
from services.api_key_manager import get_key_count
|
| 232 |
-
num_api_keys = get_key_count()
|
| 233 |
-
max_processing = self._config.job_per_api_key * num_api_keys
|
| 234 |
-
except ImportError:
|
| 235 |
-
max_processing = 10 # Default fallback
|
| 236 |
-
|
| 237 |
-
# Check if we're at max processing capacity (only for new jobs being picked up)
|
| 238 |
-
count_query = select(func.count()).where(
|
| 239 |
-
self.job_model.status == "processing"
|
| 240 |
-
)
|
| 241 |
-
count_result = await session.execute(count_query)
|
| 242 |
-
current_processing = count_result.scalar() or 0
|
| 243 |
-
|
| 244 |
-
# Query for jobs matching this priority tier
|
| 245 |
-
query = select(self.job_model).where(
|
| 246 |
-
and_(
|
| 247 |
-
self.job_model.priority == self.priority,
|
| 248 |
-
self.job_model.status.in_(["queued", "processing"]),
|
| 249 |
-
or_(
|
| 250 |
-
self.job_model.next_process_at.is_(None),
|
| 251 |
-
self.job_model.next_process_at <= now
|
| 252 |
-
)
|
| 253 |
-
)
|
| 254 |
-
).order_by(self.job_model.created_at).limit(1)
|
| 255 |
-
|
| 256 |
-
result = await session.execute(query)
|
| 257 |
-
job = result.scalar_one_or_none()
|
| 258 |
-
|
| 259 |
-
if not job:
|
| 260 |
-
return False
|
| 261 |
-
|
| 262 |
-
# For queued jobs, apply the constraints
|
| 263 |
-
if job.status == "queued":
|
| 264 |
-
# Constraint 1: Check if this user already has a job in processing
|
| 265 |
-
user_processing_query = select(func.count()).where(
|
| 266 |
-
and_(
|
| 267 |
-
self.job_model.user_id == job.user_id,
|
| 268 |
-
self.job_model.status == "processing"
|
| 269 |
-
)
|
| 270 |
-
)
|
| 271 |
-
user_result = await session.execute(user_processing_query)
|
| 272 |
-
user_processing_count = user_result.scalar() or 0
|
| 273 |
-
|
| 274 |
-
if user_processing_count > 0:
|
| 275 |
-
logger.debug(f"Worker {self.worker_id}: User {job.user_id} already has a job processing, skipping")
|
| 276 |
-
return False
|
| 277 |
-
|
| 278 |
-
# Constraint 2: Check if we're at max total processing capacity
|
| 279 |
-
if current_processing >= max_processing:
|
| 280 |
-
logger.debug(f"Worker {self.worker_id}: At max capacity ({current_processing}/{max_processing}), skipping new job")
|
| 281 |
-
return False
|
| 282 |
-
|
| 283 |
-
self._current_job_id = job.job_id
|
| 284 |
-
|
| 285 |
-
try:
|
| 286 |
-
await self._process_job(session, job)
|
| 287 |
-
return True
|
| 288 |
-
except Exception as e:
|
| 289 |
-
logger.error(f"Worker {self.worker_id}: Error processing job {job.job_id}: {e}")
|
| 290 |
-
job.status = "failed"
|
| 291 |
-
job.error_message = str(e)
|
| 292 |
-
job.completed_at = datetime.utcnow()
|
| 293 |
-
await session.commit()
|
| 294 |
-
return True # Job was found, even though it failed
|
| 295 |
-
finally:
|
| 296 |
-
self._current_job_id = None
|
| 297 |
-
|
| 298 |
-
async def _process_job(self, session: AsyncSession, job: JobType):
|
| 299 |
-
"""Process a single job."""
|
| 300 |
-
logger.info(f"Worker {self.worker_id}: Processing job {job.job_id} (status: {job.status})")
|
| 301 |
-
|
| 302 |
-
from sqlalchemy import update
|
| 303 |
-
|
| 304 |
-
if job.status == "queued":
|
| 305 |
-
# New job - try to claim it atomically
|
| 306 |
-
# Set next_process_at to future to prevent others from picking it up while we process
|
| 307 |
-
next_check = datetime.utcnow() + timedelta(seconds=self.poll_interval * 2)
|
| 308 |
-
|
| 309 |
-
stmt = (
|
| 310 |
-
update(self.job_model)
|
| 311 |
-
.where(
|
| 312 |
-
self.job_model.job_id == job.job_id,
|
| 313 |
-
self.job_model.status == "queued"
|
| 314 |
-
)
|
| 315 |
-
.values(
|
| 316 |
-
status="processing",
|
| 317 |
-
started_at=datetime.utcnow(),
|
| 318 |
-
next_process_at=next_check
|
| 319 |
-
)
|
| 320 |
-
)
|
| 321 |
-
result = await session.execute(stmt)
|
| 322 |
-
await session.commit()
|
| 323 |
-
|
| 324 |
-
if result.rowcount == 0:
|
| 325 |
-
logger.info(f"Worker {self.worker_id}: Failed to claim job {job.job_id} (already taken)")
|
| 326 |
-
return
|
| 327 |
-
|
| 328 |
-
# We claimed it. Refresh and process.
|
| 329 |
-
await session.refresh(job)
|
| 330 |
-
job = await self.job_processor.process(job, session)
|
| 331 |
-
|
| 332 |
-
else:
|
| 333 |
-
# Already processing - try to claim for status check
|
| 334 |
-
# Ensure we only pick it up if next_process_at matches (or is null/past)
|
| 335 |
-
# But the SELECT already filtered for that.
|
| 336 |
-
# We just need to ensure no one else grabbed it between SELECT and UPDATE.
|
| 337 |
-
|
| 338 |
-
# Update next_process_at to future to lock it for this check
|
| 339 |
-
next_check = datetime.utcnow() + timedelta(seconds=self.poll_interval * 2)
|
| 340 |
-
|
| 341 |
-
stmt = (
|
| 342 |
-
update(self.job_model)
|
| 343 |
-
.where(
|
| 344 |
-
self.job_model.job_id == job.job_id,
|
| 345 |
-
or_(
|
| 346 |
-
self.job_model.next_process_at.is_(None),
|
| 347 |
-
self.job_model.next_process_at <= datetime.utcnow()
|
| 348 |
-
)
|
| 349 |
-
)
|
| 350 |
-
.values(next_process_at=next_check)
|
| 351 |
-
)
|
| 352 |
-
result = await session.execute(stmt)
|
| 353 |
-
await session.commit()
|
| 354 |
-
|
| 355 |
-
if result.rowcount == 0:
|
| 356 |
-
logger.info(f"Worker {self.worker_id}: Failed to claim job {job.job_id} for check (already taken)")
|
| 357 |
-
return
|
| 358 |
-
|
| 359 |
-
await session.refresh(job)
|
| 360 |
-
job = await self.job_processor.check_status(job, session)
|
| 361 |
-
|
| 362 |
-
# Handle retry limit
|
| 363 |
-
if job.status == "processing" and job.retry_count > self.max_retries:
|
| 364 |
-
job.status = "failed"
|
| 365 |
-
job.error_message = f"Max retries ({self.max_retries}) exceeded"
|
| 366 |
-
job.completed_at = datetime.utcnow()
|
| 367 |
-
|
| 368 |
-
# Handle credit finalization for jobs with reserved credits
|
| 369 |
-
if job.status in ("completed", "failed", "cancelled"):
|
| 370 |
-
await self._handle_job_credits(session, job)
|
| 371 |
-
|
| 372 |
-
await session.commit()
|
| 373 |
-
|
| 374 |
-
async def _handle_job_credits(self, session: AsyncSession, job: JobType):
|
| 375 |
-
"""Handle credit finalization when job reaches terminal state."""
|
| 376 |
-
# Check if job has credits_reserved attribute (credit-enabled jobs)
|
| 377 |
-
if not hasattr(job, 'credits_reserved') or job.credits_reserved <= 0:
|
| 378 |
-
return
|
| 379 |
-
|
| 380 |
-
try:
|
| 381 |
-
from services.credit_service.credit_manager import handle_job_completion
|
| 382 |
-
await handle_job_completion(session, job)
|
| 383 |
-
except ImportError:
|
| 384 |
-
# Credit service not available - skip
|
| 385 |
-
logger.debug(f"Credit service not available for job {job.job_id}")
|
| 386 |
-
except Exception as e:
|
| 387 |
-
logger.error(f"Error handling credits for job {job.job_id}: {e}")
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
class PriorityWorkerPool(Generic[JobType]):
|
| 391 |
-
"""
|
| 392 |
-
Modular priority-tier worker pool.
|
| 393 |
-
|
| 394 |
-
Can be used with any job model that follows the required schema.
|
| 395 |
-
"""
|
| 396 |
-
|
| 397 |
-
def __init__(
|
| 398 |
-
self,
|
| 399 |
-
database_url: str,
|
| 400 |
-
job_model: type,
|
| 401 |
-
job_processor: JobProcessor[JobType],
|
| 402 |
-
config: Optional[WorkerConfig] = None
|
| 403 |
-
):
|
| 404 |
-
"""
|
| 405 |
-
Initialize the worker pool.
|
| 406 |
-
|
| 407 |
-
Args:
|
| 408 |
-
database_url: SQLAlchemy async database URL
|
| 409 |
-
job_model: Your ORM model class for jobs
|
| 410 |
-
job_processor: Instance of JobProcessor to handle jobs
|
| 411 |
-
config: Worker configuration (uses env vars if not provided)
|
| 412 |
-
"""
|
| 413 |
-
self.database_url = database_url
|
| 414 |
-
self.job_model = job_model
|
| 415 |
-
self.job_processor = job_processor
|
| 416 |
-
self.config = config or WorkerConfig.from_env()
|
| 417 |
-
|
| 418 |
-
self.engine = create_async_engine(database_url, echo=False)
|
| 419 |
-
self.session_maker = async_sessionmaker(
|
| 420 |
-
self.engine,
|
| 421 |
-
class_=AsyncSession,
|
| 422 |
-
expire_on_commit=False
|
| 423 |
-
)
|
| 424 |
-
self.workers: List[PriorityWorker] = []
|
| 425 |
-
self._running = False
|
| 426 |
-
|
| 427 |
-
# Wake events for each priority tier - allows immediate job notification
|
| 428 |
-
self._wake_events: dict[str, asyncio.Event] = {
|
| 429 |
-
"fast": asyncio.Event(),
|
| 430 |
-
"medium": asyncio.Event(),
|
| 431 |
-
"slow": asyncio.Event()
|
| 432 |
-
}
|
| 433 |
-
|
| 434 |
-
async def start(self):
|
| 435 |
-
"""Start all workers."""
|
| 436 |
-
self._running = True
|
| 437 |
-
worker_id = 0
|
| 438 |
-
|
| 439 |
-
# Create fast workers
|
| 440 |
-
for i in range(self.config.fast_workers):
|
| 441 |
-
worker = PriorityWorker(
|
| 442 |
-
worker_id=worker_id,
|
| 443 |
-
priority="fast",
|
| 444 |
-
poll_interval=self.config.fast_interval,
|
| 445 |
-
session_maker=self.session_maker,
|
| 446 |
-
job_model=self.job_model,
|
| 447 |
-
job_processor=self.job_processor,
|
| 448 |
-
max_retries=self.config.max_retries,
|
| 449 |
-
wake_event=self._wake_events["fast"],
|
| 450 |
-
config=self.config
|
| 451 |
-
)
|
| 452 |
-
self.workers.append(worker)
|
| 453 |
-
await worker.start()
|
| 454 |
-
worker_id += 1
|
| 455 |
-
|
| 456 |
-
# Create medium workers
|
| 457 |
-
for i in range(self.config.medium_workers):
|
| 458 |
-
worker = PriorityWorker(
|
| 459 |
-
worker_id=worker_id,
|
| 460 |
-
priority="medium",
|
| 461 |
-
poll_interval=self.config.medium_interval,
|
| 462 |
-
session_maker=self.session_maker,
|
| 463 |
-
job_model=self.job_model,
|
| 464 |
-
job_processor=self.job_processor,
|
| 465 |
-
max_retries=self.config.max_retries,
|
| 466 |
-
wake_event=self._wake_events["medium"],
|
| 467 |
-
config=self.config
|
| 468 |
-
)
|
| 469 |
-
self.workers.append(worker)
|
| 470 |
-
await worker.start()
|
| 471 |
-
worker_id += 1
|
| 472 |
-
|
| 473 |
-
# Create slow workers
|
| 474 |
-
for i in range(self.config.slow_workers):
|
| 475 |
-
worker = PriorityWorker(
|
| 476 |
-
worker_id=worker_id,
|
| 477 |
-
priority="slow",
|
| 478 |
-
poll_interval=self.config.slow_interval,
|
| 479 |
-
session_maker=self.session_maker,
|
| 480 |
-
job_model=self.job_model,
|
| 481 |
-
job_processor=self.job_processor,
|
| 482 |
-
max_retries=self.config.max_retries,
|
| 483 |
-
wake_event=self._wake_events["slow"],
|
| 484 |
-
config=self.config
|
| 485 |
-
)
|
| 486 |
-
self.workers.append(worker)
|
| 487 |
-
await worker.start()
|
| 488 |
-
worker_id += 1
|
| 489 |
-
|
| 490 |
-
total = self.config.fast_workers + self.config.medium_workers + self.config.slow_workers
|
| 491 |
-
logger.info(
|
| 492 |
-
f"PriorityWorkerPool started with {total} workers: "
|
| 493 |
-
f"{self.config.fast_workers} fast, {self.config.medium_workers} medium, {self.config.slow_workers} slow"
|
| 494 |
-
)
|
| 495 |
-
|
| 496 |
-
def notify_new_job(self, priority: str):
|
| 497 |
-
"""
|
| 498 |
-
Wake sleeping workers of the specified priority tier.
|
| 499 |
-
Call this when a new job is created to start processing immediately.
|
| 500 |
-
|
| 501 |
-
Args:
|
| 502 |
-
priority: Priority tier ("fast", "medium", or "slow")
|
| 503 |
-
"""
|
| 504 |
-
if priority in self._wake_events:
|
| 505 |
-
self._wake_events[priority].set()
|
| 506 |
-
logger.debug(f"Notified {priority} workers of new job")
|
| 507 |
-
|
| 508 |
-
async def stop(self):
|
| 509 |
-
"""Stop all workers and refund orphaned jobs."""
|
| 510 |
-
self._running = False
|
| 511 |
-
|
| 512 |
-
# Refund credits for any jobs that were processing when server stopped
|
| 513 |
-
await self._refund_orphaned_jobs()
|
| 514 |
-
|
| 515 |
-
for worker in self.workers:
|
| 516 |
-
await worker.stop()
|
| 517 |
-
logger.info("PriorityWorkerPool stopped")
|
| 518 |
-
|
| 519 |
-
async def _refund_orphaned_jobs(self):
|
| 520 |
-
"""Refund credits for jobs abandoned during shutdown."""
|
| 521 |
-
try:
|
| 522 |
-
from services.credit_service.credit_manager import refund_orphaned_jobs
|
| 523 |
-
async with self.session_maker() as session:
|
| 524 |
-
refund_count = await refund_orphaned_jobs(session)
|
| 525 |
-
if refund_count > 0:
|
| 526 |
-
logger.info(f"Shutdown: Refunded {refund_count} orphaned job(s)")
|
| 527 |
-
except ImportError:
|
| 528 |
-
logger.debug("Credit service not available for orphaned job refunds")
|
| 529 |
-
except Exception as e:
|
| 530 |
-
logger.error(f"Error refunding orphaned jobs during shutdown: {e}")
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
# Convenience functions for priority mapping
|
| 534 |
-
def get_priority_for_job_type(job_type: str, mappings: dict) -> str:
|
| 535 |
-
"""Get priority tier for a job type using provided mappings."""
|
| 536 |
-
return mappings.get(job_type, "fast")
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
def get_interval_for_priority(priority: str, config: Optional[WorkerConfig] = None) -> int:
|
| 540 |
-
"""Get polling interval for a priority tier."""
|
| 541 |
-
cfg = config or WorkerConfig.from_env()
|
| 542 |
-
if priority == "fast":
|
| 543 |
-
return cfg.fast_interval
|
| 544 |
-
elif priority == "medium":
|
| 545 |
-
return cfg.medium_interval
|
| 546 |
-
else:
|
| 547 |
-
return cfg.slow_interval
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/debug_gemini_service.py
DELETED
|
@@ -1,165 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Debug script to test Gemini service with API keys from environment.
|
| 3 |
-
Keys should be in GEMINI_KEYS environment variable, comma-separated.
|
| 4 |
-
|
| 5 |
-
Usage:
|
| 6 |
-
GEMINI_KEYS="key1,key2,key3" python tests/debug_gemini_service.py
|
| 7 |
-
"""
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
import asyncio
|
| 11 |
-
import logging
|
| 12 |
-
import base64
|
| 13 |
-
from dotenv import load_dotenv
|
| 14 |
-
|
| 15 |
-
# Load environment variables
|
| 16 |
-
load_dotenv()
|
| 17 |
-
|
| 18 |
-
# Add parent directory to path
|
| 19 |
-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 20 |
-
|
| 21 |
-
from services.gemini_service import GeminiService, MODELS
|
| 22 |
-
|
| 23 |
-
# Configure logging
|
| 24 |
-
logging.basicConfig(level=logging.INFO)
|
| 25 |
-
logger = logging.getLogger(__name__)
|
| 26 |
-
|
| 27 |
-
# Test image path
|
| 28 |
-
TEST_IMAGE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test.jpg")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def load_test_image():
|
| 32 |
-
"""Load test image and return base64 + mime type."""
|
| 33 |
-
if not os.path.exists(TEST_IMAGE_PATH):
|
| 34 |
-
logger.error(f"Test image not found: {TEST_IMAGE_PATH}")
|
| 35 |
-
return None, None
|
| 36 |
-
|
| 37 |
-
with open(TEST_IMAGE_PATH, "rb") as f:
|
| 38 |
-
image_data = f.read()
|
| 39 |
-
|
| 40 |
-
base64_image = base64.b64encode(image_data).decode("utf-8")
|
| 41 |
-
mime_type = "image/jpeg"
|
| 42 |
-
|
| 43 |
-
logger.info(f"Loaded test image: {TEST_IMAGE_PATH} ({len(image_data)} bytes)")
|
| 44 |
-
return base64_image, mime_type
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
async def test_generate_text(service: GeminiService, key_index: int):
|
| 48 |
-
"""Test simple text generation."""
|
| 49 |
-
logger.info(f"[Key {key_index}] Testing text generation...")
|
| 50 |
-
try:
|
| 51 |
-
result = await service.generate_text("Say hello in one word.")
|
| 52 |
-
logger.info(f"[Key {key_index}] Text generation result: {result[:100]}...")
|
| 53 |
-
return True
|
| 54 |
-
except Exception as e:
|
| 55 |
-
logger.error(f"[Key {key_index}] Text generation failed: {e}")
|
| 56 |
-
return False
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
async def test_analyze_image(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
|
| 60 |
-
"""Test image analysis."""
|
| 61 |
-
logger.info(f"[Key {key_index}] Testing image analysis...")
|
| 62 |
-
try:
|
| 63 |
-
result = await service.analyze_image(
|
| 64 |
-
base64_image=base64_image,
|
| 65 |
-
mime_type=mime_type,
|
| 66 |
-
prompt="Describe this image in one sentence."
|
| 67 |
-
)
|
| 68 |
-
logger.info(f"[Key {key_index}] Image analysis result: {result[:100]}...")
|
| 69 |
-
return True
|
| 70 |
-
except Exception as e:
|
| 71 |
-
logger.error(f"[Key {key_index}] Image analysis failed: {e}")
|
| 72 |
-
return False
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
async def test_generate_animation_prompt(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
|
| 76 |
-
"""Test animation prompt generation."""
|
| 77 |
-
logger.info(f"[Key {key_index}] Testing animation prompt generation...")
|
| 78 |
-
try:
|
| 79 |
-
result = await service.generate_animation_prompt(
|
| 80 |
-
base64_image=base64_image,
|
| 81 |
-
mime_type=mime_type
|
| 82 |
-
)
|
| 83 |
-
logger.info(f"[Key {key_index}] Animation prompt result: {result[:100]}...")
|
| 84 |
-
return True
|
| 85 |
-
except Exception as e:
|
| 86 |
-
logger.error(f"[Key {key_index}] Animation prompt generation failed: {e}")
|
| 87 |
-
return False
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
async def test_key(api_key: str, key_index: int, base64_image: str, mime_type: str):
|
| 91 |
-
"""Test all basic operations with a single API key."""
|
| 92 |
-
logger.info(f"\n{'='*50}")
|
| 93 |
-
logger.info(f"Testing Key {key_index}: {api_key[:10]}...{api_key[-4:]}")
|
| 94 |
-
logger.info(f"{'='*50}")
|
| 95 |
-
|
| 96 |
-
try:
|
| 97 |
-
service = GeminiService(api_key)
|
| 98 |
-
except Exception as e:
|
| 99 |
-
logger.error(f"[Key {key_index}] Failed to initialize service: {e}")
|
| 100 |
-
return {"key_index": key_index, "valid": False, "error": str(e)}
|
| 101 |
-
|
| 102 |
-
results = {
|
| 103 |
-
"key_index": key_index,
|
| 104 |
-
"key_preview": f"{api_key[:10]}...{api_key[-4:]}",
|
| 105 |
-
"text_generation": await test_generate_text(service, key_index),
|
| 106 |
-
"image_analysis": await test_analyze_image(service, key_index, base64_image, mime_type),
|
| 107 |
-
"animation_prompt": await test_generate_animation_prompt(service, key_index, base64_image, mime_type),
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
results["valid"] = all([
|
| 111 |
-
results["text_generation"],
|
| 112 |
-
results["image_analysis"],
|
| 113 |
-
results["animation_prompt"]
|
| 114 |
-
])
|
| 115 |
-
|
| 116 |
-
return results
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
async def main():
|
| 120 |
-
# Load test image
|
| 121 |
-
base64_image, mime_type = load_test_image()
|
| 122 |
-
if not base64_image:
|
| 123 |
-
logger.error("Cannot run tests without test image. Please add test.jpg to project root.")
|
| 124 |
-
return
|
| 125 |
-
|
| 126 |
-
gemini_keys_str = os.getenv("GEMINI_KEYS", "")
|
| 127 |
-
|
| 128 |
-
if not gemini_keys_str:
|
| 129 |
-
logger.error("GEMINI_KEYS environment variable not set.")
|
| 130 |
-
logger.info("Usage: GEMINI_KEYS='key1,key2,key3' python tests/debug_gemini_service.py")
|
| 131 |
-
return
|
| 132 |
-
|
| 133 |
-
keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()]
|
| 134 |
-
|
| 135 |
-
if not keys:
|
| 136 |
-
logger.error("No valid keys found in GEMINI_KEYS.")
|
| 137 |
-
return
|
| 138 |
-
|
| 139 |
-
logger.info(f"Found {len(keys)} API key(s) to test.")
|
| 140 |
-
logger.info(f"Available models: {MODELS}")
|
| 141 |
-
|
| 142 |
-
all_results = []
|
| 143 |
-
for i, key in enumerate(keys):
|
| 144 |
-
result = await test_key(key, i + 1, base64_image, mime_type)
|
| 145 |
-
all_results.append(result)
|
| 146 |
-
|
| 147 |
-
# Summary
|
| 148 |
-
logger.info(f"\n{'='*50}")
|
| 149 |
-
logger.info("SUMMARY")
|
| 150 |
-
logger.info(f"{'='*50}")
|
| 151 |
-
|
| 152 |
-
valid_count = sum(1 for r in all_results if r.get("valid", False))
|
| 153 |
-
logger.info(f"Valid keys: {valid_count}/{len(keys)}")
|
| 154 |
-
|
| 155 |
-
for result in all_results:
|
| 156 |
-
status = "✓ VALID" if result.get("valid") else "✗ INVALID"
|
| 157 |
-
logger.info(f" Key {result['key_index']}: {status}")
|
| 158 |
-
if not result.get("valid"):
|
| 159 |
-
for test_name in ["text_generation", "image_analysis", "animation_prompt"]:
|
| 160 |
-
if test_name in result and not result[test_name]:
|
| 161 |
-
logger.info(f" - {test_name}: FAILED")
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
if __name__ == "__main__":
|
| 165 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_fal_service.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for Fal.ai Service.
|
| 3 |
+
|
| 4 |
+
Tests cover:
|
| 5 |
+
1. Initialization & API key handling
|
| 6 |
+
2. Video generation
|
| 7 |
+
3. Error handling
|
| 8 |
+
4. Mock mode
|
| 9 |
+
"""
|
| 10 |
+
import pytest
|
| 11 |
+
import asyncio
|
| 12 |
+
import os
|
| 13 |
+
from unittest.mock import patch, MagicMock, AsyncMock
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# =============================================================================
|
| 17 |
+
# 1. Initialization & Configuration Tests
|
| 18 |
+
# =============================================================================
|
| 19 |
+
|
| 20 |
+
class TestFalServiceInit:
|
| 21 |
+
"""Test FalService initialization and configuration."""
|
| 22 |
+
|
| 23 |
+
def test_init_with_explicit_api_key(self):
|
| 24 |
+
"""Service initializes with explicit API key."""
|
| 25 |
+
with patch.dict(os.environ, {"FAL_KEY": "env-key"}):
|
| 26 |
+
from services.fal_service import FalService
|
| 27 |
+
|
| 28 |
+
service = FalService(api_key="test-key-123")
|
| 29 |
+
|
| 30 |
+
assert service.api_key == "test-key-123"
|
| 31 |
+
|
| 32 |
+
def test_init_with_env_fallback(self):
|
| 33 |
+
"""Service falls back to environment variable for API key."""
|
| 34 |
+
with patch.dict(os.environ, {"FAL_KEY": "env-key-456"}):
|
| 35 |
+
from services.fal_service import FalService
|
| 36 |
+
|
| 37 |
+
service = FalService()
|
| 38 |
+
|
| 39 |
+
assert service.api_key == "env-key-456"
|
| 40 |
+
|
| 41 |
+
def test_init_fails_without_api_key(self):
|
| 42 |
+
"""Service raises error when no API key available."""
|
| 43 |
+
with patch.dict(os.environ, {}, clear=True):
|
| 44 |
+
os.environ.pop("FAL_KEY", None)
|
| 45 |
+
|
| 46 |
+
from services.fal_service import get_fal_api_key
|
| 47 |
+
|
| 48 |
+
with pytest.raises(ValueError, match="FAL_KEY not configured"):
|
| 49 |
+
get_fal_api_key()
|
| 50 |
+
|
| 51 |
+
def test_models_dict_has_required_entries(self):
|
| 52 |
+
"""MODELS dictionary has all required model names."""
|
| 53 |
+
from services.fal_service import MODELS
|
| 54 |
+
|
| 55 |
+
assert "video_generation" in MODELS
|
| 56 |
+
assert "veo3" in MODELS["video_generation"].lower() or "image-to-video" in MODELS["video_generation"]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# =============================================================================
|
| 60 |
+
# 2. Video Generation Tests
|
| 61 |
+
# =============================================================================
|
| 62 |
+
|
| 63 |
+
class TestFalVideoGeneration:
|
| 64 |
+
"""Test video generation methods."""
|
| 65 |
+
|
| 66 |
+
@pytest.mark.asyncio
|
| 67 |
+
async def test_start_video_generation_mock_mode(self):
|
| 68 |
+
"""Video generation works in mock mode."""
|
| 69 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 70 |
+
with patch('services.fal_service.api_client.MOCK_MODE', True):
|
| 71 |
+
from services.fal_service import FalService
|
| 72 |
+
|
| 73 |
+
service = FalService(api_key="test-key")
|
| 74 |
+
result = await service.start_video_generation(
|
| 75 |
+
base64_image="base64data",
|
| 76 |
+
mime_type="image/jpeg",
|
| 77 |
+
prompt="Animate this"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
assert result["done"] is True
|
| 81 |
+
assert result["status"] == "completed"
|
| 82 |
+
assert "video_url" in result
|
| 83 |
+
assert "fal_request_id" in result
|
| 84 |
+
|
| 85 |
+
@pytest.mark.asyncio
|
| 86 |
+
async def test_start_video_generation_success(self):
|
| 87 |
+
"""Video generation returns video URL on success."""
|
| 88 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 89 |
+
with patch('services.fal_service.api_client.MOCK_MODE', False):
|
| 90 |
+
with patch('services.fal_service.api_client.asyncio.to_thread') as mock_to_thread:
|
| 91 |
+
from services.fal_service import FalService
|
| 92 |
+
|
| 93 |
+
# Mock fal_client response
|
| 94 |
+
mock_result = {
|
| 95 |
+
"video": {"url": "https://fal.ai/video.mp4"},
|
| 96 |
+
"request_id": "req-123"
|
| 97 |
+
}
|
| 98 |
+
mock_to_thread.return_value = mock_result
|
| 99 |
+
|
| 100 |
+
service = FalService(api_key="test-key")
|
| 101 |
+
result = await service.start_video_generation(
|
| 102 |
+
base64_image="base64data",
|
| 103 |
+
mime_type="image/jpeg",
|
| 104 |
+
prompt="Animate this"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
assert result["done"] is True
|
| 108 |
+
assert result["status"] == "completed"
|
| 109 |
+
assert result["video_url"] == "https://fal.ai/video.mp4"
|
| 110 |
+
|
| 111 |
+
@pytest.mark.asyncio
|
| 112 |
+
async def test_start_video_generation_no_video_url(self):
|
| 113 |
+
"""Video generation returns failed when no URL in response."""
|
| 114 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 115 |
+
with patch('services.fal_service.api_client.MOCK_MODE', False):
|
| 116 |
+
with patch('services.fal_service.api_client.asyncio.to_thread') as mock_to_thread:
|
| 117 |
+
from services.fal_service import FalService
|
| 118 |
+
|
| 119 |
+
# Mock response without video URL
|
| 120 |
+
mock_result = {"status": "error"}
|
| 121 |
+
mock_to_thread.return_value = mock_result
|
| 122 |
+
|
| 123 |
+
service = FalService(api_key="test-key")
|
| 124 |
+
result = await service.start_video_generation(
|
| 125 |
+
base64_image="base64data",
|
| 126 |
+
mime_type="image/jpeg",
|
| 127 |
+
prompt="Animate this"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
assert result["done"] is True
|
| 131 |
+
assert result["status"] == "failed"
|
| 132 |
+
assert "error" in result
|
| 133 |
+
|
| 134 |
+
@pytest.mark.asyncio
|
| 135 |
+
async def test_start_video_generation_with_params(self):
|
| 136 |
+
"""Video generation passes aspect_ratio and resolution."""
|
| 137 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 138 |
+
with patch('services.fal_service.api_client.MOCK_MODE', False):
|
| 139 |
+
with patch('services.fal_service.api_client.asyncio.to_thread') as mock_to_thread:
|
| 140 |
+
from services.fal_service import FalService
|
| 141 |
+
|
| 142 |
+
mock_result = {"video": {"url": "https://fal.ai/video.mp4"}}
|
| 143 |
+
mock_to_thread.return_value = mock_result
|
| 144 |
+
|
| 145 |
+
service = FalService(api_key="test-key")
|
| 146 |
+
await service.start_video_generation(
|
| 147 |
+
base64_image="base64data",
|
| 148 |
+
mime_type="image/jpeg",
|
| 149 |
+
prompt="Animate",
|
| 150 |
+
aspect_ratio="9:16",
|
| 151 |
+
resolution="720p"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Verify arguments were passed
|
| 155 |
+
call_args = mock_to_thread.call_args
|
| 156 |
+
arguments = call_args.kwargs.get("arguments") or call_args[1].get("arguments")
|
| 157 |
+
assert arguments["aspect_ratio"] == "9:16"
|
| 158 |
+
assert arguments["resolution"] == "720p"
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# =============================================================================
|
| 162 |
+
# 3. Error Handling Tests
|
| 163 |
+
# =============================================================================
|
| 164 |
+
|
| 165 |
+
class TestFalErrorHandling:
|
| 166 |
+
"""Test error handling methods."""
|
| 167 |
+
|
| 168 |
+
def test_handle_api_error_401(self):
|
| 169 |
+
"""_handle_api_error raises ValueError for 401."""
|
| 170 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 171 |
+
from services.fal_service import FalService
|
| 172 |
+
|
| 173 |
+
service = FalService(api_key="test-key")
|
| 174 |
+
|
| 175 |
+
with pytest.raises(ValueError, match="Authentication failed"):
|
| 176 |
+
service._handle_api_error(Exception("401 Unauthorized"), "test")
|
| 177 |
+
|
| 178 |
+
def test_handle_api_error_402(self):
|
| 179 |
+
"""_handle_api_error raises ValueError for 402."""
|
| 180 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 181 |
+
from services.fal_service import FalService
|
| 182 |
+
|
| 183 |
+
service = FalService(api_key="test-key")
|
| 184 |
+
|
| 185 |
+
with pytest.raises(ValueError, match="Insufficient credits"):
|
| 186 |
+
service._handle_api_error(Exception("402 Payment Required"), "test")
|
| 187 |
+
|
| 188 |
+
def test_handle_api_error_429(self):
|
| 189 |
+
"""_handle_api_error raises ValueError for 429."""
|
| 190 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 191 |
+
from services.fal_service import FalService
|
| 192 |
+
|
| 193 |
+
service = FalService(api_key="test-key")
|
| 194 |
+
|
| 195 |
+
with pytest.raises(ValueError, match="Rate limit"):
|
| 196 |
+
service._handle_api_error(Exception("429 Rate limit exceeded"), "test")
|
| 197 |
+
|
| 198 |
+
def test_handle_api_error_reraises_other(self):
|
| 199 |
+
"""_handle_api_error re-raises non-handled errors."""
|
| 200 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 201 |
+
from services.fal_service import FalService
|
| 202 |
+
|
| 203 |
+
service = FalService(api_key="test-key")
|
| 204 |
+
|
| 205 |
+
with pytest.raises(RuntimeError, match="Connection timeout"):
|
| 206 |
+
service._handle_api_error(RuntimeError("Connection timeout"), "test")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# =============================================================================
|
| 210 |
+
# 4. Video Download Tests
|
| 211 |
+
# =============================================================================
|
| 212 |
+
|
| 213 |
+
class TestFalVideoDownload:
|
| 214 |
+
"""Test download_video method."""
|
| 215 |
+
|
| 216 |
+
@pytest.mark.asyncio
|
| 217 |
+
async def test_download_video_saves_file(self):
|
| 218 |
+
"""download_video saves file and returns filename."""
|
| 219 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 220 |
+
from services.fal_service import FalService
|
| 221 |
+
|
| 222 |
+
# Mock httpx client at module level
|
| 223 |
+
with patch('httpx.AsyncClient') as mock_client:
|
| 224 |
+
mock_response = MagicMock()
|
| 225 |
+
mock_response.content = b"fake video data"
|
| 226 |
+
mock_response.raise_for_status = MagicMock()
|
| 227 |
+
|
| 228 |
+
mock_client_instance = AsyncMock()
|
| 229 |
+
mock_client_instance.get.return_value = mock_response
|
| 230 |
+
mock_client_instance.__aenter__.return_value = mock_client_instance
|
| 231 |
+
mock_client_instance.__aexit__.return_value = None
|
| 232 |
+
mock_client.return_value = mock_client_instance
|
| 233 |
+
|
| 234 |
+
# Mock file operations
|
| 235 |
+
with patch('services.fal_service.api_client.os.makedirs'):
|
| 236 |
+
mock_file = MagicMock()
|
| 237 |
+
with patch('builtins.open', MagicMock(return_value=mock_file)):
|
| 238 |
+
mock_file.__enter__ = MagicMock(return_value=mock_file)
|
| 239 |
+
mock_file.__exit__ = MagicMock(return_value=False)
|
| 240 |
+
|
| 241 |
+
service = FalService(api_key="test-key")
|
| 242 |
+
result = await service.download_video(
|
| 243 |
+
"https://fal.ai/video.mp4",
|
| 244 |
+
"test-req-123"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
assert result == "test-req-123.mp4"
|
| 248 |
+
mock_file.write.assert_called_once_with(b"fake video data")
|
| 249 |
+
|
| 250 |
+
@pytest.mark.asyncio
|
| 251 |
+
async def test_download_video_http_error(self):
|
| 252 |
+
"""download_video raises error on HTTP failure."""
|
| 253 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 254 |
+
from services.fal_service import FalService
|
| 255 |
+
|
| 256 |
+
with patch('httpx.AsyncClient') as mock_client:
|
| 257 |
+
mock_client_instance = AsyncMock()
|
| 258 |
+
mock_client_instance.get.side_effect = Exception("Connection refused")
|
| 259 |
+
mock_client_instance.__aenter__.return_value = mock_client_instance
|
| 260 |
+
mock_client_instance.__aexit__.return_value = None
|
| 261 |
+
mock_client.return_value = mock_client_instance
|
| 262 |
+
|
| 263 |
+
service = FalService(api_key="test-key")
|
| 264 |
+
|
| 265 |
+
with pytest.raises(ValueError, match="Failed to download"):
|
| 266 |
+
await service.download_video(
|
| 267 |
+
"https://fal.ai/video.mp4",
|
| 268 |
+
"test-req-123"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# =============================================================================
|
| 273 |
+
# 5. Check Status Tests
|
| 274 |
+
# =============================================================================
|
| 275 |
+
|
| 276 |
+
class TestFalCheckStatus:
|
| 277 |
+
"""Test check_video_status method."""
|
| 278 |
+
|
| 279 |
+
@pytest.mark.asyncio
|
| 280 |
+
async def test_check_status_returns_completed(self):
|
| 281 |
+
"""check_video_status returns completed (fal.ai is sync)."""
|
| 282 |
+
with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
|
| 283 |
+
from services.fal_service import FalService
|
| 284 |
+
|
| 285 |
+
service = FalService(api_key="test-key")
|
| 286 |
+
result = await service.check_video_status("req-123")
|
| 287 |
+
|
| 288 |
+
assert result["done"] is True
|
| 289 |
+
assert result["status"] == "completed"
|
| 290 |
+
assert result["fal_request_id"] == "req-123"
|
tests/test_gemini_service.py
DELETED
|
@@ -1,814 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Rigorous Tests for Gemini AI Service.
|
| 3 |
-
|
| 4 |
-
Tests cover:
|
| 5 |
-
1. Initialization & API key handling
|
| 6 |
-
2. Concurrency semaphores
|
| 7 |
-
3. Text generation
|
| 8 |
-
4. Animation prompt generation
|
| 9 |
-
5. Image analysis & editing
|
| 10 |
-
6. Video generation, status checking, downloading
|
| 11 |
-
7. Error handling
|
| 12 |
-
"""
|
| 13 |
-
import pytest
|
| 14 |
-
import asyncio
|
| 15 |
-
import os
|
| 16 |
-
import tempfile
|
| 17 |
-
from unittest.mock import patch, MagicMock, AsyncMock, PropertyMock
|
| 18 |
-
from datetime import datetime
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# =============================================================================
|
| 22 |
-
# 1. Initialization & Configuration Tests
|
| 23 |
-
# =============================================================================
|
| 24 |
-
|
| 25 |
-
class TestGeminiServiceInit:
|
| 26 |
-
"""Test GeminiService initialization and configuration."""
|
| 27 |
-
|
| 28 |
-
def test_init_with_explicit_api_key(self):
|
| 29 |
-
"""Service initializes with explicit API key."""
|
| 30 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 31 |
-
from services.gemini_service import GeminiService
|
| 32 |
-
|
| 33 |
-
service = GeminiService(api_key="test-key-123")
|
| 34 |
-
|
| 35 |
-
assert service.api_key == "test-key-123"
|
| 36 |
-
mock_genai.Client.assert_called_once_with(api_key="test-key-123")
|
| 37 |
-
|
| 38 |
-
def test_init_with_env_fallback(self):
|
| 39 |
-
"""Service falls back to environment variable for API key."""
|
| 40 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 41 |
-
with patch.dict(os.environ, {"GEMINI_API_KEY": "env-key-456"}):
|
| 42 |
-
from services.gemini_service import GeminiService
|
| 43 |
-
|
| 44 |
-
service = GeminiService()
|
| 45 |
-
|
| 46 |
-
assert service.api_key == "env-key-456"
|
| 47 |
-
|
| 48 |
-
def test_init_fails_without_api_key(self):
|
| 49 |
-
"""Service raises error when no API key available."""
|
| 50 |
-
with patch.dict(os.environ, {}, clear=True):
|
| 51 |
-
# Remove GEMINI_API_KEY if present
|
| 52 |
-
os.environ.pop("GEMINI_API_KEY", None)
|
| 53 |
-
os.environ.pop("GEMINI_API_KEYS", None)
|
| 54 |
-
|
| 55 |
-
from services.gemini_service import get_gemini_api_key
|
| 56 |
-
|
| 57 |
-
with pytest.raises(ValueError, match="Server Authentication Error"):
|
| 58 |
-
get_gemini_api_key()
|
| 59 |
-
|
| 60 |
-
def test_models_dict_has_required_entries(self):
|
| 61 |
-
"""MODELS dictionary has all required model names."""
|
| 62 |
-
from services.gemini_service import MODELS
|
| 63 |
-
|
| 64 |
-
assert "text_generation" in MODELS
|
| 65 |
-
assert "image_edit" in MODELS
|
| 66 |
-
assert "video_generation" in MODELS
|
| 67 |
-
assert all(isinstance(v, str) for v in MODELS.values())
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# =============================================================================
|
| 71 |
-
# 2. Semaphore Concurrency Tests
|
| 72 |
-
# =============================================================================
|
| 73 |
-
|
| 74 |
-
class TestSemaphoreConcurrency:
|
| 75 |
-
"""Test concurrency control via semaphores."""
|
| 76 |
-
|
| 77 |
-
def test_video_semaphore_respects_limit(self):
|
| 78 |
-
"""Video semaphore uses MAX_CONCURRENT_VIDEOS."""
|
| 79 |
-
# Reset global
|
| 80 |
-
import services.gemini_service as gs
|
| 81 |
-
gs._video_semaphore = None
|
| 82 |
-
|
| 83 |
-
with patch.object(gs, 'MAX_CONCURRENT_VIDEOS', 3):
|
| 84 |
-
gs._video_semaphore = None # Reset
|
| 85 |
-
sem = gs.get_video_semaphore()
|
| 86 |
-
# Semaphore internal value
|
| 87 |
-
assert sem._value == 3
|
| 88 |
-
|
| 89 |
-
def test_image_semaphore_respects_limit(self):
|
| 90 |
-
"""Image semaphore uses MAX_CONCURRENT_IMAGES."""
|
| 91 |
-
import services.gemini_service as gs
|
| 92 |
-
gs._image_semaphore = None
|
| 93 |
-
|
| 94 |
-
with patch.object(gs, 'MAX_CONCURRENT_IMAGES', 5):
|
| 95 |
-
gs._image_semaphore = None
|
| 96 |
-
sem = gs.get_image_semaphore()
|
| 97 |
-
assert sem._value == 5
|
| 98 |
-
|
| 99 |
-
def test_text_semaphore_respects_limit(self):
|
| 100 |
-
"""Text semaphore uses MAX_CONCURRENT_TEXT."""
|
| 101 |
-
import services.gemini_service as gs
|
| 102 |
-
gs._text_semaphore = None
|
| 103 |
-
|
| 104 |
-
with patch.object(gs, 'MAX_CONCURRENT_TEXT', 10):
|
| 105 |
-
gs._text_semaphore = None
|
| 106 |
-
sem = gs.get_text_semaphore()
|
| 107 |
-
assert sem._value == 10
|
| 108 |
-
|
| 109 |
-
def test_semaphores_are_singletons(self):
|
| 110 |
-
"""Calling get_*_semaphore multiple times returns same object."""
|
| 111 |
-
import services.gemini_service as gs
|
| 112 |
-
gs._video_semaphore = None
|
| 113 |
-
gs._image_semaphore = None
|
| 114 |
-
gs._text_semaphore = None
|
| 115 |
-
|
| 116 |
-
video1 = gs.get_video_semaphore()
|
| 117 |
-
video2 = gs.get_video_semaphore()
|
| 118 |
-
assert video1 is video2
|
| 119 |
-
|
| 120 |
-
image1 = gs.get_image_semaphore()
|
| 121 |
-
image2 = gs.get_image_semaphore()
|
| 122 |
-
assert image1 is image2
|
| 123 |
-
|
| 124 |
-
text1 = gs.get_text_semaphore()
|
| 125 |
-
text2 = gs.get_text_semaphore()
|
| 126 |
-
assert text1 is text2
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# =============================================================================
|
| 130 |
-
# 3. Text Generation Tests
|
| 131 |
-
# =============================================================================
|
| 132 |
-
|
| 133 |
-
class TestTextGeneration:
|
| 134 |
-
"""Test generate_text method."""
|
| 135 |
-
|
| 136 |
-
@pytest.mark.asyncio
|
| 137 |
-
async def test_generate_text_success(self):
|
| 138 |
-
"""generate_text returns text on success."""
|
| 139 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 140 |
-
from services.gemini_service import GeminiService
|
| 141 |
-
|
| 142 |
-
# Mock response
|
| 143 |
-
mock_response = MagicMock()
|
| 144 |
-
mock_response.text = "Generated text response"
|
| 145 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 146 |
-
|
| 147 |
-
service = GeminiService(api_key="test-key")
|
| 148 |
-
result = await service.generate_text("Hello world")
|
| 149 |
-
|
| 150 |
-
assert result == "Generated text response"
|
| 151 |
-
|
| 152 |
-
@pytest.mark.asyncio
|
| 153 |
-
async def test_generate_text_with_custom_model(self):
|
| 154 |
-
"""generate_text uses custom model when provided."""
|
| 155 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 156 |
-
from services.gemini_service import GeminiService
|
| 157 |
-
|
| 158 |
-
mock_response = MagicMock()
|
| 159 |
-
mock_response.text = "Custom model response"
|
| 160 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 161 |
-
|
| 162 |
-
service = GeminiService(api_key="test-key")
|
| 163 |
-
result = await service.generate_text("Hello", model="custom-model")
|
| 164 |
-
|
| 165 |
-
# Verify custom model was used
|
| 166 |
-
call_args = mock_genai.Client.return_value.models.generate_content.call_args
|
| 167 |
-
assert call_args.kwargs.get('model') == "custom-model"
|
| 168 |
-
|
| 169 |
-
@pytest.mark.asyncio
|
| 170 |
-
async def test_generate_text_empty_response(self):
|
| 171 |
-
"""generate_text returns empty string for None response."""
|
| 172 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 173 |
-
from services.gemini_service import GeminiService
|
| 174 |
-
|
| 175 |
-
mock_response = MagicMock()
|
| 176 |
-
mock_response.text = None
|
| 177 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 178 |
-
|
| 179 |
-
service = GeminiService(api_key="test-key")
|
| 180 |
-
result = await service.generate_text("Hello")
|
| 181 |
-
|
| 182 |
-
assert result == ""
|
| 183 |
-
|
| 184 |
-
@pytest.mark.asyncio
|
| 185 |
-
async def test_generate_text_api_error_404(self):
|
| 186 |
-
"""generate_text raises ValueError for 404 error."""
|
| 187 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 188 |
-
from services.gemini_service import GeminiService
|
| 189 |
-
|
| 190 |
-
mock_genai.Client.return_value.models.generate_content.side_effect = Exception("404 NOT_FOUND")
|
| 191 |
-
|
| 192 |
-
service = GeminiService(api_key="test-key")
|
| 193 |
-
|
| 194 |
-
with pytest.raises(ValueError, match="Model not found"):
|
| 195 |
-
await service.generate_text("Hello")
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
# =============================================================================
|
| 199 |
-
# 4. Animation Prompt Tests
|
| 200 |
-
# =============================================================================
|
| 201 |
-
|
| 202 |
-
class TestAnimationPrompt:
|
| 203 |
-
"""Test generate_animation_prompt method."""
|
| 204 |
-
|
| 205 |
-
@pytest.mark.asyncio
|
| 206 |
-
async def test_generate_animation_prompt_default(self):
|
| 207 |
-
"""generate_animation_prompt uses default prompt."""
|
| 208 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 209 |
-
with patch('services.gemini_service.types'):
|
| 210 |
-
from services.gemini_service import GeminiService
|
| 211 |
-
|
| 212 |
-
mock_response = MagicMock()
|
| 213 |
-
mock_response.text = "Subtle zoom with camera pan"
|
| 214 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 215 |
-
|
| 216 |
-
service = GeminiService(api_key="test-key")
|
| 217 |
-
result = await service.generate_animation_prompt(
|
| 218 |
-
base64_image="base64data",
|
| 219 |
-
mime_type="image/jpeg"
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
assert result == "Subtle zoom with camera pan"
|
| 223 |
-
|
| 224 |
-
@pytest.mark.asyncio
|
| 225 |
-
async def test_generate_animation_prompt_custom(self):
|
| 226 |
-
"""generate_animation_prompt uses custom prompt when provided."""
|
| 227 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 228 |
-
with patch('services.gemini_service.types'):
|
| 229 |
-
from services.gemini_service import GeminiService
|
| 230 |
-
|
| 231 |
-
mock_response = MagicMock()
|
| 232 |
-
mock_response.text = "Custom animation"
|
| 233 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 234 |
-
|
| 235 |
-
service = GeminiService(api_key="test-key")
|
| 236 |
-
result = await service.generate_animation_prompt(
|
| 237 |
-
base64_image="base64data",
|
| 238 |
-
mime_type="image/jpeg",
|
| 239 |
-
custom_prompt="Make it dramatic"
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
assert result == "Custom animation"
|
| 243 |
-
|
| 244 |
-
@pytest.mark.asyncio
|
| 245 |
-
async def test_generate_animation_prompt_fallback(self):
|
| 246 |
-
"""generate_animation_prompt returns fallback on empty response."""
|
| 247 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 248 |
-
with patch('services.gemini_service.types'):
|
| 249 |
-
from services.gemini_service import GeminiService
|
| 250 |
-
|
| 251 |
-
mock_response = MagicMock()
|
| 252 |
-
mock_response.text = None
|
| 253 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 254 |
-
|
| 255 |
-
service = GeminiService(api_key="test-key")
|
| 256 |
-
result = await service.generate_animation_prompt(
|
| 257 |
-
base64_image="base64data",
|
| 258 |
-
mime_type="image/jpeg"
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
assert result == "Cinematic subtle movement"
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
# =============================================================================
|
| 265 |
-
# 5. Image Analysis Tests
|
| 266 |
-
# =============================================================================
|
| 267 |
-
|
| 268 |
-
class TestImageAnalysis:
|
| 269 |
-
"""Test analyze_image method."""
|
| 270 |
-
|
| 271 |
-
@pytest.mark.asyncio
|
| 272 |
-
async def test_analyze_image_success(self):
|
| 273 |
-
"""analyze_image returns analysis text."""
|
| 274 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 275 |
-
with patch('services.gemini_service.types'):
|
| 276 |
-
from services.gemini_service import GeminiService
|
| 277 |
-
|
| 278 |
-
mock_response = MagicMock()
|
| 279 |
-
mock_response.text = "This image shows a sunset over mountains"
|
| 280 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 281 |
-
|
| 282 |
-
service = GeminiService(api_key="test-key")
|
| 283 |
-
result = await service.analyze_image(
|
| 284 |
-
base64_image="base64data",
|
| 285 |
-
mime_type="image/jpeg",
|
| 286 |
-
prompt="Describe this image"
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
-
assert result == "This image shows a sunset over mountains"
|
| 290 |
-
|
| 291 |
-
@pytest.mark.asyncio
|
| 292 |
-
async def test_analyze_image_empty_response(self):
|
| 293 |
-
"""analyze_image returns empty string for None response."""
|
| 294 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 295 |
-
with patch('services.gemini_service.types'):
|
| 296 |
-
from services.gemini_service import GeminiService
|
| 297 |
-
|
| 298 |
-
mock_response = MagicMock()
|
| 299 |
-
mock_response.text = None
|
| 300 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 301 |
-
|
| 302 |
-
service = GeminiService(api_key="test-key")
|
| 303 |
-
result = await service.analyze_image(
|
| 304 |
-
base64_image="base64data",
|
| 305 |
-
mime_type="image/jpeg",
|
| 306 |
-
prompt="Describe"
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
-
assert result == ""
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
# =============================================================================
|
| 313 |
-
# 6. Image Editing Tests
|
| 314 |
-
# =============================================================================
|
| 315 |
-
|
| 316 |
-
class TestImageEditing:
|
| 317 |
-
"""Test edit_image method."""
|
| 318 |
-
|
| 319 |
-
@pytest.mark.asyncio
|
| 320 |
-
async def test_edit_image_returns_data_uri(self):
|
| 321 |
-
"""edit_image returns base64 data URI."""
|
| 322 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 323 |
-
from services.gemini_service import GeminiService
|
| 324 |
-
|
| 325 |
-
# Create mock response structure
|
| 326 |
-
mock_inline_data = MagicMock()
|
| 327 |
-
mock_inline_data.data = "base64imagedata"
|
| 328 |
-
mock_inline_data.mime_type = "image/png"
|
| 329 |
-
|
| 330 |
-
mock_part = MagicMock()
|
| 331 |
-
mock_part.inline_data = mock_inline_data
|
| 332 |
-
|
| 333 |
-
mock_content = MagicMock()
|
| 334 |
-
mock_content.parts = [mock_part]
|
| 335 |
-
|
| 336 |
-
mock_candidate = MagicMock()
|
| 337 |
-
mock_candidate.content = mock_content
|
| 338 |
-
|
| 339 |
-
mock_response = MagicMock()
|
| 340 |
-
mock_response.candidates = [mock_candidate]
|
| 341 |
-
|
| 342 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 343 |
-
|
| 344 |
-
service = GeminiService(api_key="test-key")
|
| 345 |
-
result = await service.edit_image(
|
| 346 |
-
base64_image="input-base64",
|
| 347 |
-
mime_type="image/jpeg",
|
| 348 |
-
prompt="Make it colorful"
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
assert result == "data:image/png;base64,base64imagedata"
|
| 352 |
-
|
| 353 |
-
@pytest.mark.asyncio
|
| 354 |
-
async def test_edit_image_no_candidates(self):
|
| 355 |
-
"""edit_image raises error when no candidates returned."""
|
| 356 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 357 |
-
from services.gemini_service import GeminiService
|
| 358 |
-
|
| 359 |
-
mock_response = MagicMock()
|
| 360 |
-
mock_response.candidates = []
|
| 361 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 362 |
-
|
| 363 |
-
service = GeminiService(api_key="test-key")
|
| 364 |
-
|
| 365 |
-
with pytest.raises(ValueError, match="No candidates returned"):
|
| 366 |
-
await service.edit_image(
|
| 367 |
-
base64_image="input-base64",
|
| 368 |
-
mime_type="image/jpeg",
|
| 369 |
-
prompt="Edit"
|
| 370 |
-
)
|
| 371 |
-
|
| 372 |
-
@pytest.mark.asyncio
|
| 373 |
-
async def test_edit_image_no_image_data(self):
|
| 374 |
-
"""edit_image raises error when no image data in parts."""
|
| 375 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 376 |
-
from services.gemini_service import GeminiService
|
| 377 |
-
|
| 378 |
-
# Part without inline_data
|
| 379 |
-
mock_part = MagicMock()
|
| 380 |
-
mock_part.inline_data = None
|
| 381 |
-
|
| 382 |
-
mock_content = MagicMock()
|
| 383 |
-
mock_content.parts = [mock_part]
|
| 384 |
-
|
| 385 |
-
mock_candidate = MagicMock()
|
| 386 |
-
mock_candidate.content = mock_content
|
| 387 |
-
|
| 388 |
-
mock_response = MagicMock()
|
| 389 |
-
mock_response.candidates = [mock_candidate]
|
| 390 |
-
|
| 391 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 392 |
-
|
| 393 |
-
service = GeminiService(api_key="test-key")
|
| 394 |
-
|
| 395 |
-
with pytest.raises(ValueError, match="No image data found"):
|
| 396 |
-
await service.edit_image(
|
| 397 |
-
base64_image="input-base64",
|
| 398 |
-
mime_type="image/jpeg",
|
| 399 |
-
prompt="Edit"
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
@pytest.mark.asyncio
|
| 403 |
-
async def test_edit_image_default_prompt(self):
|
| 404 |
-
"""edit_image uses default prompt when empty."""
|
| 405 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 406 |
-
with patch('services.gemini_service.types'):
|
| 407 |
-
from services.gemini_service import GeminiService
|
| 408 |
-
|
| 409 |
-
mock_inline_data = MagicMock()
|
| 410 |
-
mock_inline_data.data = "base64data"
|
| 411 |
-
mock_inline_data.mime_type = "image/png"
|
| 412 |
-
|
| 413 |
-
mock_part = MagicMock()
|
| 414 |
-
mock_part.inline_data = mock_inline_data
|
| 415 |
-
|
| 416 |
-
mock_content = MagicMock()
|
| 417 |
-
mock_content.parts = [mock_part]
|
| 418 |
-
|
| 419 |
-
mock_candidate = MagicMock()
|
| 420 |
-
mock_candidate.content = mock_content
|
| 421 |
-
|
| 422 |
-
mock_response = MagicMock()
|
| 423 |
-
mock_response.candidates = [mock_candidate]
|
| 424 |
-
|
| 425 |
-
mock_genai.Client.return_value.models.generate_content.return_value = mock_response
|
| 426 |
-
|
| 427 |
-
service = GeminiService(api_key="test-key")
|
| 428 |
-
result = await service.edit_image(
|
| 429 |
-
base64_image="input",
|
| 430 |
-
mime_type="image/jpeg",
|
| 431 |
-
prompt="" # Empty prompt
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
assert "data:" in result
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
# =============================================================================
|
| 438 |
-
# 7. Video Generation Tests
|
| 439 |
-
# =============================================================================
|
| 440 |
-
|
| 441 |
-
class TestVideoGeneration:
|
| 442 |
-
"""Test start_video_generation method."""
|
| 443 |
-
|
| 444 |
-
@pytest.mark.asyncio
|
| 445 |
-
async def test_start_video_returns_operation_dict(self):
|
| 446 |
-
"""start_video_generation returns operation dictionary."""
|
| 447 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 448 |
-
with patch('services.gemini_service.types'):
|
| 449 |
-
from services.gemini_service import GeminiService
|
| 450 |
-
|
| 451 |
-
mock_operation = MagicMock()
|
| 452 |
-
mock_operation.name = "operations/video-123"
|
| 453 |
-
mock_operation.done = False
|
| 454 |
-
|
| 455 |
-
mock_genai.Client.return_value.models.generate_videos.return_value = mock_operation
|
| 456 |
-
|
| 457 |
-
service = GeminiService(api_key="test-key")
|
| 458 |
-
result = await service.start_video_generation(
|
| 459 |
-
base64_image="base64data",
|
| 460 |
-
mime_type="image/jpeg",
|
| 461 |
-
prompt="Animate this"
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
assert result["gemini_operation_name"] == "operations/video-123"
|
| 465 |
-
assert result["done"] == False
|
| 466 |
-
assert result["status"] == "pending"
|
| 467 |
-
|
| 468 |
-
@pytest.mark.asyncio
|
| 469 |
-
async def test_start_video_completed_immediately(self):
|
| 470 |
-
"""start_video_generation returns completed when done=True."""
|
| 471 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 472 |
-
with patch('services.gemini_service.types'):
|
| 473 |
-
from services.gemini_service import GeminiService
|
| 474 |
-
|
| 475 |
-
mock_operation = MagicMock()
|
| 476 |
-
mock_operation.name = "operations/video-123"
|
| 477 |
-
mock_operation.done = True
|
| 478 |
-
|
| 479 |
-
mock_genai.Client.return_value.models.generate_videos.return_value = mock_operation
|
| 480 |
-
|
| 481 |
-
service = GeminiService(api_key="test-key")
|
| 482 |
-
result = await service.start_video_generation(
|
| 483 |
-
base64_image="base64data",
|
| 484 |
-
mime_type="image/jpeg",
|
| 485 |
-
prompt="Animate this"
|
| 486 |
-
)
|
| 487 |
-
|
| 488 |
-
assert result["status"] == "completed"
|
| 489 |
-
|
| 490 |
-
@pytest.mark.asyncio
|
| 491 |
-
async def test_start_video_with_params(self):
|
| 492 |
-
"""start_video_generation passes aspect_ratio and resolution."""
|
| 493 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 494 |
-
with patch('services.gemini_service.types'):
|
| 495 |
-
from services.gemini_service import GeminiService
|
| 496 |
-
|
| 497 |
-
mock_operation = MagicMock()
|
| 498 |
-
mock_operation.name = "operations/video-123"
|
| 499 |
-
mock_operation.done = False
|
| 500 |
-
|
| 501 |
-
mock_genai.Client.return_value.models.generate_videos.return_value = mock_operation
|
| 502 |
-
|
| 503 |
-
service = GeminiService(api_key="test-key")
|
| 504 |
-
await service.start_video_generation(
|
| 505 |
-
base64_image="base64data",
|
| 506 |
-
mime_type="image/jpeg",
|
| 507 |
-
prompt="Animate",
|
| 508 |
-
aspect_ratio="9:16",
|
| 509 |
-
resolution="1080p",
|
| 510 |
-
number_of_videos=2
|
| 511 |
-
)
|
| 512 |
-
|
| 513 |
-
# Verify config was passed
|
| 514 |
-
call_args = mock_genai.Client.return_value.models.generate_videos.call_args
|
| 515 |
-
assert call_args is not None
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
# =============================================================================
|
| 519 |
-
# 8. Video Status Checking Tests
|
| 520 |
-
# =============================================================================
|
| 521 |
-
|
| 522 |
-
class TestVideoStatusChecking:
|
| 523 |
-
"""Test check_video_status method."""
|
| 524 |
-
|
| 525 |
-
@pytest.mark.asyncio
|
| 526 |
-
async def test_check_status_pending(self):
|
| 527 |
-
"""check_video_status returns pending when not done."""
|
| 528 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 529 |
-
from services.gemini_service import GeminiService
|
| 530 |
-
|
| 531 |
-
mock_operation = MagicMock()
|
| 532 |
-
mock_operation.done = False
|
| 533 |
-
mock_operation.error = None
|
| 534 |
-
|
| 535 |
-
mock_genai.Client.return_value.operations.get.return_value = mock_operation
|
| 536 |
-
|
| 537 |
-
service = GeminiService(api_key="test-key")
|
| 538 |
-
result = await service.check_video_status("operations/video-123")
|
| 539 |
-
|
| 540 |
-
assert result["done"] == False
|
| 541 |
-
assert result["status"] == "pending"
|
| 542 |
-
|
| 543 |
-
@pytest.mark.asyncio
|
| 544 |
-
async def test_check_status_completed_with_url(self):
|
| 545 |
-
"""check_video_status returns completed with video URL."""
|
| 546 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 547 |
-
from services.gemini_service import GeminiService
|
| 548 |
-
|
| 549 |
-
# Build nested mock structure
|
| 550 |
-
mock_video = MagicMock()
|
| 551 |
-
mock_video.uri = "https://storage.googleapis.com/video.mp4"
|
| 552 |
-
|
| 553 |
-
mock_generated_video = MagicMock()
|
| 554 |
-
mock_generated_video.video = mock_video
|
| 555 |
-
|
| 556 |
-
mock_result = MagicMock()
|
| 557 |
-
mock_result.generated_videos = [mock_generated_video]
|
| 558 |
-
|
| 559 |
-
mock_operation = MagicMock()
|
| 560 |
-
mock_operation.done = True
|
| 561 |
-
mock_operation.error = None
|
| 562 |
-
mock_operation.result = mock_result
|
| 563 |
-
|
| 564 |
-
mock_genai.Client.return_value.operations.get.return_value = mock_operation
|
| 565 |
-
|
| 566 |
-
service = GeminiService(api_key="test-api-key")
|
| 567 |
-
result = await service.check_video_status("operations/video-123")
|
| 568 |
-
|
| 569 |
-
assert result["done"] == True
|
| 570 |
-
assert result["status"] == "completed"
|
| 571 |
-
assert "video_url" in result
|
| 572 |
-
assert "test-api-key" in result["video_url"] # API key appended
|
| 573 |
-
|
| 574 |
-
@pytest.mark.asyncio
|
| 575 |
-
async def test_check_status_operation_error(self):
|
| 576 |
-
"""check_video_status returns failed on operation error."""
|
| 577 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 578 |
-
from services.gemini_service import GeminiService
|
| 579 |
-
|
| 580 |
-
mock_error = MagicMock()
|
| 581 |
-
mock_error.message = "Content blocked by policy"
|
| 582 |
-
|
| 583 |
-
mock_operation = MagicMock()
|
| 584 |
-
mock_operation.done = True
|
| 585 |
-
mock_operation.error = mock_error
|
| 586 |
-
|
| 587 |
-
mock_genai.Client.return_value.operations.get.return_value = mock_operation
|
| 588 |
-
|
| 589 |
-
service = GeminiService(api_key="test-key")
|
| 590 |
-
result = await service.check_video_status("operations/video-123")
|
| 591 |
-
|
| 592 |
-
assert result["done"] == True
|
| 593 |
-
assert result["status"] == "failed"
|
| 594 |
-
assert "error" in result
|
| 595 |
-
|
| 596 |
-
@pytest.mark.asyncio
|
| 597 |
-
async def test_check_status_404_expired(self):
|
| 598 |
-
"""check_video_status handles 404 for expired operation."""
|
| 599 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 600 |
-
from services.gemini_service import GeminiService
|
| 601 |
-
|
| 602 |
-
mock_genai.Client.return_value.operations.get.side_effect = Exception("404 NOT_FOUND")
|
| 603 |
-
|
| 604 |
-
service = GeminiService(api_key="test-key")
|
| 605 |
-
result = await service.check_video_status("operations/expired-123")
|
| 606 |
-
|
| 607 |
-
assert result["done"] == True
|
| 608 |
-
assert result["status"] == "failed"
|
| 609 |
-
assert "expired" in result["error"].lower()
|
| 610 |
-
|
| 611 |
-
@pytest.mark.asyncio
|
| 612 |
-
async def test_check_status_no_video_uri(self):
|
| 613 |
-
"""check_video_status returns failed when no video URI."""
|
| 614 |
-
with patch('services.gemini_service.genai') as mock_genai:
|
| 615 |
-
from services.gemini_service import GeminiService
|
| 616 |
-
|
| 617 |
-
mock_result = MagicMock()
|
| 618 |
-
mock_result.generated_videos = [] # Empty
|
| 619 |
-
|
| 620 |
-
mock_operation = MagicMock()
|
| 621 |
-
mock_operation.done = True
|
| 622 |
-
mock_operation.error = None
|
| 623 |
-
mock_operation.result = mock_result
|
| 624 |
-
|
| 625 |
-
mock_genai.Client.return_value.operations.get.return_value = mock_operation
|
| 626 |
-
|
| 627 |
-
service = GeminiService(api_key="test-key")
|
| 628 |
-
result = await service.check_video_status("operations/video-123")
|
| 629 |
-
|
| 630 |
-
assert result["status"] == "failed"
|
| 631 |
-
assert "safety filters" in result["error"].lower()
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
# =============================================================================
|
| 635 |
-
# 9. Video Download Tests
|
| 636 |
-
# =============================================================================
|
| 637 |
-
|
| 638 |
-
class TestVideoDownload:
|
| 639 |
-
"""Test download_video method."""
|
| 640 |
-
|
| 641 |
-
@pytest.mark.asyncio
|
| 642 |
-
async def test_download_video_saves_file(self):
|
| 643 |
-
"""download_video saves file and returns filename."""
|
| 644 |
-
with patch('services.gemini_service.genai'):
|
| 645 |
-
from services.gemini_service import GeminiService, DOWNLOADS_DIR
|
| 646 |
-
|
| 647 |
-
with patch('httpx.AsyncClient') as mock_client:
|
| 648 |
-
mock_response = MagicMock()
|
| 649 |
-
mock_response.content = b"fake video data"
|
| 650 |
-
mock_response.raise_for_status = MagicMock()
|
| 651 |
-
|
| 652 |
-
mock_client_instance = AsyncMock()
|
| 653 |
-
mock_client_instance.get.return_value = mock_response
|
| 654 |
-
mock_client_instance.__aenter__.return_value = mock_client_instance
|
| 655 |
-
mock_client_instance.__aexit__.return_value = None
|
| 656 |
-
mock_client.return_value = mock_client_instance
|
| 657 |
-
|
| 658 |
-
service = GeminiService(api_key="test-key")
|
| 659 |
-
|
| 660 |
-
# Use temp directory
|
| 661 |
-
with tempfile.TemporaryDirectory() as temp_dir:
|
| 662 |
-
with patch.object(
|
| 663 |
-
__import__('services.gemini_service', fromlist=['DOWNLOADS_DIR']),
|
| 664 |
-
'DOWNLOADS_DIR',
|
| 665 |
-
temp_dir
|
| 666 |
-
):
|
| 667 |
-
result = await service.download_video(
|
| 668 |
-
"https://example.com/video.mp4",
|
| 669 |
-
"test-op-123"
|
| 670 |
-
)
|
| 671 |
-
|
| 672 |
-
assert result == "test-op-123.mp4"
|
| 673 |
-
|
| 674 |
-
@pytest.mark.asyncio
|
| 675 |
-
async def test_download_video_http_error(self):
|
| 676 |
-
"""download_video raises error on HTTP failure."""
|
| 677 |
-
with patch('services.gemini_service.genai'):
|
| 678 |
-
from services.gemini_service import GeminiService
|
| 679 |
-
|
| 680 |
-
with patch('httpx.AsyncClient') as mock_client:
|
| 681 |
-
mock_client_instance = AsyncMock()
|
| 682 |
-
mock_client_instance.get.side_effect = Exception("Connection refused")
|
| 683 |
-
mock_client_instance.__aenter__.return_value = mock_client_instance
|
| 684 |
-
mock_client_instance.__aexit__.return_value = None
|
| 685 |
-
mock_client.return_value = mock_client_instance
|
| 686 |
-
|
| 687 |
-
service = GeminiService(api_key="test-key")
|
| 688 |
-
|
| 689 |
-
with pytest.raises(ValueError, match="Failed to download"):
|
| 690 |
-
await service.download_video(
|
| 691 |
-
"https://example.com/video.mp4",
|
| 692 |
-
"test-op-123"
|
| 693 |
-
)
|
| 694 |
-
|
| 695 |
-
@pytest.mark.asyncio
|
| 696 |
-
async def test_download_video_follows_redirects(self):
|
| 697 |
-
"""download_video client is configured to follow redirects."""
|
| 698 |
-
with patch('services.gemini_service.genai'):
|
| 699 |
-
from services.gemini_service import GeminiService
|
| 700 |
-
|
| 701 |
-
with patch('httpx.AsyncClient') as mock_client:
|
| 702 |
-
mock_response = MagicMock()
|
| 703 |
-
mock_response.content = b"video data"
|
| 704 |
-
mock_response.raise_for_status = MagicMock()
|
| 705 |
-
|
| 706 |
-
mock_client_instance = AsyncMock()
|
| 707 |
-
mock_client_instance.get.return_value = mock_response
|
| 708 |
-
mock_client_instance.__aenter__.return_value = mock_client_instance
|
| 709 |
-
mock_client_instance.__aexit__.return_value = None
|
| 710 |
-
mock_client.return_value = mock_client_instance
|
| 711 |
-
|
| 712 |
-
service = GeminiService(api_key="test-key")
|
| 713 |
-
|
| 714 |
-
with tempfile.TemporaryDirectory() as temp_dir:
|
| 715 |
-
with patch('services.gemini_service.DOWNLOADS_DIR', temp_dir):
|
| 716 |
-
await service.download_video(
|
| 717 |
-
"https://example.com/video.mp4",
|
| 718 |
-
"redirect-test"
|
| 719 |
-
)
|
| 720 |
-
|
| 721 |
-
# Verify follow_redirects=True was passed
|
| 722 |
-
mock_client.assert_called_with(timeout=120.0, follow_redirects=True)
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
# =============================================================================
|
| 726 |
-
# 10. Error Handling Tests
|
| 727 |
-
# =============================================================================
|
| 728 |
-
|
| 729 |
-
class TestErrorHandling:
|
| 730 |
-
"""Test _handle_api_error method."""
|
| 731 |
-
|
| 732 |
-
def test_handle_api_error_404(self):
|
| 733 |
-
"""_handle_api_error raises ValueError for 404."""
|
| 734 |
-
with patch('services.gemini_service.genai'):
|
| 735 |
-
from services.gemini_service import GeminiService
|
| 736 |
-
|
| 737 |
-
service = GeminiService(api_key="test-key")
|
| 738 |
-
|
| 739 |
-
with pytest.raises(ValueError, match="Model not found"):
|
| 740 |
-
service._handle_api_error(Exception("Error 404"), "test-model")
|
| 741 |
-
|
| 742 |
-
def test_handle_api_error_not_found(self):
|
| 743 |
-
"""_handle_api_error handles NOT_FOUND in message."""
|
| 744 |
-
with patch('services.gemini_service.genai'):
|
| 745 |
-
from services.gemini_service import GeminiService
|
| 746 |
-
|
| 747 |
-
service = GeminiService(api_key="test-key")
|
| 748 |
-
|
| 749 |
-
with pytest.raises(ValueError, match="Model not found"):
|
| 750 |
-
service._handle_api_error(Exception("NOT_FOUND: resource"), "test-model")
|
| 751 |
-
|
| 752 |
-
def test_handle_api_error_entity_not_found(self):
|
| 753 |
-
"""_handle_api_error handles 'Requested entity was not found'."""
|
| 754 |
-
with patch('services.gemini_service.genai'):
|
| 755 |
-
from services.gemini_service import GeminiService
|
| 756 |
-
|
| 757 |
-
service = GeminiService(api_key="test-key")
|
| 758 |
-
|
| 759 |
-
with pytest.raises(ValueError, match="Model not found"):
|
| 760 |
-
service._handle_api_error(
|
| 761 |
-
Exception("Requested entity was not found"),
|
| 762 |
-
"test-model"
|
| 763 |
-
)
|
| 764 |
-
|
| 765 |
-
def test_handle_api_error_bracket_5_pattern(self):
|
| 766 |
-
"""_handle_api_error handles [5, pattern."""
|
| 767 |
-
with patch('services.gemini_service.genai'):
|
| 768 |
-
from services.gemini_service import GeminiService
|
| 769 |
-
|
| 770 |
-
service = GeminiService(api_key="test-key")
|
| 771 |
-
|
| 772 |
-
with pytest.raises(ValueError, match="Model not found"):
|
| 773 |
-
service._handle_api_error(
|
| 774 |
-
Exception("Response [5, 'NOT_FOUND']"),
|
| 775 |
-
"test-model"
|
| 776 |
-
)
|
| 777 |
-
|
| 778 |
-
def test_handle_api_error_reraises_other(self):
|
| 779 |
-
"""_handle_api_error re-raises non-404 errors."""
|
| 780 |
-
with patch('services.gemini_service.genai'):
|
| 781 |
-
from services.gemini_service import GeminiService
|
| 782 |
-
|
| 783 |
-
service = GeminiService(api_key="test-key")
|
| 784 |
-
|
| 785 |
-
with pytest.raises(RuntimeError, match="Connection timeout"):
|
| 786 |
-
service._handle_api_error(
|
| 787 |
-
RuntimeError("Connection timeout"),
|
| 788 |
-
"test-model"
|
| 789 |
-
)
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
# =============================================================================
|
| 793 |
-
# 11. Downloads Directory Tests
|
| 794 |
-
# =============================================================================
|
| 795 |
-
|
| 796 |
-
class TestDownloadsDirectory:
|
| 797 |
-
"""Test downloads directory handling."""
|
| 798 |
-
|
| 799 |
-
def test_downloads_dir_exists(self):
|
| 800 |
-
"""DOWNLOADS_DIR is created on module import."""
|
| 801 |
-
from services.gemini_service import DOWNLOADS_DIR
|
| 802 |
-
|
| 803 |
-
assert os.path.exists(DOWNLOADS_DIR)
|
| 804 |
-
assert os.path.isdir(DOWNLOADS_DIR)
|
| 805 |
-
|
| 806 |
-
def test_downloads_dir_is_in_project(self):
|
| 807 |
-
"""DOWNLOADS_DIR is within project directory."""
|
| 808 |
-
from services.gemini_service import DOWNLOADS_DIR
|
| 809 |
-
|
| 810 |
-
assert "downloads" in DOWNLOADS_DIR
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
if __name__ == "__main__":
|
| 814 |
-
pytest.main([__file__, "-v"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/{test_worker_pool.py → test_worker_pool.py.archived}
RENAMED
|
File without changes
|