jebin2 commited on
Commit
693e4e3
·
1 Parent(s): 2dbfc89

refactor: remove worker pool, use direct fal.ai API calls

Browse files

- Remove services/worker/ and priority_worker_pool.py
- Remove services/gemini_service/ (replaced by fal_service)
- Update gemini.py to call FalService directly
- Use fal.ai submit() for non-blocking job start
- Use fal.ai status() for non-blocking status check
- Remove worker startup/shutdown from app.py
- All API responses now immediate, no background processing

.env.example CHANGED
@@ -97,6 +97,21 @@ JOB_PER_API_KEY=2
97
  # Enable mock mode for testing without consuming API credits
98
  # GEMINI_MOCK_MODE=true
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # -----------------------------------------------------------------------------
101
  # Email Configuration (Optional)
102
  # -----------------------------------------------------------------------------
 
97
  # Enable mock mode for testing without consuming API credits
98
  # GEMINI_MOCK_MODE=true
99
 
100
+ # -----------------------------------------------------------------------------
101
+ # AI Provider Configuration
102
+ # -----------------------------------------------------------------------------
103
+ # Which AI provider to use for video generation
104
+ # Options: "google" (Gemini/Veo) or "fal" (fal.ai)
105
+ AI_PROVIDER=google
106
+
107
+ # Fal.ai API Key (required if AI_PROVIDER=fal)
108
+ # Get from: https://fal.ai/dashboard/keys
109
+ # Note: fal_client expects this env var to be named FAL_KEY
110
+ # FAL_KEY=your-fal-api-key
111
+
112
+ # Enable mock mode for fal.ai testing without consuming API credits
113
+ # FAL_MOCK_MODE=true
114
+
115
  # -----------------------------------------------------------------------------
116
  # Email Configuration (Optional)
117
  # -----------------------------------------------------------------------------
app.py CHANGED
@@ -159,24 +159,12 @@ async def lifespan(app: FastAPI):
159
  )
160
  logger.info("✅ Audit Service configured")
161
 
162
- # Register API Key Service configuration
163
- from services.gemini_service import APIKeyServiceConfig
164
- APIKeyServiceConfig.register(
165
- rotation_strategy="least_used", # or "round_robin"
166
- cooldown_seconds=60, # Wait 1 min after quota error
167
- max_requests_per_minute=60,
168
- retry_on_quota_error=True # Auto-retry with different key
169
- )
170
- logger.info("✅ API Key Service configured")
171
-
172
- # Worker Pool Section
173
  logger.info("")
174
- logger.info("👷 [WORKER POOL]")
 
175
 
176
- # Start background job worker
177
- from services.gemini_service import start_worker, stop_worker
178
- await start_worker()
179
- logger.info("✅ Worker pool started")
180
 
181
  # Log CORS configuration
182
  allowed_origins = os.getenv("CORS_ORIGINS").split(",")
@@ -189,17 +177,15 @@ async def lifespan(app: FastAPI):
189
  logger.info("═" * 60)
190
  logger.info(" 🚀 API Gateway Ready")
191
  logger.info(" • Database: ✅ Ready")
192
- logger.info(" • Services: 5 initialized (DB, Auth, Credit, Audit, API Key)")
193
- logger.info(" • Workers: 15 active")
194
  logger.info(" • Endpoint: http://0.0.0.0:8000")
195
  logger.info("═" * 60)
196
  logger.info("")
197
 
198
  yield
199
 
200
- # Stop background job worker
201
- await stop_worker()
202
- logger.info("Background job worker stopped")
203
 
204
  # Shutdown: Upload DB to Drive
205
  logger.info("Shutdown: Uploading database to Google Drive...")
@@ -230,10 +216,6 @@ from services.audit_service import AuditMiddleware
230
  app.add_middleware(AuditMiddleware)
231
 
232
 
233
- from services.gemini_service import APIKeyMiddleware
234
- app.add_middleware(APIKeyMiddleware)
235
-
236
-
237
  from services.auth_service import AuthMiddleware
238
  app.add_middleware(AuthMiddleware)
239
 
 
159
  )
160
  logger.info("✅ Audit Service configured")
161
 
162
+ # Job Processing Info
 
 
 
 
 
 
 
 
 
 
163
  logger.info("")
164
+ logger.info(" [JOB PROCESSING]")
165
+ logger.info("✅ Using inline processor (fire-and-forget async)")
166
 
167
+
 
 
 
168
 
169
  # Log CORS configuration
170
  allowed_origins = os.getenv("CORS_ORIGINS").split(",")
 
177
  logger.info("═" * 60)
178
  logger.info(" 🚀 API Gateway Ready")
179
  logger.info(" • Database: ✅ Ready")
180
+ logger.info(" • Services: 4 initialized (DB, Auth, Credit, Audit)")
181
+ logger.info(" • Processing: Inline (no workers)")
182
  logger.info(" • Endpoint: http://0.0.0.0:8000")
183
  logger.info("═" * 60)
184
  logger.info("")
185
 
186
  yield
187
 
188
+ # No worker cleanup needed - inline processor uses fire-and-forget tasks
 
 
189
 
190
  # Shutdown: Upload DB to Drive
191
  logger.info("Shutdown: Uploading database to Google Drive...")
 
216
  app.add_middleware(AuditMiddleware)
217
 
218
 
 
 
 
 
219
  from services.auth_service import AuthMiddleware
220
  app.add_middleware(AuthMiddleware)
221
 
requirements.txt CHANGED
@@ -1,19 +1,18 @@
1
- # FastAPI URL Blink Application Dependencies
2
- fastapi>=0.104.0
3
- uvicorn[standard]>=0.24.0
4
- sqlalchemy>=2.0.0
5
- aiosqlite>=0.19.0
6
- cryptography>=41.0.0
7
- pydantic>=2.0.0
8
- httpx>=0.25.0
9
-
10
- passlib[bcrypt]>=1.7.4
11
- email-validator>=2.0.0
12
- python-dotenv>=1.0.0
13
- google-api-python-client>=2.0.0
14
- google-auth-oauthlib>=1.0.0
15
- google-auth-httplib2>=0.1.0
16
- google-genai>=1.0.0
17
- PyJWT>=2.8.0
18
- razorpay>=1.4.0
19
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.12
2
+ uvicorn[standard]==0.34.3
3
+ sqlalchemy==2.0.41
4
+ aiosqlite==0.21.0
5
+ cryptography==45.0.5
6
+ pydantic==2.11.7
7
+ httpx==0.28.1
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ passlib[bcrypt]==1.7.4
10
+ email-validator==2.2.0
11
+ python-dotenv==1.1.1
12
+ google-api-python-client==2.187.0
13
+ google-auth-oauthlib==1.2.1
14
+ google-auth-httplib2==0.2.0
15
+ google-genai==1.57.0
16
+ PyJWT==2.10.1
17
+ razorpay==2.0.0
18
+ fal-client==0.5.9
routers/gemini.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- Gemini Router - API endpoints for Gemini AI services.
 
3
  """
4
  import os
5
  import uuid
@@ -12,10 +13,10 @@ from sqlalchemy import select, func
12
 
13
  from core.database import get_db
14
  from core.models import User, GeminiJob
15
- from services.gemini_service import MODELS, DOWNLOADS_DIR
16
  from datetime import datetime
17
 
18
- router = APIRouter(prefix="/gemini", tags=["gemini"])
19
 
20
 
21
 
@@ -71,18 +72,15 @@ async def create_job(
71
  input_data: dict,
72
  credits_reserved: int = 0
73
  ) -> GeminiJob:
74
- """Create a new job in the queue."""
75
- from services.gemini_service.job_processor import get_priority_for_job_type, get_pool
76
-
77
  job_id = f"job_{uuid.uuid4().hex[:16]}"
78
- priority = get_priority_for_job_type(job_type)
79
 
80
  job = GeminiJob(
81
  job_id=job_id,
82
  user_id=user.id,
83
  job_type=job_type,
84
  status="queued",
85
- priority=priority,
86
  input_data=input_data,
87
  credits_reserved=credits_reserved
88
  )
@@ -90,8 +88,22 @@ async def create_job(
90
  await db.commit()
91
  await db.refresh(job)
92
 
93
-
94
- get_pool().notify_new_job(priority)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  return job
97
 
@@ -322,7 +334,7 @@ async def get_job_status(
322
  req: Request,
323
  db: AsyncSession = Depends(get_db)
324
  ):
325
- """Get job status and update if processing."""
326
  user = req.state.user
327
  query = select(GeminiJob).where(
328
  GeminiJob.job_id == job_id,
@@ -337,13 +349,30 @@ async def get_job_status(
337
  detail="Job not found"
338
  )
339
 
340
-
341
  if job.status == "processing" and job.job_type == "video" and job.third_party_id:
342
- from services.gemini_service.job_processor import GeminiJobProcessor
343
- processor = GeminiJobProcessor()
344
- job = await processor.check_status(job, db)
345
- await db.commit()
346
- await db.refresh(job)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  response = {
349
  "success": True,
@@ -353,23 +382,17 @@ async def get_job_status(
353
  "created_at": job.created_at.isoformat() if job.created_at else None,
354
  "credits_remaining": user.credits
355
  }
356
-
357
 
358
  if job.job_type == "video" and job.input_data:
359
  response["prompt"] = job.input_data.get("prompt")
360
 
361
- if job.status == "queued":
362
- response["position"] = await get_queue_position(db, job.job_id)
363
-
364
  if job.status == "processing":
365
  response["started_at"] = job.started_at.isoformat() if job.started_at else None
366
 
367
  if job.status == "completed":
368
  response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
369
-
370
-
371
- if job.output_data and "prompt" in job.output_data:
372
- response["prompt"] = job.output_data["prompt"]
373
 
374
  if job.status == "failed":
375
  response["error"] = job.error_message
 
1
  """
2
+ Video Router - API endpoints for AI video generation services.
3
+ Uses fal.ai for video generation.
4
  """
5
  import os
6
  import uuid
 
13
 
14
  from core.database import get_db
15
  from core.models import User, GeminiJob
16
+ from services.fal_service import MODELS
17
  from datetime import datetime
18
 
19
+ router = APIRouter(prefix="/gemini", tags=["video"])
20
 
21
 
22
 
 
72
  input_data: dict,
73
  credits_reserved: int = 0
74
  ) -> GeminiJob:
75
+ """Create job and start processing on fal.ai (non-blocking)."""
 
 
76
  job_id = f"job_{uuid.uuid4().hex[:16]}"
 
77
 
78
  job = GeminiJob(
79
  job_id=job_id,
80
  user_id=user.id,
81
  job_type=job_type,
82
  status="queued",
83
+ priority="slow",
84
  input_data=input_data,
85
  credits_reserved=credits_reserved
86
  )
 
88
  await db.commit()
89
  await db.refresh(job)
90
 
91
+ # Start fal.ai job immediately (non-blocking)
92
+ if job_type == "video":
93
+ from services.fal_service import FalService
94
+ fal = FalService()
95
+ result = await fal.start_video_generation(
96
+ base64_image=input_data.get("base64_image", ""),
97
+ mime_type=input_data.get("mime_type", "image/jpeg"),
98
+ prompt=input_data.get("prompt", "Animate this image with subtle movement."),
99
+ aspect_ratio=input_data.get("aspect_ratio", "16:9"),
100
+ resolution=input_data.get("resolution", "720p"),
101
+ )
102
+ job.status = "processing"
103
+ job.started_at = datetime.utcnow()
104
+ job.third_party_id = result.get("fal_request_id")
105
+ job.api_response = result
106
+ await db.commit()
107
 
108
  return job
109
 
 
334
  req: Request,
335
  db: AsyncSession = Depends(get_db)
336
  ):
337
+ """Get job status. Checks fal.ai if processing (non-blocking)."""
338
  user = req.state.user
339
  query = select(GeminiJob).where(
340
  GeminiJob.job_id == job_id,
 
349
  detail="Job not found"
350
  )
351
 
352
+ # Check fal.ai status if processing (non-blocking)
353
  if job.status == "processing" and job.job_type == "video" and job.third_party_id:
354
+ from services.fal_service import FalService
355
+ fal = FalService()
356
+ result = await fal.check_video_status(job.third_party_id)
357
+
358
+ if result.get("done"):
359
+ if result.get("status") == "completed":
360
+ job.status = "completed"
361
+ job.output_data = {"video_url": result.get("video_url")}
362
+ job.completed_at = datetime.utcnow()
363
+ else:
364
+ job.status = "failed"
365
+ job.error_message = result.get("error", "Unknown error")
366
+ job.completed_at = datetime.utcnow()
367
+
368
+ # Handle credits on completion
369
+ if job.credits_reserved > 0:
370
+ try:
371
+ from services.credit_service.credit_manager import handle_job_completion
372
+ await handle_job_completion(db, job)
373
+ except Exception:
374
+ pass
375
+ await db.commit()
376
 
377
  response = {
378
  "success": True,
 
382
  "created_at": job.created_at.isoformat() if job.created_at else None,
383
  "credits_remaining": user.credits
384
  }
 
385
 
386
  if job.job_type == "video" and job.input_data:
387
  response["prompt"] = job.input_data.get("prompt")
388
 
 
 
 
389
  if job.status == "processing":
390
  response["started_at"] = job.started_at.isoformat() if job.started_at else None
391
 
392
  if job.status == "completed":
393
  response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
394
+ if job.output_data and job.output_data.get("video_url"):
395
+ response["video_url"] = job.output_data.get("video_url")
 
 
396
 
397
  if job.status == "failed":
398
  response["error"] = job.error_message
services/fal_service/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fal.ai Service for video generation.
3
+ Provides access to Veo 3.1 and other models through fal.ai's unified API.
4
+ """
5
+ from services.fal_service.api_client import (
6
+ FalService,
7
+ MODELS,
8
+ MOCK_MODE,
9
+ get_fal_api_key,
10
+ )
11
+
12
+ __all__ = [
13
+ "FalService",
14
+ "MODELS",
15
+ "MOCK_MODE",
16
+ "get_fal_api_key",
17
+ ]
services/fal_service/api_client.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fal.ai Service for video generation.
3
+ Python implementation using fal-client SDK.
4
+ Uses server-side API key from environment.
5
+ """
6
+ import asyncio
7
+ import logging
8
+ import os
9
+ from typing import Optional, Literal
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Model names - easily configurable
14
+ MODELS = {
15
+ "video_generation": "fal-ai/veo3.1/fast/image-to-video"
16
+ }
17
+
18
+ # Type aliases
19
+ AspectRatio = Literal["16:9", "9:16", "auto"]
20
+ Resolution = Literal["720p", "1080p"]
21
+
22
+ # Mock mode for local testing (set FAL_MOCK_MODE=true to skip real API calls)
23
+ MOCK_MODE = os.getenv("FAL_MOCK_MODE", "false").lower() == "true"
24
+
25
+ # Sample video URL for mock mode
26
+ MOCK_VIDEO_URL = "https://v3b.fal.media/files/mock/mock-video.mp4"
27
+
28
+
29
+ def get_fal_api_key() -> str:
30
+ """Get Fal.ai API key from environment."""
31
+ api_key = os.getenv("FAL_KEY")
32
+ if not api_key:
33
+ raise ValueError("Server Authentication Error: FAL_KEY not configured")
34
+ return api_key
35
+
36
+
37
+ class FalService:
38
+ """
39
+ Fal.ai Service for video generation.
40
+ Uses server-side API key from environment (FAL_KEY).
41
+ """
42
+
43
+ def __init__(self, api_key: Optional[str] = None):
44
+ """Initialize the Fal.ai client with API key from env or provided."""
45
+ self.api_key = api_key or get_fal_api_key()
46
+ # fal_client reads from FAL_KEY env var by default
47
+ # Set it explicitly if a custom key is provided
48
+ if api_key:
49
+ os.environ["FAL_KEY"] = api_key
50
+
51
+ def _handle_api_error(self, error: Exception, context: str):
52
+ """Handle API errors with descriptive messages."""
53
+ msg = str(error)
54
+ if "401" in msg or "Unauthorized" in msg:
55
+ raise ValueError(
56
+ f"Authentication failed ({context}). Check your FAL_KEY is valid."
57
+ )
58
+ if "402" in msg or "Payment Required" in msg:
59
+ raise ValueError(
60
+ f"Insufficient credits ({context}). Add credits at fal.ai."
61
+ )
62
+ if "429" in msg or "Rate limit" in msg.lower():
63
+ raise ValueError(
64
+ f"Rate limit exceeded ({context}). Wait and retry."
65
+ )
66
+ raise error
67
+
68
+ async def start_video_generation(
69
+ self,
70
+ base64_image: str,
71
+ mime_type: str,
72
+ prompt: str,
73
+ aspect_ratio: AspectRatio = "16:9",
74
+ resolution: Resolution = "720p",
75
+ number_of_videos: int = 1
76
+ ) -> dict:
77
+ """
78
+ Start video generation using Fal.ai Veo 3.1 model.
79
+ Unlike Gemini, fal.ai subscribe() handles polling internally,
80
+ so this returns the completed video directly.
81
+
82
+ Returns dict with:
83
+ - fal_request_id: Request ID for reference
84
+ - done: Always True (fal.ai waits for completion)
85
+ - status: "completed" or "failed"
86
+ - video_url: URL to the generated video
87
+ """
88
+ # Mock mode for testing without API credits
89
+ if MOCK_MODE:
90
+ import uuid
91
+ mock_request_id = f"mock_fal_{uuid.uuid4().hex[:16]}"
92
+ logger.info(f"[MOCK MODE] Video generation: {mock_request_id}")
93
+ await asyncio.sleep(2) # Simulate API delay
94
+ return {
95
+ "fal_request_id": mock_request_id,
96
+ "done": True,
97
+ "status": "completed",
98
+ "video_url": MOCK_VIDEO_URL
99
+ }
100
+
101
+ try:
102
+ import fal_client
103
+
104
+ # Use submit() instead of subscribe() - returns immediately without waiting
105
+ # This starts the job and returns a request_id for status checking
106
+ handle = await asyncio.to_thread(
107
+ fal_client.submit,
108
+ MODELS["video_generation"],
109
+ arguments={
110
+ "prompt": prompt,
111
+ "image_url": f"data:{mime_type};base64,{base64_image}",
112
+ "aspect_ratio": aspect_ratio,
113
+ "resolution": resolution,
114
+ "generate_audio": True,
115
+ },
116
+ )
117
+
118
+ # Get the request ID from the handle
119
+ request_id = handle.request_id if hasattr(handle, 'request_id') else str(handle)
120
+
121
+ return {
122
+ "fal_request_id": request_id,
123
+ "done": False,
124
+ "status": "processing",
125
+ }
126
+
127
+ except ImportError:
128
+ raise ValueError(
129
+ "fal-client package not installed. Run: pip install fal-client"
130
+ )
131
+ except Exception as error:
132
+ self._handle_api_error(error, MODELS["video_generation"])
133
+
134
+ async def check_video_status(self, fal_request_id: str) -> dict:
135
+ """
136
+ Check the status of a video generation request.
137
+ Returns immediately with current status (does not wait).
138
+ """
139
+ # Mock mode for testing
140
+ if MOCK_MODE:
141
+ import random
142
+ # Simulate completion after a few checks
143
+ if random.random() > 0.7:
144
+ return {
145
+ "fal_request_id": fal_request_id,
146
+ "done": True,
147
+ "status": "completed",
148
+ "video_url": MOCK_VIDEO_URL
149
+ }
150
+ return {
151
+ "fal_request_id": fal_request_id,
152
+ "done": False,
153
+ "status": "processing"
154
+ }
155
+
156
+ try:
157
+ import fal_client
158
+
159
+ # Get status without waiting
160
+ status = await asyncio.to_thread(
161
+ fal_client.status,
162
+ MODELS["video_generation"],
163
+ fal_request_id,
164
+ with_logs=False
165
+ )
166
+
167
+ # Check if completed
168
+ if hasattr(status, 'status'):
169
+ if status.status == "COMPLETED":
170
+ # Get the result
171
+ result = await asyncio.to_thread(
172
+ fal_client.result,
173
+ MODELS["video_generation"],
174
+ fal_request_id
175
+ )
176
+
177
+ # Extract video URL
178
+ video_url = None
179
+ if isinstance(result, dict) and "video" in result:
180
+ video_url = result["video"].get("url")
181
+ elif hasattr(result, "video") and hasattr(result.video, "url"):
182
+ video_url = result.video.url
183
+
184
+ return {
185
+ "fal_request_id": fal_request_id,
186
+ "done": True,
187
+ "status": "completed",
188
+ "video_url": video_url
189
+ }
190
+ elif status.status == "FAILED":
191
+ return {
192
+ "fal_request_id": fal_request_id,
193
+ "done": True,
194
+ "status": "failed",
195
+ "error": getattr(status, 'error', 'Unknown error')
196
+ }
197
+ else:
198
+ # Still processing (IN_QUEUE, IN_PROGRESS)
199
+ return {
200
+ "fal_request_id": fal_request_id,
201
+ "done": False,
202
+ "status": "processing"
203
+ }
204
+
205
+ # Fallback - assume still processing
206
+ return {
207
+ "fal_request_id": fal_request_id,
208
+ "done": False,
209
+ "status": "processing"
210
+ }
211
+
212
+ except ImportError:
213
+ raise ValueError(
214
+ "fal-client package not installed. Run: pip install fal-client"
215
+ )
216
+ except Exception as error:
217
+ logger.error(f"Error checking status for {fal_request_id}: {error}")
218
+ return {
219
+ "fal_request_id": fal_request_id,
220
+ "done": False,
221
+ "status": "processing",
222
+ "error": str(error)
223
+ }
224
+
225
+ async def download_video(self, video_url: str, request_id: str) -> str:
226
+ """
227
+ Download video from fal.ai to local storage.
228
+ Returns the local filename.
229
+ """
230
+ import httpx
231
+
232
+ # Use same downloads directory as Gemini service
233
+ downloads_dir = os.path.join(
234
+ os.path.dirname(os.path.dirname(__file__)),
235
+ "downloads"
236
+ )
237
+ os.makedirs(downloads_dir, exist_ok=True)
238
+
239
+ filename = f"{request_id}.mp4"
240
+ filepath = os.path.join(downloads_dir, filename)
241
+
242
+ try:
243
+ async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
244
+ response = await client.get(video_url)
245
+ response.raise_for_status()
246
+
247
+ with open(filepath, 'wb') as f:
248
+ f.write(response.content)
249
+
250
+ logger.info(f"Downloaded video to {filepath}")
251
+ return filename
252
+ except Exception as e:
253
+ logger.error(f"Failed to download video: {e}")
254
+ raise ValueError(f"Failed to download video: {e}")
services/gemini_service/__init__.py DELETED
@@ -1,55 +0,0 @@
1
- """
2
- Gemini Service - AI-powered image and video generation
3
-
4
- Provides:
5
- - Text generation
6
- - Image editing
7
- - Video generation
8
- - Job processing and background workers
9
- """
10
-
11
- # API Client exports
12
- from services.gemini_service.api_client import (
13
- GeminiService,
14
- MODELS,
15
- DOWNLOADS_DIR,
16
- get_gemini_api_key,
17
- MOCK_MODE,
18
- MOCK_VIDEO_URL,
19
- )
20
-
21
- # Job Processor exports
22
- from services.gemini_service.job_processor import (
23
- GeminiJobProcessor,
24
- PriorityWorkerPool,
25
- get_pool,
26
- get_priority_for_job_type,
27
- start_worker,
28
- stop_worker,
29
- )
30
-
31
- # API Key Middleware exports # Added
32
- from services.gemini_service.api_key_config import APIKeyServiceConfig # Added
33
- from services.gemini_service.api_key_middleware import APIKeyMiddleware # Added
34
-
35
- __all__ = [
36
- # API Client
37
- 'GeminiService',
38
- 'MODELS',
39
- 'DOWNLOADS_DIR',
40
- 'get_gemini_api_key',
41
- 'MOCK_MODE',
42
- 'MOCK_VIDEO_URL',
43
-
44
- # Job Processor
45
- 'GeminiJobProcessor',
46
- 'PriorityWorkerPool',
47
- 'get_pool',
48
- 'get_priority_for_job_type',
49
- 'start_worker',
50
- 'stop_worker',
51
-
52
- # API Key Middleware
53
- 'APIKeyServiceConfig',
54
- 'APIKeyMiddleware',
55
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/gemini_service/api_client.py DELETED
@@ -1,401 +0,0 @@
1
- """
2
- Gemini AI Service for image and video generation.
3
- Python port of the TypeScript geminiService.ts
4
- Uses server-side API key from environment.
5
- """
6
- import asyncio
7
- import logging
8
- import os
9
- import uuid
10
- import httpx
11
- from typing import Optional, Literal
12
- from google import genai
13
- from google.genai import types
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
- # Model names - easily configurable
18
- MODELS = {
19
- "text_generation": "gemini-2.5-flash",
20
- "image_edit": "gemini-2.5-flash-image",
21
- "video_generation": "veo-3.1-generate-preview"
22
- }
23
-
24
- # Type aliases
25
- AspectRatio = Literal["16:9", "9:16"]
26
- Resolution = Literal["720p", "1080p"]
27
-
28
- # Video downloads directory
29
- DOWNLOADS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "downloads")
30
-
31
- # Ensure downloads directory exists
32
- os.makedirs(DOWNLOADS_DIR, exist_ok=True)
33
-
34
- # Mock mode for local testing (set GEMINI_MOCK_MODE=true to skip real API calls)
35
- MOCK_MODE = os.getenv("GEMINI_MOCK_MODE", "false").lower() == "true"
36
- MOCK_MODE_SLEEP_TIME = os.getenv("GEMINI_MOCK_MODE_SLEEP_TIME", "0.5")
37
-
38
- # Sample video URL for mock mode (a public test video)
39
- MOCK_VIDEO_URL = "https://video.twimg.com/amplify_video/1994083297756848128/vid/avc1/576x576/ue31qU0xts8L9tXD.mp4?tag=21"
40
-
41
- # Concurrency limits from environment (defaults)
42
- MAX_CONCURRENT_VIDEOS = int(os.getenv("MAX_CONCURRENT_VIDEOS", "2"))
43
- MAX_CONCURRENT_IMAGES = int(os.getenv("MAX_CONCURRENT_IMAGES", "5"))
44
- MAX_CONCURRENT_TEXT = int(os.getenv("MAX_CONCURRENT_TEXT", "10"))
45
-
46
- # Semaphores for concurrency control
47
- _video_semaphore: Optional[asyncio.Semaphore] = None
48
- _image_semaphore: Optional[asyncio.Semaphore] = None
49
- _text_semaphore: Optional[asyncio.Semaphore] = None
50
-
51
-
52
- def get_video_semaphore() -> asyncio.Semaphore:
53
- """Get or create video semaphore."""
54
- global _video_semaphore
55
- if _video_semaphore is None:
56
- _video_semaphore = asyncio.Semaphore(MAX_CONCURRENT_VIDEOS)
57
- logger.info(f"Video semaphore initialized with limit: {MAX_CONCURRENT_VIDEOS}")
58
- return _video_semaphore
59
-
60
-
61
- def get_image_semaphore() -> asyncio.Semaphore:
62
- """Get or create image semaphore."""
63
- global _image_semaphore
64
- if _image_semaphore is None:
65
- _image_semaphore = asyncio.Semaphore(MAX_CONCURRENT_IMAGES)
66
- logger.info(f"Image semaphore initialized with limit: {MAX_CONCURRENT_IMAGES}")
67
- return _image_semaphore
68
-
69
-
70
- def get_text_semaphore() -> asyncio.Semaphore:
71
- """Get or create text semaphore."""
72
- global _text_semaphore
73
- if _text_semaphore is None:
74
- _text_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TEXT)
75
- logger.info(f"Text semaphore initialized with limit: {MAX_CONCURRENT_TEXT}")
76
- return _text_semaphore
77
-
78
-
79
- def get_gemini_api_key() -> str:
80
- """Get Gemini API key from environment."""
81
- api_key = os.getenv("GEMINI_API_KEY")
82
- if not api_key:
83
- raise ValueError("Server Authentication Error with GEMINI")
84
- return api_key
85
-
86
-
87
- class GeminiService:
88
- """
89
- Gemini AI Service for text, image, and video generation.
90
- Uses server-side API key from environment.
91
- """
92
-
93
- def __init__(self, api_key: Optional[str] = None):
94
- """Initialize the Gemini client with API key from env or provided."""
95
- self.api_key = api_key or get_gemini_api_key()
96
- self.client = genai.Client(api_key=self.api_key)
97
-
98
- def _handle_api_error(self, error: Exception, context: str):
99
- """Handle API errors with descriptive messages."""
100
- msg = str(error)
101
- if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg or "[5," in msg:
102
- raise ValueError(
103
- f"Model not found ({context}). Ensure your API key project has access to this model. "
104
- "Veo requires a paid account."
105
- )
106
- raise error
107
-
108
- async def generate_animation_prompt(
109
- self,
110
- base64_image: str,
111
- mime_type: str,
112
- custom_prompt: Optional[str] = None
113
- ) -> str:
114
- """
115
- Analyzes the image to generate a suitable animation prompt.
116
- """
117
- # Mock mode for testing
118
- if MOCK_MODE:
119
- logger.info("[MOCK MODE] Generating animation prompt")
120
- await asyncio.sleep(GEMINI_MOCK_MODE_SLEEP_TIME) # Simulate API delay
121
- return "A gentle breeze rustles through the scene as soft light dances across the surface. The camera slowly zooms in with a subtle parallax effect, creating depth and movement."
122
-
123
- default_prompt = custom_prompt or "Describe how this image could be subtly animated with cinematic movement."
124
- async with get_text_semaphore():
125
- try:
126
- response = await asyncio.to_thread(
127
- self.client.models.generate_content,
128
- model=MODELS["text_generation"],
129
- contents=types.Content(
130
- parts=[
131
- types.Part.from_bytes(
132
- data=base64_image,
133
- mime_type=mime_type
134
- ),
135
- types.Part.from_text(text=default_prompt)
136
- ]
137
- )
138
- )
139
- return response.text or "Cinematic subtle movement"
140
- except Exception as error:
141
- self._handle_api_error(error, MODELS["text_generation"])
142
-
143
- async def edit_image(
144
- self,
145
- base64_image: str,
146
- mime_type: str,
147
- prompt: str
148
- ) -> str:
149
- """
150
- Edit an image using Gemini image model.
151
- Returns base64 data URI of the edited image.
152
- """
153
- # Mock mode for testing - return a sample image
154
- if MOCK_MODE:
155
- logger.info(f"[MOCK MODE] Editing image with prompt: {prompt}")
156
- await asyncio.sleep(1) # Simulate API delay
157
- # Return a small red placeholder image (1x1 pixel)
158
- return "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
159
-
160
- async with get_image_semaphore():
161
- try:
162
- response = await asyncio.to_thread(
163
- self.client.models.generate_content,
164
- model=MODELS["image_edit"],
165
- contents=types.Content(
166
- parts=[
167
- types.Part.from_bytes(
168
- data=base64_image,
169
- mime_type=mime_type
170
- ),
171
- types.Part.from_text(text=prompt or "Enhance this image")
172
- ]
173
- )
174
- )
175
-
176
- candidates = response.candidates
177
- if not candidates:
178
- raise ValueError("No candidates returned from Gemini.")
179
-
180
- for part in candidates[0].content.parts:
181
- if hasattr(part, 'inline_data') and part.inline_data and part.inline_data.data:
182
- result_mime = part.inline_data.mime_type or 'image/png'
183
- return f"data:{result_mime};base64,{part.inline_data.data}"
184
-
185
- raise ValueError("No image data found in the response.")
186
- except Exception as error:
187
- self._handle_api_error(error, MODELS["image_edit"])
188
-
189
- async def start_video_generation(
190
- self,
191
- base64_image: str,
192
- mime_type: str,
193
- prompt: str,
194
- aspect_ratio: AspectRatio = "16:9",
195
- resolution: Resolution = "720p",
196
- number_of_videos: int = 1
197
- ) -> dict:
198
- """
199
- Start video generation using Veo model.
200
- Returns operation details for polling.
201
- """
202
- # Mock mode for testing without API credits
203
- if MOCK_MODE:
204
- import uuid
205
- mock_operation_name = f"mock_operation_{uuid.uuid4().hex[:16]}"
206
- logger.info(f"[MOCK MODE] Starting video generation: {mock_operation_name}")
207
- return {
208
- "gemini_operation_name": mock_operation_name,
209
- "done": False,
210
- "status": "pending"
211
- }
212
-
213
- async with get_video_semaphore():
214
- try:
215
- # Start video generation
216
- operation = await asyncio.to_thread(
217
- self.client.models.generate_videos,
218
- model=MODELS["video_generation"],
219
- prompt=prompt,
220
- image=types.Image(
221
- image_bytes=base64_image,
222
- mime_type=mime_type
223
- ),
224
- config=types.GenerateVideosConfig(
225
- number_of_videos=number_of_videos,
226
- resolution=resolution,
227
- aspect_ratio=aspect_ratio
228
- )
229
- )
230
-
231
- # Return operation details
232
- return {
233
- "gemini_operation_name": operation.name,
234
- "done": operation.done,
235
- "status": "completed" if operation.done else "pending"
236
- }
237
- except Exception as error:
238
- self._handle_api_error(error, MODELS["video_generation"])
239
-
240
- async def check_video_status(self, gemini_operation_name: str) -> dict:
241
- """
242
- Check the status of a video generation operation.
243
- Returns status and video URL if complete.
244
- """
245
- # Mock mode for testing without API credits
246
- if MOCK_MODE:
247
- # Simulate processing time: complete after 2 checks (track via a simple mechanism)
248
- # For simplicity, always return completed with mock video URL
249
- logger.info(f"[MOCK MODE] Checking video status: {gemini_operation_name}")
250
- await asyncio.sleep(2) # Simulate API delay
251
- return {
252
- "gemini_operation_name": gemini_operation_name,
253
- "done": True,
254
- "status": "completed",
255
- "video_url": MOCK_VIDEO_URL
256
- }
257
-
258
- try:
259
- # Get operation status using the operation object
260
- # First, we need to recreate the operation from the name
261
- from google.genai.types import GenerateVideosOperation
262
-
263
- operation = await asyncio.to_thread(
264
- self.client.operations.get,
265
- GenerateVideosOperation(name=gemini_operation_name, done=False)
266
- )
267
-
268
- if not operation.done:
269
- return {
270
- "gemini_operation_name": gemini_operation_name,
271
- "done": False,
272
- "status": "pending"
273
- }
274
-
275
- # Check for error - handle both string and object types
276
- if operation.error:
277
- error_msg = operation.error
278
- if hasattr(operation.error, 'message'):
279
- error_msg = operation.error.message
280
- return {
281
- "gemini_operation_name": gemini_operation_name,
282
- "done": True,
283
- "status": "failed",
284
- "error": str(error_msg) or "Unknown error"
285
- }
286
-
287
- # Extract video URI from result
288
- result = operation.result
289
- if result and hasattr(result, 'generated_videos') and result.generated_videos:
290
- video = result.generated_videos[0]
291
- if hasattr(video, 'video') and video.video and hasattr(video.video, 'uri'):
292
- video_uri = video.video.uri
293
- return {
294
- "gemini_operation_name": gemini_operation_name,
295
- "done": True,
296
- "status": "completed",
297
- "video_url": f"{video_uri}&key={self.api_key}"
298
- }
299
-
300
- return {
301
- "gemini_operation_name": gemini_operation_name,
302
- "done": True,
303
- "status": "failed",
304
- "error": "No video URI returned. May be due to safety filters."
305
- }
306
-
307
- except Exception as error:
308
- msg = str(error)
309
- if "404" in msg or "NOT_FOUND" in msg or "Requested entity was not found" in msg:
310
- return {
311
- "gemini_operation_name": gemini_operation_name,
312
- "done": True,
313
- "status": "failed",
314
- "error": "Operation not found (404). It may have expired."
315
- }
316
- raise error
317
-
318
- async def download_video(self, video_url: str, operation_id: str) -> str:
319
- """
320
- Download video from Gemini to local storage.
321
- Returns the local filename.
322
- """
323
- filename = f"{operation_id}.mp4"
324
- filepath = os.path.join(DOWNLOADS_DIR, filename)
325
-
326
- try:
327
- # follow_redirects=True is required as Gemini returns 302 redirects
328
- async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
329
- response = await client.get(video_url)
330
- response.raise_for_status()
331
-
332
- with open(filepath, 'wb') as f:
333
- f.write(response.content)
334
-
335
- logger.info(f"Downloaded video to {filepath}")
336
- return filename
337
- except Exception as e:
338
- logger.error(f"Failed to download video: {e}")
339
- raise ValueError(f"Failed to download video: {e}")
340
-
341
- async def generate_text(
342
- self,
343
- prompt: str,
344
- model: Optional[str] = None
345
- ) -> str:
346
- """
347
- Simple text generation with Gemini.
348
- """
349
- # Mock mode for testing
350
- if MOCK_MODE:
351
- logger.info(f"[MOCK MODE] Generating text for prompt: {prompt[:50]}...")
352
- await asyncio.sleep(MOCK_MODE_SLEEP_TIME) # Simulate API delay
353
- return f"This is a mock response for your prompt: '{prompt[:100]}...'. In production, this would be generated by Gemini AI."
354
-
355
- model_name = model or MODELS["text_generation"]
356
- async with get_text_semaphore():
357
- try:
358
- response = await asyncio.to_thread(
359
- self.client.models.generate_content,
360
- model=model_name,
361
- contents=types.Content(
362
- parts=[types.Part.from_text(text=prompt)]
363
- )
364
- )
365
- return response.text or ""
366
- except Exception as error:
367
- self._handle_api_error(error, model_name)
368
-
369
- async def analyze_image(
370
- self,
371
- base64_image: str,
372
- mime_type: str,
373
- prompt: str
374
- ) -> str:
375
- """
376
- Analyze image with custom prompt.
377
- """
378
- # Mock mode for testing
379
- if MOCK_MODE:
380
- logger.info(f"[MOCK MODE] Analyzing image with prompt: {prompt[:50]}...")
381
- await asyncio.sleep(MOCK_MODE_SLEEP_TIME) # Simulate API delay
382
- return f"Mock analysis result: The image appears to show a scene that matches your query '{prompt[:50]}...'. This is placeholder content for testing."
383
-
384
- async with get_text_semaphore():
385
- try:
386
- response = await asyncio.to_thread(
387
- self.client.models.generate_content,
388
- model=MODELS["text_generation"],
389
- contents=types.Content(
390
- parts=[
391
- types.Part.from_bytes(
392
- data=base64_image,
393
- mime_type=mime_type
394
- ),
395
- types.Part.from_text(text=prompt)
396
- ]
397
- )
398
- )
399
- return response.text or ""
400
- except Exception as error:
401
- self._handle_api_error(error, MODELS["text_generation"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/gemini_service/api_key_config.py DELETED
@@ -1,100 +0,0 @@
1
- """
2
- API Key Service Configuration
3
-
4
- Configures automatic API key selection and rotation via middleware.
5
- """
6
- from typing import List, Optional
7
- import os
8
- import logging
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class APIKeyServiceConfig:
14
- """Configuration for API key middleware."""
15
-
16
- _rotation_strategy: str = "least_used" # or "round_robin"
17
- _cooldown_seconds: int = 60
18
- _max_requests_per_minute: int = 60
19
- _retry_on_quota_error: bool = True
20
- _api_keys: Optional[List[str]] = None
21
-
22
- @classmethod
23
- def register(
24
- cls,
25
- rotation_strategy: str = "least_used",
26
- cooldown_seconds: int = 60,
27
- max_requests_per_minute: int = 60,
28
- retry_on_quota_error: bool = True
29
- ) -> None:
30
- """
31
- Register API key service configuration.
32
-
33
- Args:
34
- rotation_strategy: "least_used" or "round_robin"
35
- cooldown_seconds: Time to wait before reusing a key after quota error
36
- max_requests_per_minute: Rate limit per key
37
- retry_on_quota_error: Auto-retry with different key on 429
38
-
39
- Example:
40
- APIKeyServiceConfig.register(
41
- rotation_strategy="least_used",
42
- cooldown_seconds=60,
43
- retry_on_quota_error=True
44
- )
45
- """
46
- cls._rotation_strategy = rotation_strategy
47
- cls._cooldown_seconds = cooldown_seconds
48
- cls._max_requests_per_minute = max_requests_per_minute
49
- cls._retry_on_quota_error = retry_on_quota_error
50
-
51
- # Load API keys from env
52
- cls._load_api_keys()
53
-
54
- logger.info(
55
- f"API Key Service configured: "
56
- f"keys={len(cls._api_keys or [])}, "
57
- f"strategy={rotation_strategy}, "
58
- f"retry={retry_on_quota_error}"
59
- )
60
-
61
- @classmethod
62
- def _load_api_keys(cls):
63
- """Load API keys from environment variables."""
64
- keys_str = os.getenv("GEMINI_API_KEYS", "")
65
- if not keys_str:
66
- # Fallback to single key
67
- single_key = os.getenv("GEMINI_API_KEY", "")
68
- if single_key:
69
- cls._api_keys = [single_key]
70
- else:
71
- cls._api_keys = []
72
- logger.warning("No Gemini API keys configured!")
73
- else:
74
- cls._api_keys = [k.strip() for k in keys_str.split(",") if k.strip()]
75
-
76
- if cls._api_keys:
77
- logger.info(f"Loaded {len(cls._api_keys)} Gemini API key(s)")
78
-
79
- @classmethod
80
- def get_api_keys(cls) -> List[str]:
81
- """Get loaded API keys."""
82
- if cls._api_keys is None:
83
- cls._load_api_keys()
84
- return cls._api_keys or []
85
-
86
- @classmethod
87
- def get_key_count(cls) -> int:
88
- """Get number of available keys."""
89
- return len(cls.get_api_keys())
90
-
91
- @classmethod
92
- def get_config(cls) -> dict:
93
- """Get current configuration."""
94
- return {
95
- "key_count": cls.get_key_count(),
96
- "rotation_strategy": cls._rotation_strategy,
97
- "cooldown_seconds": cls._cooldown_seconds,
98
- "max_requests_per_minute": cls._max_requests_per_minute,
99
- "retry_on_quota_error": cls._retry_on_quota_error
100
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/gemini_service/api_key_middleware.py DELETED
@@ -1,180 +0,0 @@
1
- """
2
- API Key Middleware - Automatic key selection and rotation
3
-
4
- Automatically selects and injects Gemini API keys for requests.
5
- Handles quota errors with automatic key rotation and retry.
6
- """
7
- import time
8
- import logging
9
- from datetime import datetime, timedelta
10
- from typing import Optional, Dict
11
- from fastapi import Request, Response
12
- from starlette.middleware.base import BaseHTTPMiddleware
13
- from starlette.types import ASGIApp
14
-
15
- from core.database import async_session_maker
16
- from services.gemini_service.api_key_config import APIKeyServiceConfig
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- # Track key cooldowns in memory
22
- _key_cooldowns: Dict[int, datetime] = {}
23
-
24
-
25
- class APIKeyMiddleware(BaseHTTPMiddleware):
26
- """
27
- Middleware for automatic API key management.
28
-
29
- Features:
30
- - Automatic key selection based on strategy
31
- - Quota error detection and recovery
32
- - Key cooldown management
33
- - Usage tracking
34
- """
35
-
36
- def __init__(self, app: ASGIApp):
37
- super().__init__(app)
38
-
39
- async def dispatch(self, request: Request, call_next):
40
- """
41
- Process request with automatic API key injection.
42
-
43
- Flow:
44
- 1. Check if Gemini request
45
- 2. Select best available key
46
- 3. Inject into request state
47
- 4. Handle response (quota errors)
48
- """
49
- # Only handle Gemini requests
50
- if not self._is_gemini_request(request):
51
- return await call_next(request)
52
-
53
- # Select API key
54
- try:
55
- key_index, api_key = await self._select_api_key()
56
- request.state.gemini_api_key = api_key
57
- request.state.gemini_key_index = key_index
58
- except ValueError as e:
59
- # No keys available
60
- logger.error(f"No API keys available: {e}")
61
- return Response(
62
- content=f'{{"detail": "{str(e)}"}}',
63
- status_code=503,
64
- media_type="application/json"
65
- )
66
-
67
- # Process request
68
- response = await call_next(request)
69
-
70
- # Handle quota errors
71
- if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error:
72
- logger.warning(f"Quota error on key {key_index}, attempting retry")
73
-
74
- # Mark key in cooldown
75
- self._mark_cooldown(key_index)
76
-
77
- # Try to select different key
78
- try:
79
- key_index, api_key = await self._select_api_key(exclude_index=key_index)
80
- request.state.gemini_api_key = api_key
81
- request.state.gemini_key_index = key_index
82
-
83
- # Retry request
84
- logger.info(f"Retrying with key {key_index}")
85
- response = await call_next(request)
86
- except ValueError:
87
- # No other keys available
88
- logger.error("All API keys in cooldown or exhausted")
89
-
90
- # Track usage
91
- success = response.status_code < 400
92
- await self._track_usage(key_index, success, response.status_code)
93
-
94
- return response
95
-
96
- def _is_gemini_request(self, request: Request) -> bool:
97
- """Check if request is for Gemini service."""
98
- path = request.url.path
99
- gemini_paths = ["/gemini/", "/api/gemini"]
100
- return any(path.startswith(p) for p in gemini_paths)
101
-
102
- async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]:
103
- """
104
- Select best available API key.
105
-
106
- Args:
107
- exclude_index: Key index to exclude (e.g., after quota error)
108
-
109
- Returns:
110
- Tuple of (key_index, api_key)
111
-
112
- Raises:
113
- ValueError: If no keys available
114
- """
115
- keys = APIKeyServiceConfig.get_api_keys()
116
- if not keys:
117
- raise ValueError("No API keys configured")
118
-
119
- # Filter out excluded and cooldown keys
120
- available_indices = []
121
- for i in range(len(keys)):
122
- if i == exclude_index:
123
- continue
124
- if self._is_in_cooldown(i):
125
- continue
126
- available_indices.append(i)
127
-
128
- if not available_indices:
129
- raise ValueError("All API keys in cooldown")
130
-
131
- # Select based on strategy
132
- if APIKeyServiceConfig._rotation_strategy == "round_robin":
133
- # Simple round-robin
134
- selected_index = available_indices[0]
135
- else: # least_used
136
- # Get usage stats from DB
137
- async with async_session_maker() as db:
138
- from services.api_key_manager import get_least_used_key
139
- try:
140
- selected_index, _ = await get_least_used_key(db)
141
- if selected_index not in available_indices:
142
- # Fallback to first available
143
- selected_index = available_indices[0]
144
- except Exception as e:
145
- logger.error(f"Error getting least used key: {e}")
146
- selected_index = available_indices[0]
147
-
148
- logger.debug(f"Selected API key index {selected_index}")
149
- return selected_index, keys[selected_index]
150
-
151
- def _is_in_cooldown(self, key_index: int) -> bool:
152
- """Check if key is in cooldown period."""
153
- if key_index not in _key_cooldowns:
154
- return False
155
-
156
- cooldown_until = _key_cooldowns[key_index]
157
- if datetime.utcnow() > cooldown_until:
158
- # Cooldown expired
159
- del _key_cooldowns[key_index]
160
- return False
161
-
162
- return True
163
-
164
- def _mark_cooldown(self, key_index: int):
165
- """Mark key as in cooldown."""
166
- cooldown_seconds = APIKeyServiceConfig._cooldown_seconds
167
- cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds)
168
- _key_cooldowns[key_index] = cooldown_until
169
- logger.info(f"Key {key_index} in cooldown until {cooldown_until}")
170
-
171
- async def _track_usage(self, key_index: int, success: bool, status_code: int):
172
- """Track API key usage."""
173
- try:
174
- async with async_session_maker() as db:
175
- from services.api_key_manager import record_usage
176
- error_message = f"HTTP {status_code}" if not success else None
177
- await record_usage(db, key_index, success, error_message)
178
- await db.commit()
179
- except Exception as e:
180
- logger.error(f"Failed to track usage: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/gemini_service/job_processor.py DELETED
@@ -1,378 +0,0 @@
1
- """
2
- Gemini Job Worker - Specific implementation using the modular PriorityWorkerPool.
3
-
4
- This file shows how to use the modular PriorityWorkerPool with Gemini-specific
5
- job processing logic.
6
- """
7
- import logging
8
- from datetime import datetime, timedelta
9
- from typing import Optional
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
-
12
- from core.database import DATABASE_URL
13
- from core.models import GeminiJob
14
- from services.priority_worker_pool import (
15
- PriorityWorkerPool,
16
- JobProcessor,
17
- WorkerConfig,
18
- get_interval_for_priority
19
- )
20
- from services.gemini_service.api_client import GeminiService
21
- from services.drive_service import DriveService
22
- import asyncio
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
- # Job type to priority mapping for Gemini jobs
27
- JOB_PRIORITY_MAP = {
28
- "text": "fast",
29
- "analyze": "fast",
30
- "animation_prompt": "fast",
31
- "image": "medium",
32
- "edit_image": "medium",
33
- "video": "slow"
34
- }
35
-
36
-
37
- def get_priority_for_job_type(job_type: str) -> str:
38
- """Get the priority tier for a Gemini job type."""
39
- return JOB_PRIORITY_MAP.get(job_type, "fast")
40
-
41
-
42
- class GeminiJobProcessor(JobProcessor[GeminiJob]):
43
- """Processes Gemini AI jobs (text, image, video generation) with round-robin API keys."""
44
-
45
- def __init__(self):
46
- self.drive_service = DriveService()
47
-
48
- async def _get_service_with_key(self, session: AsyncSession) -> tuple:
49
- """Get a GeminiService with the least-used API key."""
50
- from services.api_key_manager import get_least_used_key
51
- key_index, api_key = await get_least_used_key(session)
52
- return key_index, GeminiService(api_key=api_key)
53
-
54
- async def _record_usage(self, session: AsyncSession, key_index: int, success: bool, error_message: Optional[str] = None):
55
- """Record API key usage after request."""
56
- from services.api_key_manager import record_usage
57
- await record_usage(session, key_index, success, error_message)
58
-
59
- def _handle_error(self, job: GeminiJob, error: Exception, reset_to_queued: bool = False) -> tuple[bool, str]:
60
- """
61
- Handle job errors with retry logic.
62
-
63
- Args:
64
- job: The job object
65
- error: The exception raised
66
- reset_to_queued: Whether to reset status to 'queued' on retry (for process())
67
-
68
- Returns:
69
- Tuple of (success, error_message)
70
- success is False (since it's an error)
71
- error_message is the formatted error string
72
- """
73
- error_str = str(error)
74
- is_retryable = False
75
- log_msg = ""
76
-
77
- # Check for Rate Limit (429)
78
- if "429" in error_str or "ResourceExhausted" in error_str:
79
- is_retryable = True
80
- log_msg = f"Rate limit hit for job {job.job_id}"
81
-
82
- # Check for Auth/Billing errors (401, 403, API key not found, API key not valid, FAILED_PRECONDITION)
83
- elif "401" in error_str or "403" in error_str or "Unauthenticated" in error_str or "PermissionDenied" in error_str or "API key not found" in error_str or "API key not valid" in error_str or "FAILED_PRECONDITION" in error_str:
84
- is_retryable = True
85
- log_msg = f"Auth/Billing error for job {job.job_id}: {error_str}. Rescheduling to try different key."
86
-
87
- # Check for Server errors (500, 503, 504)
88
- elif "500" in error_str or "503" in error_str or "504" in error_str or "INTERNAL" in error_str or "UNAVAILABLE" in error_str or "DEADLINE_EXCEEDED" in error_str:
89
- is_retryable = True
90
- log_msg = f"Server error for job {job.job_id}: {error_str}"
91
-
92
- # Try to parse JSON error details if present
93
- try:
94
- import json
95
- import re
96
- # Look for JSON-like structure in error string
97
- json_match = re.search(r"(\{.*\})", error_str)
98
- if json_match:
99
- job.api_response = json.loads(json_match.group(1))
100
- else:
101
- job.api_response = {"error": error_str}
102
- except Exception:
103
- job.api_response = {"error": error_str}
104
-
105
- if is_retryable:
106
- logger.warning(f"{log_msg}. Rescheduling.")
107
- job.retry_count += 1
108
- config = WorkerConfig.from_env()
109
- # Use a longer delay for these errors (e.g., 30s)
110
- interval = 30
111
- job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
112
-
113
- if reset_to_queued:
114
- job.status = "queued"
115
-
116
- return False, f"Retryable error: {error_str}"
117
- else:
118
- logger.error(f"Error processing job {job.job_id}: {error}")
119
- job.status = "failed"
120
- job.error_message = str(error)
121
- job.completed_at = datetime.utcnow()
122
- return False, str(error)
123
-
124
- async def process(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
125
- """Start processing a new job with round-robin API key."""
126
- key_index, service = await self._get_service_with_key(session)
127
- input_data = job.input_data or {}
128
- success = False
129
- error_msg = None
130
-
131
- try:
132
- if job.job_type == "video":
133
- job = await self._start_video(job, session, service, input_data)
134
- success = True
135
- elif job.job_type == "image":
136
- job = await self._process_image(job, service, input_data)
137
- success = True
138
- elif job.job_type == "text":
139
- job = await self._process_text(job, service, input_data)
140
- success = True
141
- elif job.job_type == "analyze":
142
- job = await self._process_analyze(job, service, input_data)
143
- success = True
144
- elif job.job_type == "animation_prompt":
145
- job = await self._process_animation_prompt(job, service, input_data)
146
- success = True
147
- else:
148
- job.status = "failed"
149
- job.error_message = f"Unknown job type: {job.job_type}"
150
- job.completed_at = datetime.utcnow()
151
- error_msg = job.error_message
152
- except Exception as e:
153
- # Use helper for error handling
154
- # reset_to_queued=True because if we fail to start, we want to try starting again from scratch
155
- success, error_msg = self._handle_error(job, e, reset_to_queued=True)
156
-
157
- # Record usage
158
- await self._record_usage(session, key_index, success, error_msg)
159
-
160
- return job
161
-
162
- async def check_status(self, job: GeminiJob, session: AsyncSession) -> GeminiJob:
163
- """Check status of an in-progress job (video generation)."""
164
- if job.job_type != "video" or not job.third_party_id:
165
- job.status = "failed"
166
- job.error_message = "Invalid job state for status check"
167
- job.completed_at = datetime.utcnow()
168
- return job
169
-
170
- # Use round-robin key for status check
171
- key_index, service = await self._get_service_with_key(session)
172
- success = False
173
- error_msg = None
174
-
175
- try:
176
- status_result = await service.check_video_status(job.third_party_id)
177
- # Save raw response
178
- job.api_response = status_result
179
-
180
- if status_result.get("done"):
181
- if status_result.get("status") == "completed":
182
- video_url = status_result.get("video_url")
183
- if video_url:
184
- # Store video URL - download will happen on-demand when client requests
185
- job.status = "completed"
186
- job.output_data = {"video_url": video_url}
187
- job.error_message = None # Clear any previous error
188
- job.completed_at = datetime.utcnow()
189
- success = True
190
- # Sync DB on success
191
- from services.backup_service import get_backup_service
192
- backup_service = get_backup_service()
193
- await backup_service.backup_async()
194
- else:
195
- job.status = "failed"
196
- job.error_message = "No video URL returned"
197
- job.completed_at = datetime.utcnow()
198
- error_msg = job.error_message
199
- else:
200
- job.status = "failed"
201
- job.error_message = status_result.get("error", "Unknown error")
202
- job.completed_at = datetime.utcnow()
203
- error_msg = job.error_message
204
- else:
205
- # Not done - reschedule
206
- job.retry_count += 1
207
- config = WorkerConfig.from_env()
208
- interval = get_interval_for_priority(job.priority, config)
209
- job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
210
- success = True # Status check succeeded even if video not ready
211
-
212
- except Exception as e:
213
- # Use helper for error handling
214
- # reset_to_queued=False because we want to continue checking status, not restart
215
- success, error_msg = self._handle_error(job, e, reset_to_queued=False)
216
-
217
- # Record usage
218
- await self._record_usage(session, key_index, success, error_msg)
219
-
220
- return job
221
-
222
- # Record usage
223
- await self._record_usage(session, key_index, success, error_msg)
224
-
225
- return job
226
-
227
- async def _start_video(self, job: GeminiJob, session: AsyncSession, service: GeminiService, input_data: dict) -> GeminiJob:
228
- """Start async video generation."""
229
- prompt = input_data.get("prompt", "")
230
-
231
- # If prompt is missing, generate one using the animation template
232
- if not prompt:
233
- try:
234
- import os
235
- template_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "prompt", "animation.md")
236
- if os.path.exists(template_path):
237
- with open(template_path, "r") as f:
238
- template_prompt = f.read().strip()
239
-
240
- logger.info(f"Generating auto-prompt for job {job.job_id} using template")
241
- prompt = await service.generate_animation_prompt(
242
- base64_image=input_data.get("base64_image", ""),
243
- mime_type=input_data.get("mime_type", "image/jpeg"),
244
- custom_prompt=template_prompt
245
- )
246
- logger.info(f"Generated prompt for job {job.job_id}: {prompt}")
247
-
248
- # Update input data with generated prompt for reference
249
- # Create a new dictionary to ensure SQLAlchemy detects the change
250
- new_input_data = dict(input_data)
251
- new_input_data["prompt"] = prompt
252
- job.input_data = new_input_data
253
- # We need to commit this change to DB so it persists
254
- # But session commit happens outside this method usually?
255
- # Actually process() calls this, and process() returns job,
256
- # but doesn't explicitly commit job changes until later?
257
- # The worker loop commits after process() returns.
258
- else:
259
- logger.warning(f"Animation prompt template not found at {template_path}")
260
- except Exception as e:
261
- logger.error(f"Failed to generate auto-prompt: {e}")
262
- # Fallback to empty prompt or error?
263
- # Let's let it proceed with empty prompt which might fail at API level or use API default
264
-
265
- result = await service.start_video_generation(
266
- base64_image=input_data.get("base64_image", ""),
267
- mime_type=input_data.get("mime_type", "image/jpeg"),
268
- prompt=prompt,
269
- aspect_ratio=input_data.get("aspect_ratio", "16:9"),
270
- resolution=input_data.get("resolution", "720p"),
271
- number_of_videos=input_data.get("number_of_videos", 1)
272
- )
273
- job.third_party_id = result.get("gemini_operation_name")
274
- job.api_response = result
275
-
276
- # Schedule first status check
277
- config = WorkerConfig.from_env()
278
- interval = get_interval_for_priority(job.priority, config)
279
- job.next_process_at = datetime.utcnow() + timedelta(seconds=interval)
280
-
281
- return job
282
-
283
- async def _process_image(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
284
- """Process image edit (synchronous)."""
285
- result = await service.edit_image(
286
- base64_image=input_data.get("base64_image", ""),
287
- mime_type=input_data.get("mime_type", "image/jpeg"),
288
- prompt=input_data.get("prompt", "")
289
- )
290
- job.status = "completed"
291
- job.output_data = {"image": result}
292
- # Don't save full base64 image to api_response
293
- job.api_response = {"status": "success", "type": "image_edit"}
294
- job.completed_at = datetime.utcnow()
295
- # Sync DB on success
296
- from services.backup_service import get_backup_service
297
- backup_service = get_backup_service()
298
- await backup_service.backup_async()
299
- return job
300
-
301
- async def _process_text(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
302
- """Process text generation (synchronous)."""
303
- result = await service.generate_text(
304
- prompt=input_data.get("prompt", ""),
305
- model=input_data.get("model")
306
- )
307
- job.status = "completed"
308
- job.output_data = {"text": result}
309
- job.api_response = {"result": result}
310
- job.completed_at = datetime.utcnow()
311
- # Sync DB on success
312
- from services.backup_service import get_backup_service
313
- backup_service = get_backup_service()
314
- await backup_service.backup_async()
315
- return job
316
-
317
- async def _process_analyze(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
318
- """Process image analysis (synchronous)."""
319
- result = await service.analyze_image(
320
- base64_image=input_data.get("base64_image", ""),
321
- mime_type=input_data.get("mime_type", "image/jpeg"),
322
- prompt=input_data.get("prompt", "")
323
- )
324
- job.status = "completed"
325
- job.output_data = {"analysis": result}
326
- job.api_response = {"result": result}
327
- job.completed_at = datetime.utcnow()
328
- # Sync DB on success
329
- from services.backup_service import get_backup_service
330
- backup_service = get_backup_service()
331
- await backup_service.backup_async()
332
- return job
333
-
334
- async def _process_animation_prompt(self, job: GeminiJob, service: GeminiService, input_data: dict) -> GeminiJob:
335
- """Process animation prompt generation (synchronous)."""
336
- result = await service.generate_animation_prompt(
337
- base64_image=input_data.get("base64_image", ""),
338
- mime_type=input_data.get("mime_type", "image/jpeg"),
339
- custom_prompt=input_data.get("custom_prompt")
340
- )
341
- job.status = "completed"
342
- job.output_data = {"prompt": result}
343
- job.api_response = {"result": result}
344
- job.completed_at = datetime.utcnow()
345
- # Sync DB on success
346
- from services.backup_service import get_backup_service
347
- backup_service = get_backup_service()
348
- await backup_service.backup_async()
349
- return job
350
-
351
-
352
- # Singleton pool instance
353
- _pool: Optional[PriorityWorkerPool] = None
354
-
355
-
356
- def get_pool() -> PriorityWorkerPool:
357
- """Get the global Gemini worker pool instance."""
358
- global _pool
359
- if _pool is None:
360
- _pool = PriorityWorkerPool(
361
- database_url=DATABASE_URL,
362
- job_model=GeminiJob,
363
- job_processor=GeminiJobProcessor(),
364
- config=WorkerConfig.from_env()
365
- )
366
- return _pool
367
-
368
-
369
- async def start_worker():
370
- """Start the Gemini job worker pool."""
371
- pool = get_pool()
372
- await pool.start()
373
-
374
-
375
- async def stop_worker():
376
- """Stop the Gemini job worker pool."""
377
- pool = get_pool()
378
- await pool.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/priority_worker_pool.py DELETED
@@ -1,547 +0,0 @@
1
- """
2
- Modular Priority-Tier Worker Pool
3
-
4
- A self-contained, plug-and-play worker pool for processing async jobs
5
- with priority-tier scheduling. Can be used in any Python application.
6
-
7
- Usage:
8
- from services.priority_worker_pool import PriorityWorkerPool, WorkerConfig
9
-
10
- # Define your job processor function
11
- async def process_my_job(job, session):
12
- # Process job and return updated job
13
- job.status = "completed"
14
- job.output_data = {"result": "done"}
15
- return job
16
-
17
- # Configure and start pool
18
- pool = PriorityWorkerPool(
19
- database_url="sqlite+aiosqlite:///./my_db.db",
20
- job_model=MyJobModel,
21
- job_processor=process_my_job,
22
- config=WorkerConfig(fast_workers=5, medium_workers=5, slow_workers=5)
23
- )
24
- await pool.start()
25
-
26
- Environment Variables (optional):
27
- FAST_WORKERS: Number of fast workers (default: 5)
28
- MEDIUM_WORKERS: Number of medium workers (default: 5)
29
- SLOW_WORKERS: Number of slow workers (default: 5)
30
- FAST_INTERVAL: Fast tier polling interval in seconds (default: 5)
31
- MEDIUM_INTERVAL: Medium tier polling interval in seconds (default: 30)
32
- SLOW_INTERVAL: Slow tier polling interval in seconds (default: 60)
33
-
34
- Dependencies:
35
- sqlalchemy[asyncio]>=2.0.0
36
- aiosqlite (for SQLite) or asyncpg (for PostgreSQL)
37
-
38
- Job Model Requirements:
39
- Your job model must have these columns:
40
- - job_id: str (unique identifier)
41
- - status: str (queued, processing, completed, failed, cancelled)
42
- - priority: str (fast, medium, slow)
43
- - next_process_at: datetime (nullable, for rescheduling)
44
- - retry_count: int (default 0)
45
- - created_at: datetime
46
- - started_at: datetime (nullable)
47
- - completed_at: datetime (nullable)
48
- - error_message: str (nullable)
49
- """
50
- import asyncio
51
- import logging
52
- import os
53
- from abc import ABC, abstractmethod
54
- from dataclasses import dataclass, field
55
- from datetime import datetime, timedelta
56
- from typing import Optional, List, Callable, Any, TypeVar, Generic
57
- from sqlalchemy import select, or_, and_
58
- from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
59
-
60
- logger = logging.getLogger(__name__)
61
-
62
- # Generic type for job model
63
- JobType = TypeVar('JobType')
64
-
65
-
66
- @dataclass
67
- class WorkerConfig:
68
- """Configuration for the worker pool."""
69
- fast_workers: int = 5
70
- medium_workers: int = 5
71
- slow_workers: int = 5
72
- fast_interval: int = 2 # seconds
73
- medium_interval: int = 10 # seconds
74
- slow_interval: int = 15 # seconds
75
- max_retries: int = 60 # Max retry attempts before failing
76
- job_per_api_key: int = 1 # Max concurrent jobs per API key
77
-
78
- @classmethod
79
- def from_env(cls) -> 'WorkerConfig':
80
- """Create config from environment variables."""
81
- return cls(
82
- fast_workers=int(os.getenv("FAST_WORKERS", "5")),
83
- medium_workers=int(os.getenv("MEDIUM_WORKERS", "5")),
84
- slow_workers=int(os.getenv("SLOW_WORKERS", "5")),
85
- fast_interval=int(os.getenv("FAST_INTERVAL", "5")),
86
- medium_interval=int(os.getenv("MEDIUM_INTERVAL", "30")),
87
- slow_interval=int(os.getenv("SLOW_INTERVAL", "60")),
88
- job_per_api_key=int(os.getenv("JOB_PER_API_KEY", "1")),
89
- )
90
-
91
-
92
- @dataclass
93
- class PriorityMapping:
94
- """Maps job types to priority tiers."""
95
- mappings: dict = field(default_factory=dict)
96
-
97
- def get_priority(self, job_type: str, default: str = "fast") -> str:
98
- """Get priority for a job type."""
99
- return self.mappings.get(job_type, default)
100
-
101
- def get_interval(self, priority: str, config: WorkerConfig) -> int:
102
- """Get polling interval for a priority tier."""
103
- if priority == "fast":
104
- return config.fast_interval
105
- elif priority == "medium":
106
- return config.medium_interval
107
- else:
108
- return config.slow_interval
109
-
110
-
111
- class JobProcessor(ABC, Generic[JobType]):
112
- """Abstract base class for job processors."""
113
-
114
- @abstractmethod
115
- async def process(self, job: JobType, session: AsyncSession) -> JobType:
116
- """
117
- Process a job and return the updated job.
118
-
119
- Args:
120
- job: The job to process
121
- session: Database session for updates
122
-
123
- Returns:
124
- The updated job with new status/output
125
- """
126
- pass
127
-
128
- @abstractmethod
129
- async def check_status(self, job: JobType, session: AsyncSession) -> JobType:
130
- """
131
- Check status of an in-progress job (for async third-party operations).
132
-
133
- Args:
134
- job: The job to check
135
- session: Database session for updates
136
-
137
- Returns:
138
- The updated job. Set next_process_at to reschedule if not done.
139
- """
140
- pass
141
-
142
-
143
- class PriorityWorker(Generic[JobType]):
144
- """Worker that processes jobs of a specific priority tier."""
145
-
146
- def __init__(
147
- self,
148
- worker_id: int,
149
- priority: str,
150
- poll_interval: int,
151
- session_maker: async_sessionmaker,
152
- job_model: type,
153
- job_processor: JobProcessor[JobType],
154
- max_retries: int = 60,
155
- wake_event: Optional[asyncio.Event] = None,
156
- config: Optional[WorkerConfig] = None
157
- ):
158
- self.worker_id = worker_id
159
- self.priority = priority
160
- self.poll_interval = poll_interval
161
- self.session_maker = session_maker
162
- self.job_model = job_model
163
- self.job_processor = job_processor
164
- self.max_retries = max_retries
165
- self._running = False
166
- self._current_job_id: Optional[str] = None
167
- self._wake_event = wake_event # Event to wake worker immediately when new jobs arrive
168
- self._config = config or WorkerConfig.from_env()
169
-
170
- async def start(self):
171
- """Start the worker polling loop."""
172
- self._running = True
173
- logger.debug(f"Worker {self.worker_id} ({self.priority}) started, polling every {self.poll_interval}s")
174
- asyncio.create_task(self._poll_loop())
175
-
176
- async def stop(self):
177
- """Stop the worker."""
178
- self._running = False
179
- logger.info(f"Worker {self.worker_id} ({self.priority}) stopped")
180
-
181
- async def _poll_loop(self):
182
- """Main polling loop with optimized scheduling.
183
-
184
- Optimizations:
185
- - When no jobs are found, sleep for poll_interval before checking again
186
- - When a job is processed, immediately check for the next job (no waiting)
187
- - This ensures first job starts immediately when queue was empty
188
- - This ensures next job starts immediately after current job finishes
189
- """
190
- while self._running:
191
- job_found = False
192
- try:
193
- job_found = await self._process_one_job()
194
- except Exception as e:
195
- logger.error(f"Worker {self.worker_id}: Error in poll loop: {e}")
196
-
197
- # Only sleep if no job was found - otherwise immediately look for next job
198
- if not job_found:
199
- # Wait on event with timeout - allows immediate wake-up when new job arrives
200
- if self._wake_event:
201
- try:
202
- # Wait for event or timeout (whichever comes first)
203
- await asyncio.wait_for(
204
- self._wake_event.wait(),
205
- timeout=self.poll_interval
206
- )
207
- # Clear event after waking (we'll check for jobs)
208
- self._wake_event.clear()
209
- except asyncio.TimeoutError:
210
- pass # Normal timeout, check for jobs
211
- else:
212
- await asyncio.sleep(self.poll_interval)
213
-
214
- async def _process_one_job(self) -> bool:
215
- """Find and process one job.
216
-
217
- Enforces constraints:
218
- 1. Only one job per user can be in processing state at a time
219
- 2. Total processing jobs limited to JOB_PER_API_KEY * number of API keys
220
-
221
- Returns:
222
- True if a job was found and processed, False if no jobs available
223
- """
224
- async with self.session_maker() as session:
225
- from sqlalchemy import func
226
-
227
- now = datetime.utcnow()
228
-
229
- # Get number of API keys for capacity calculation
230
- try:
231
- from services.api_key_manager import get_key_count
232
- num_api_keys = get_key_count()
233
- max_processing = self._config.job_per_api_key * num_api_keys
234
- except ImportError:
235
- max_processing = 10 # Default fallback
236
-
237
- # Check if we're at max processing capacity (only for new jobs being picked up)
238
- count_query = select(func.count()).where(
239
- self.job_model.status == "processing"
240
- )
241
- count_result = await session.execute(count_query)
242
- current_processing = count_result.scalar() or 0
243
-
244
- # Query for jobs matching this priority tier
245
- query = select(self.job_model).where(
246
- and_(
247
- self.job_model.priority == self.priority,
248
- self.job_model.status.in_(["queued", "processing"]),
249
- or_(
250
- self.job_model.next_process_at.is_(None),
251
- self.job_model.next_process_at <= now
252
- )
253
- )
254
- ).order_by(self.job_model.created_at).limit(1)
255
-
256
- result = await session.execute(query)
257
- job = result.scalar_one_or_none()
258
-
259
- if not job:
260
- return False
261
-
262
- # For queued jobs, apply the constraints
263
- if job.status == "queued":
264
- # Constraint 1: Check if this user already has a job in processing
265
- user_processing_query = select(func.count()).where(
266
- and_(
267
- self.job_model.user_id == job.user_id,
268
- self.job_model.status == "processing"
269
- )
270
- )
271
- user_result = await session.execute(user_processing_query)
272
- user_processing_count = user_result.scalar() or 0
273
-
274
- if user_processing_count > 0:
275
- logger.debug(f"Worker {self.worker_id}: User {job.user_id} already has a job processing, skipping")
276
- return False
277
-
278
- # Constraint 2: Check if we're at max total processing capacity
279
- if current_processing >= max_processing:
280
- logger.debug(f"Worker {self.worker_id}: At max capacity ({current_processing}/{max_processing}), skipping new job")
281
- return False
282
-
283
- self._current_job_id = job.job_id
284
-
285
- try:
286
- await self._process_job(session, job)
287
- return True
288
- except Exception as e:
289
- logger.error(f"Worker {self.worker_id}: Error processing job {job.job_id}: {e}")
290
- job.status = "failed"
291
- job.error_message = str(e)
292
- job.completed_at = datetime.utcnow()
293
- await session.commit()
294
- return True # Job was found, even though it failed
295
- finally:
296
- self._current_job_id = None
297
-
298
- async def _process_job(self, session: AsyncSession, job: JobType):
299
- """Process a single job."""
300
- logger.info(f"Worker {self.worker_id}: Processing job {job.job_id} (status: {job.status})")
301
-
302
- from sqlalchemy import update
303
-
304
- if job.status == "queued":
305
- # New job - try to claim it atomically
306
- # Set next_process_at to future to prevent others from picking it up while we process
307
- next_check = datetime.utcnow() + timedelta(seconds=self.poll_interval * 2)
308
-
309
- stmt = (
310
- update(self.job_model)
311
- .where(
312
- self.job_model.job_id == job.job_id,
313
- self.job_model.status == "queued"
314
- )
315
- .values(
316
- status="processing",
317
- started_at=datetime.utcnow(),
318
- next_process_at=next_check
319
- )
320
- )
321
- result = await session.execute(stmt)
322
- await session.commit()
323
-
324
- if result.rowcount == 0:
325
- logger.info(f"Worker {self.worker_id}: Failed to claim job {job.job_id} (already taken)")
326
- return
327
-
328
- # We claimed it. Refresh and process.
329
- await session.refresh(job)
330
- job = await self.job_processor.process(job, session)
331
-
332
- else:
333
- # Already processing - try to claim for status check
334
- # Ensure we only pick it up if next_process_at matches (or is null/past)
335
- # But the SELECT already filtered for that.
336
- # We just need to ensure no one else grabbed it between SELECT and UPDATE.
337
-
338
- # Update next_process_at to future to lock it for this check
339
- next_check = datetime.utcnow() + timedelta(seconds=self.poll_interval * 2)
340
-
341
- stmt = (
342
- update(self.job_model)
343
- .where(
344
- self.job_model.job_id == job.job_id,
345
- or_(
346
- self.job_model.next_process_at.is_(None),
347
- self.job_model.next_process_at <= datetime.utcnow()
348
- )
349
- )
350
- .values(next_process_at=next_check)
351
- )
352
- result = await session.execute(stmt)
353
- await session.commit()
354
-
355
- if result.rowcount == 0:
356
- logger.info(f"Worker {self.worker_id}: Failed to claim job {job.job_id} for check (already taken)")
357
- return
358
-
359
- await session.refresh(job)
360
- job = await self.job_processor.check_status(job, session)
361
-
362
- # Handle retry limit
363
- if job.status == "processing" and job.retry_count > self.max_retries:
364
- job.status = "failed"
365
- job.error_message = f"Max retries ({self.max_retries}) exceeded"
366
- job.completed_at = datetime.utcnow()
367
-
368
- # Handle credit finalization for jobs with reserved credits
369
- if job.status in ("completed", "failed", "cancelled"):
370
- await self._handle_job_credits(session, job)
371
-
372
- await session.commit()
373
-
374
- async def _handle_job_credits(self, session: AsyncSession, job: JobType):
375
- """Handle credit finalization when job reaches terminal state."""
376
- # Check if job has credits_reserved attribute (credit-enabled jobs)
377
- if not hasattr(job, 'credits_reserved') or job.credits_reserved <= 0:
378
- return
379
-
380
- try:
381
- from services.credit_service.credit_manager import handle_job_completion
382
- await handle_job_completion(session, job)
383
- except ImportError:
384
- # Credit service not available - skip
385
- logger.debug(f"Credit service not available for job {job.job_id}")
386
- except Exception as e:
387
- logger.error(f"Error handling credits for job {job.job_id}: {e}")
388
-
389
-
390
- class PriorityWorkerPool(Generic[JobType]):
391
- """
392
- Modular priority-tier worker pool.
393
-
394
- Can be used with any job model that follows the required schema.
395
- """
396
-
397
- def __init__(
398
- self,
399
- database_url: str,
400
- job_model: type,
401
- job_processor: JobProcessor[JobType],
402
- config: Optional[WorkerConfig] = None
403
- ):
404
- """
405
- Initialize the worker pool.
406
-
407
- Args:
408
- database_url: SQLAlchemy async database URL
409
- job_model: Your ORM model class for jobs
410
- job_processor: Instance of JobProcessor to handle jobs
411
- config: Worker configuration (uses env vars if not provided)
412
- """
413
- self.database_url = database_url
414
- self.job_model = job_model
415
- self.job_processor = job_processor
416
- self.config = config or WorkerConfig.from_env()
417
-
418
- self.engine = create_async_engine(database_url, echo=False)
419
- self.session_maker = async_sessionmaker(
420
- self.engine,
421
- class_=AsyncSession,
422
- expire_on_commit=False
423
- )
424
- self.workers: List[PriorityWorker] = []
425
- self._running = False
426
-
427
- # Wake events for each priority tier - allows immediate job notification
428
- self._wake_events: dict[str, asyncio.Event] = {
429
- "fast": asyncio.Event(),
430
- "medium": asyncio.Event(),
431
- "slow": asyncio.Event()
432
- }
433
-
434
- async def start(self):
435
- """Start all workers."""
436
- self._running = True
437
- worker_id = 0
438
-
439
- # Create fast workers
440
- for i in range(self.config.fast_workers):
441
- worker = PriorityWorker(
442
- worker_id=worker_id,
443
- priority="fast",
444
- poll_interval=self.config.fast_interval,
445
- session_maker=self.session_maker,
446
- job_model=self.job_model,
447
- job_processor=self.job_processor,
448
- max_retries=self.config.max_retries,
449
- wake_event=self._wake_events["fast"],
450
- config=self.config
451
- )
452
- self.workers.append(worker)
453
- await worker.start()
454
- worker_id += 1
455
-
456
- # Create medium workers
457
- for i in range(self.config.medium_workers):
458
- worker = PriorityWorker(
459
- worker_id=worker_id,
460
- priority="medium",
461
- poll_interval=self.config.medium_interval,
462
- session_maker=self.session_maker,
463
- job_model=self.job_model,
464
- job_processor=self.job_processor,
465
- max_retries=self.config.max_retries,
466
- wake_event=self._wake_events["medium"],
467
- config=self.config
468
- )
469
- self.workers.append(worker)
470
- await worker.start()
471
- worker_id += 1
472
-
473
- # Create slow workers
474
- for i in range(self.config.slow_workers):
475
- worker = PriorityWorker(
476
- worker_id=worker_id,
477
- priority="slow",
478
- poll_interval=self.config.slow_interval,
479
- session_maker=self.session_maker,
480
- job_model=self.job_model,
481
- job_processor=self.job_processor,
482
- max_retries=self.config.max_retries,
483
- wake_event=self._wake_events["slow"],
484
- config=self.config
485
- )
486
- self.workers.append(worker)
487
- await worker.start()
488
- worker_id += 1
489
-
490
- total = self.config.fast_workers + self.config.medium_workers + self.config.slow_workers
491
- logger.info(
492
- f"PriorityWorkerPool started with {total} workers: "
493
- f"{self.config.fast_workers} fast, {self.config.medium_workers} medium, {self.config.slow_workers} slow"
494
- )
495
-
496
- def notify_new_job(self, priority: str):
497
- """
498
- Wake sleeping workers of the specified priority tier.
499
- Call this when a new job is created to start processing immediately.
500
-
501
- Args:
502
- priority: Priority tier ("fast", "medium", or "slow")
503
- """
504
- if priority in self._wake_events:
505
- self._wake_events[priority].set()
506
- logger.debug(f"Notified {priority} workers of new job")
507
-
508
- async def stop(self):
509
- """Stop all workers and refund orphaned jobs."""
510
- self._running = False
511
-
512
- # Refund credits for any jobs that were processing when server stopped
513
- await self._refund_orphaned_jobs()
514
-
515
- for worker in self.workers:
516
- await worker.stop()
517
- logger.info("PriorityWorkerPool stopped")
518
-
519
- async def _refund_orphaned_jobs(self):
520
- """Refund credits for jobs abandoned during shutdown."""
521
- try:
522
- from services.credit_service.credit_manager import refund_orphaned_jobs
523
- async with self.session_maker() as session:
524
- refund_count = await refund_orphaned_jobs(session)
525
- if refund_count > 0:
526
- logger.info(f"Shutdown: Refunded {refund_count} orphaned job(s)")
527
- except ImportError:
528
- logger.debug("Credit service not available for orphaned job refunds")
529
- except Exception as e:
530
- logger.error(f"Error refunding orphaned jobs during shutdown: {e}")
531
-
532
-
533
- # Convenience functions for priority mapping
534
- def get_priority_for_job_type(job_type: str, mappings: dict) -> str:
535
- """Get priority tier for a job type using provided mappings."""
536
- return mappings.get(job_type, "fast")
537
-
538
-
539
- def get_interval_for_priority(priority: str, config: Optional[WorkerConfig] = None) -> int:
540
- """Get polling interval for a priority tier."""
541
- cfg = config or WorkerConfig.from_env()
542
- if priority == "fast":
543
- return cfg.fast_interval
544
- elif priority == "medium":
545
- return cfg.medium_interval
546
- else:
547
- return cfg.slow_interval
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/debug_gemini_service.py DELETED
@@ -1,165 +0,0 @@
1
- """
2
- Debug script to test Gemini service with API keys from environment.
3
- Keys should be in GEMINI_KEYS environment variable, comma-separated.
4
-
5
- Usage:
6
- GEMINI_KEYS="key1,key2,key3" python tests/debug_gemini_service.py
7
- """
8
- import os
9
- import sys
10
- import asyncio
11
- import logging
12
- import base64
13
- from dotenv import load_dotenv
14
-
15
- # Load environment variables
16
- load_dotenv()
17
-
18
- # Add parent directory to path
19
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
20
-
21
- from services.gemini_service import GeminiService, MODELS
22
-
23
- # Configure logging
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- # Test image path
28
- TEST_IMAGE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test.jpg")
29
-
30
-
31
- def load_test_image():
32
- """Load test image and return base64 + mime type."""
33
- if not os.path.exists(TEST_IMAGE_PATH):
34
- logger.error(f"Test image not found: {TEST_IMAGE_PATH}")
35
- return None, None
36
-
37
- with open(TEST_IMAGE_PATH, "rb") as f:
38
- image_data = f.read()
39
-
40
- base64_image = base64.b64encode(image_data).decode("utf-8")
41
- mime_type = "image/jpeg"
42
-
43
- logger.info(f"Loaded test image: {TEST_IMAGE_PATH} ({len(image_data)} bytes)")
44
- return base64_image, mime_type
45
-
46
-
47
- async def test_generate_text(service: GeminiService, key_index: int):
48
- """Test simple text generation."""
49
- logger.info(f"[Key {key_index}] Testing text generation...")
50
- try:
51
- result = await service.generate_text("Say hello in one word.")
52
- logger.info(f"[Key {key_index}] Text generation result: {result[:100]}...")
53
- return True
54
- except Exception as e:
55
- logger.error(f"[Key {key_index}] Text generation failed: {e}")
56
- return False
57
-
58
-
59
- async def test_analyze_image(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
60
- """Test image analysis."""
61
- logger.info(f"[Key {key_index}] Testing image analysis...")
62
- try:
63
- result = await service.analyze_image(
64
- base64_image=base64_image,
65
- mime_type=mime_type,
66
- prompt="Describe this image in one sentence."
67
- )
68
- logger.info(f"[Key {key_index}] Image analysis result: {result[:100]}...")
69
- return True
70
- except Exception as e:
71
- logger.error(f"[Key {key_index}] Image analysis failed: {e}")
72
- return False
73
-
74
-
75
- async def test_generate_animation_prompt(service: GeminiService, key_index: int, base64_image: str, mime_type: str):
76
- """Test animation prompt generation."""
77
- logger.info(f"[Key {key_index}] Testing animation prompt generation...")
78
- try:
79
- result = await service.generate_animation_prompt(
80
- base64_image=base64_image,
81
- mime_type=mime_type
82
- )
83
- logger.info(f"[Key {key_index}] Animation prompt result: {result[:100]}...")
84
- return True
85
- except Exception as e:
86
- logger.error(f"[Key {key_index}] Animation prompt generation failed: {e}")
87
- return False
88
-
89
-
90
- async def test_key(api_key: str, key_index: int, base64_image: str, mime_type: str):
91
- """Test all basic operations with a single API key."""
92
- logger.info(f"\n{'='*50}")
93
- logger.info(f"Testing Key {key_index}: {api_key[:10]}...{api_key[-4:]}")
94
- logger.info(f"{'='*50}")
95
-
96
- try:
97
- service = GeminiService(api_key)
98
- except Exception as e:
99
- logger.error(f"[Key {key_index}] Failed to initialize service: {e}")
100
- return {"key_index": key_index, "valid": False, "error": str(e)}
101
-
102
- results = {
103
- "key_index": key_index,
104
- "key_preview": f"{api_key[:10]}...{api_key[-4:]}",
105
- "text_generation": await test_generate_text(service, key_index),
106
- "image_analysis": await test_analyze_image(service, key_index, base64_image, mime_type),
107
- "animation_prompt": await test_generate_animation_prompt(service, key_index, base64_image, mime_type),
108
- }
109
-
110
- results["valid"] = all([
111
- results["text_generation"],
112
- results["image_analysis"],
113
- results["animation_prompt"]
114
- ])
115
-
116
- return results
117
-
118
-
119
- async def main():
120
- # Load test image
121
- base64_image, mime_type = load_test_image()
122
- if not base64_image:
123
- logger.error("Cannot run tests without test image. Please add test.jpg to project root.")
124
- return
125
-
126
- gemini_keys_str = os.getenv("GEMINI_KEYS", "")
127
-
128
- if not gemini_keys_str:
129
- logger.error("GEMINI_KEYS environment variable not set.")
130
- logger.info("Usage: GEMINI_KEYS='key1,key2,key3' python tests/debug_gemini_service.py")
131
- return
132
-
133
- keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()]
134
-
135
- if not keys:
136
- logger.error("No valid keys found in GEMINI_KEYS.")
137
- return
138
-
139
- logger.info(f"Found {len(keys)} API key(s) to test.")
140
- logger.info(f"Available models: {MODELS}")
141
-
142
- all_results = []
143
- for i, key in enumerate(keys):
144
- result = await test_key(key, i + 1, base64_image, mime_type)
145
- all_results.append(result)
146
-
147
- # Summary
148
- logger.info(f"\n{'='*50}")
149
- logger.info("SUMMARY")
150
- logger.info(f"{'='*50}")
151
-
152
- valid_count = sum(1 for r in all_results if r.get("valid", False))
153
- logger.info(f"Valid keys: {valid_count}/{len(keys)}")
154
-
155
- for result in all_results:
156
- status = "✓ VALID" if result.get("valid") else "✗ INVALID"
157
- logger.info(f" Key {result['key_index']}: {status}")
158
- if not result.get("valid"):
159
- for test_name in ["text_generation", "image_analysis", "animation_prompt"]:
160
- if test_name in result and not result[test_name]:
161
- logger.info(f" - {test_name}: FAILED")
162
-
163
-
164
- if __name__ == "__main__":
165
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_fal_service.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for Fal.ai Service.
3
+
4
+ Tests cover:
5
+ 1. Initialization & API key handling
6
+ 2. Video generation
7
+ 3. Error handling
8
+ 4. Mock mode
9
+ """
10
+ import pytest
11
+ import asyncio
12
+ import os
13
+ from unittest.mock import patch, MagicMock, AsyncMock
14
+
15
+
16
+ # =============================================================================
17
+ # 1. Initialization & Configuration Tests
18
+ # =============================================================================
19
+
20
+ class TestFalServiceInit:
21
+ """Test FalService initialization and configuration."""
22
+
23
+ def test_init_with_explicit_api_key(self):
24
+ """Service initializes with explicit API key."""
25
+ with patch.dict(os.environ, {"FAL_KEY": "env-key"}):
26
+ from services.fal_service import FalService
27
+
28
+ service = FalService(api_key="test-key-123")
29
+
30
+ assert service.api_key == "test-key-123"
31
+
32
+ def test_init_with_env_fallback(self):
33
+ """Service falls back to environment variable for API key."""
34
+ with patch.dict(os.environ, {"FAL_KEY": "env-key-456"}):
35
+ from services.fal_service import FalService
36
+
37
+ service = FalService()
38
+
39
+ assert service.api_key == "env-key-456"
40
+
41
+ def test_init_fails_without_api_key(self):
42
+ """Service raises error when no API key available."""
43
+ with patch.dict(os.environ, {}, clear=True):
44
+ os.environ.pop("FAL_KEY", None)
45
+
46
+ from services.fal_service import get_fal_api_key
47
+
48
+ with pytest.raises(ValueError, match="FAL_KEY not configured"):
49
+ get_fal_api_key()
50
+
51
+ def test_models_dict_has_required_entries(self):
52
+ """MODELS dictionary has all required model names."""
53
+ from services.fal_service import MODELS
54
+
55
+ assert "video_generation" in MODELS
56
+ assert "veo3" in MODELS["video_generation"].lower() or "image-to-video" in MODELS["video_generation"]
57
+
58
+
59
+ # =============================================================================
60
+ # 2. Video Generation Tests
61
+ # =============================================================================
62
+
63
+ class TestFalVideoGeneration:
64
+ """Test video generation methods."""
65
+
66
+ @pytest.mark.asyncio
67
+ async def test_start_video_generation_mock_mode(self):
68
+ """Video generation works in mock mode."""
69
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
70
+ with patch('services.fal_service.api_client.MOCK_MODE', True):
71
+ from services.fal_service import FalService
72
+
73
+ service = FalService(api_key="test-key")
74
+ result = await service.start_video_generation(
75
+ base64_image="base64data",
76
+ mime_type="image/jpeg",
77
+ prompt="Animate this"
78
+ )
79
+
80
+ assert result["done"] is True
81
+ assert result["status"] == "completed"
82
+ assert "video_url" in result
83
+ assert "fal_request_id" in result
84
+
85
+ @pytest.mark.asyncio
86
+ async def test_start_video_generation_success(self):
87
+ """Video generation returns video URL on success."""
88
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
89
+ with patch('services.fal_service.api_client.MOCK_MODE', False):
90
+ with patch('services.fal_service.api_client.asyncio.to_thread') as mock_to_thread:
91
+ from services.fal_service import FalService
92
+
93
+ # Mock fal_client response
94
+ mock_result = {
95
+ "video": {"url": "https://fal.ai/video.mp4"},
96
+ "request_id": "req-123"
97
+ }
98
+ mock_to_thread.return_value = mock_result
99
+
100
+ service = FalService(api_key="test-key")
101
+ result = await service.start_video_generation(
102
+ base64_image="base64data",
103
+ mime_type="image/jpeg",
104
+ prompt="Animate this"
105
+ )
106
+
107
+ assert result["done"] is True
108
+ assert result["status"] == "completed"
109
+ assert result["video_url"] == "https://fal.ai/video.mp4"
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_start_video_generation_no_video_url(self):
113
+ """Video generation returns failed when no URL in response."""
114
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
115
+ with patch('services.fal_service.api_client.MOCK_MODE', False):
116
+ with patch('services.fal_service.api_client.asyncio.to_thread') as mock_to_thread:
117
+ from services.fal_service import FalService
118
+
119
+ # Mock response without video URL
120
+ mock_result = {"status": "error"}
121
+ mock_to_thread.return_value = mock_result
122
+
123
+ service = FalService(api_key="test-key")
124
+ result = await service.start_video_generation(
125
+ base64_image="base64data",
126
+ mime_type="image/jpeg",
127
+ prompt="Animate this"
128
+ )
129
+
130
+ assert result["done"] is True
131
+ assert result["status"] == "failed"
132
+ assert "error" in result
133
+
134
+ @pytest.mark.asyncio
135
+ async def test_start_video_generation_with_params(self):
136
+ """Video generation passes aspect_ratio and resolution."""
137
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
138
+ with patch('services.fal_service.api_client.MOCK_MODE', False):
139
+ with patch('services.fal_service.api_client.asyncio.to_thread') as mock_to_thread:
140
+ from services.fal_service import FalService
141
+
142
+ mock_result = {"video": {"url": "https://fal.ai/video.mp4"}}
143
+ mock_to_thread.return_value = mock_result
144
+
145
+ service = FalService(api_key="test-key")
146
+ await service.start_video_generation(
147
+ base64_image="base64data",
148
+ mime_type="image/jpeg",
149
+ prompt="Animate",
150
+ aspect_ratio="9:16",
151
+ resolution="720p"
152
+ )
153
+
154
+ # Verify arguments were passed
155
+ call_args = mock_to_thread.call_args
156
+ arguments = call_args.kwargs.get("arguments") or call_args[1].get("arguments")
157
+ assert arguments["aspect_ratio"] == "9:16"
158
+ assert arguments["resolution"] == "720p"
159
+
160
+
161
+ # =============================================================================
162
+ # 3. Error Handling Tests
163
+ # =============================================================================
164
+
165
+ class TestFalErrorHandling:
166
+ """Test error handling methods."""
167
+
168
+ def test_handle_api_error_401(self):
169
+ """_handle_api_error raises ValueError for 401."""
170
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
171
+ from services.fal_service import FalService
172
+
173
+ service = FalService(api_key="test-key")
174
+
175
+ with pytest.raises(ValueError, match="Authentication failed"):
176
+ service._handle_api_error(Exception("401 Unauthorized"), "test")
177
+
178
+ def test_handle_api_error_402(self):
179
+ """_handle_api_error raises ValueError for 402."""
180
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
181
+ from services.fal_service import FalService
182
+
183
+ service = FalService(api_key="test-key")
184
+
185
+ with pytest.raises(ValueError, match="Insufficient credits"):
186
+ service._handle_api_error(Exception("402 Payment Required"), "test")
187
+
188
+ def test_handle_api_error_429(self):
189
+ """_handle_api_error raises ValueError for 429."""
190
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
191
+ from services.fal_service import FalService
192
+
193
+ service = FalService(api_key="test-key")
194
+
195
+ with pytest.raises(ValueError, match="Rate limit"):
196
+ service._handle_api_error(Exception("429 Rate limit exceeded"), "test")
197
+
198
+ def test_handle_api_error_reraises_other(self):
199
+ """_handle_api_error re-raises non-handled errors."""
200
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
201
+ from services.fal_service import FalService
202
+
203
+ service = FalService(api_key="test-key")
204
+
205
+ with pytest.raises(RuntimeError, match="Connection timeout"):
206
+ service._handle_api_error(RuntimeError("Connection timeout"), "test")
207
+
208
+
209
+ # =============================================================================
210
+ # 4. Video Download Tests
211
+ # =============================================================================
212
+
213
+ class TestFalVideoDownload:
214
+ """Test download_video method."""
215
+
216
+ @pytest.mark.asyncio
217
+ async def test_download_video_saves_file(self):
218
+ """download_video saves file and returns filename."""
219
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
220
+ from services.fal_service import FalService
221
+
222
+ # Mock httpx client at module level
223
+ with patch('httpx.AsyncClient') as mock_client:
224
+ mock_response = MagicMock()
225
+ mock_response.content = b"fake video data"
226
+ mock_response.raise_for_status = MagicMock()
227
+
228
+ mock_client_instance = AsyncMock()
229
+ mock_client_instance.get.return_value = mock_response
230
+ mock_client_instance.__aenter__.return_value = mock_client_instance
231
+ mock_client_instance.__aexit__.return_value = None
232
+ mock_client.return_value = mock_client_instance
233
+
234
+ # Mock file operations
235
+ with patch('services.fal_service.api_client.os.makedirs'):
236
+ mock_file = MagicMock()
237
+ with patch('builtins.open', MagicMock(return_value=mock_file)):
238
+ mock_file.__enter__ = MagicMock(return_value=mock_file)
239
+ mock_file.__exit__ = MagicMock(return_value=False)
240
+
241
+ service = FalService(api_key="test-key")
242
+ result = await service.download_video(
243
+ "https://fal.ai/video.mp4",
244
+ "test-req-123"
245
+ )
246
+
247
+ assert result == "test-req-123.mp4"
248
+ mock_file.write.assert_called_once_with(b"fake video data")
249
+
250
+ @pytest.mark.asyncio
251
+ async def test_download_video_http_error(self):
252
+ """download_video raises error on HTTP failure."""
253
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
254
+ from services.fal_service import FalService
255
+
256
+ with patch('httpx.AsyncClient') as mock_client:
257
+ mock_client_instance = AsyncMock()
258
+ mock_client_instance.get.side_effect = Exception("Connection refused")
259
+ mock_client_instance.__aenter__.return_value = mock_client_instance
260
+ mock_client_instance.__aexit__.return_value = None
261
+ mock_client.return_value = mock_client_instance
262
+
263
+ service = FalService(api_key="test-key")
264
+
265
+ with pytest.raises(ValueError, match="Failed to download"):
266
+ await service.download_video(
267
+ "https://fal.ai/video.mp4",
268
+ "test-req-123"
269
+ )
270
+
271
+
272
+ # =============================================================================
273
+ # 5. Check Status Tests
274
+ # =============================================================================
275
+
276
+ class TestFalCheckStatus:
277
+ """Test check_video_status method."""
278
+
279
+ @pytest.mark.asyncio
280
+ async def test_check_status_returns_completed(self):
281
+ """check_video_status returns completed (fal.ai is sync)."""
282
+ with patch.dict(os.environ, {"FAL_KEY": "test-key"}):
283
+ from services.fal_service import FalService
284
+
285
+ service = FalService(api_key="test-key")
286
+ result = await service.check_video_status("req-123")
287
+
288
+ assert result["done"] is True
289
+ assert result["status"] == "completed"
290
+ assert result["fal_request_id"] == "req-123"
tests/test_gemini_service.py DELETED
@@ -1,814 +0,0 @@
1
- """
2
- Rigorous Tests for Gemini AI Service.
3
-
4
- Tests cover:
5
- 1. Initialization & API key handling
6
- 2. Concurrency semaphores
7
- 3. Text generation
8
- 4. Animation prompt generation
9
- 5. Image analysis & editing
10
- 6. Video generation, status checking, downloading
11
- 7. Error handling
12
- """
13
- import pytest
14
- import asyncio
15
- import os
16
- import tempfile
17
- from unittest.mock import patch, MagicMock, AsyncMock, PropertyMock
18
- from datetime import datetime
19
-
20
-
21
- # =============================================================================
22
- # 1. Initialization & Configuration Tests
23
- # =============================================================================
24
-
25
- class TestGeminiServiceInit:
26
- """Test GeminiService initialization and configuration."""
27
-
28
- def test_init_with_explicit_api_key(self):
29
- """Service initializes with explicit API key."""
30
- with patch('services.gemini_service.genai') as mock_genai:
31
- from services.gemini_service import GeminiService
32
-
33
- service = GeminiService(api_key="test-key-123")
34
-
35
- assert service.api_key == "test-key-123"
36
- mock_genai.Client.assert_called_once_with(api_key="test-key-123")
37
-
38
- def test_init_with_env_fallback(self):
39
- """Service falls back to environment variable for API key."""
40
- with patch('services.gemini_service.genai') as mock_genai:
41
- with patch.dict(os.environ, {"GEMINI_API_KEY": "env-key-456"}):
42
- from services.gemini_service import GeminiService
43
-
44
- service = GeminiService()
45
-
46
- assert service.api_key == "env-key-456"
47
-
48
- def test_init_fails_without_api_key(self):
49
- """Service raises error when no API key available."""
50
- with patch.dict(os.environ, {}, clear=True):
51
- # Remove GEMINI_API_KEY if present
52
- os.environ.pop("GEMINI_API_KEY", None)
53
- os.environ.pop("GEMINI_API_KEYS", None)
54
-
55
- from services.gemini_service import get_gemini_api_key
56
-
57
- with pytest.raises(ValueError, match="Server Authentication Error"):
58
- get_gemini_api_key()
59
-
60
- def test_models_dict_has_required_entries(self):
61
- """MODELS dictionary has all required model names."""
62
- from services.gemini_service import MODELS
63
-
64
- assert "text_generation" in MODELS
65
- assert "image_edit" in MODELS
66
- assert "video_generation" in MODELS
67
- assert all(isinstance(v, str) for v in MODELS.values())
68
-
69
-
70
- # =============================================================================
71
- # 2. Semaphore Concurrency Tests
72
- # =============================================================================
73
-
74
- class TestSemaphoreConcurrency:
75
- """Test concurrency control via semaphores."""
76
-
77
- def test_video_semaphore_respects_limit(self):
78
- """Video semaphore uses MAX_CONCURRENT_VIDEOS."""
79
- # Reset global
80
- import services.gemini_service as gs
81
- gs._video_semaphore = None
82
-
83
- with patch.object(gs, 'MAX_CONCURRENT_VIDEOS', 3):
84
- gs._video_semaphore = None # Reset
85
- sem = gs.get_video_semaphore()
86
- # Semaphore internal value
87
- assert sem._value == 3
88
-
89
- def test_image_semaphore_respects_limit(self):
90
- """Image semaphore uses MAX_CONCURRENT_IMAGES."""
91
- import services.gemini_service as gs
92
- gs._image_semaphore = None
93
-
94
- with patch.object(gs, 'MAX_CONCURRENT_IMAGES', 5):
95
- gs._image_semaphore = None
96
- sem = gs.get_image_semaphore()
97
- assert sem._value == 5
98
-
99
- def test_text_semaphore_respects_limit(self):
100
- """Text semaphore uses MAX_CONCURRENT_TEXT."""
101
- import services.gemini_service as gs
102
- gs._text_semaphore = None
103
-
104
- with patch.object(gs, 'MAX_CONCURRENT_TEXT', 10):
105
- gs._text_semaphore = None
106
- sem = gs.get_text_semaphore()
107
- assert sem._value == 10
108
-
109
- def test_semaphores_are_singletons(self):
110
- """Calling get_*_semaphore multiple times returns same object."""
111
- import services.gemini_service as gs
112
- gs._video_semaphore = None
113
- gs._image_semaphore = None
114
- gs._text_semaphore = None
115
-
116
- video1 = gs.get_video_semaphore()
117
- video2 = gs.get_video_semaphore()
118
- assert video1 is video2
119
-
120
- image1 = gs.get_image_semaphore()
121
- image2 = gs.get_image_semaphore()
122
- assert image1 is image2
123
-
124
- text1 = gs.get_text_semaphore()
125
- text2 = gs.get_text_semaphore()
126
- assert text1 is text2
127
-
128
-
129
- # =============================================================================
130
- # 3. Text Generation Tests
131
- # =============================================================================
132
-
133
- class TestTextGeneration:
134
- """Test generate_text method."""
135
-
136
- @pytest.mark.asyncio
137
- async def test_generate_text_success(self):
138
- """generate_text returns text on success."""
139
- with patch('services.gemini_service.genai') as mock_genai:
140
- from services.gemini_service import GeminiService
141
-
142
- # Mock response
143
- mock_response = MagicMock()
144
- mock_response.text = "Generated text response"
145
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
146
-
147
- service = GeminiService(api_key="test-key")
148
- result = await service.generate_text("Hello world")
149
-
150
- assert result == "Generated text response"
151
-
152
- @pytest.mark.asyncio
153
- async def test_generate_text_with_custom_model(self):
154
- """generate_text uses custom model when provided."""
155
- with patch('services.gemini_service.genai') as mock_genai:
156
- from services.gemini_service import GeminiService
157
-
158
- mock_response = MagicMock()
159
- mock_response.text = "Custom model response"
160
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
161
-
162
- service = GeminiService(api_key="test-key")
163
- result = await service.generate_text("Hello", model="custom-model")
164
-
165
- # Verify custom model was used
166
- call_args = mock_genai.Client.return_value.models.generate_content.call_args
167
- assert call_args.kwargs.get('model') == "custom-model"
168
-
169
- @pytest.mark.asyncio
170
- async def test_generate_text_empty_response(self):
171
- """generate_text returns empty string for None response."""
172
- with patch('services.gemini_service.genai') as mock_genai:
173
- from services.gemini_service import GeminiService
174
-
175
- mock_response = MagicMock()
176
- mock_response.text = None
177
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
178
-
179
- service = GeminiService(api_key="test-key")
180
- result = await service.generate_text("Hello")
181
-
182
- assert result == ""
183
-
184
- @pytest.mark.asyncio
185
- async def test_generate_text_api_error_404(self):
186
- """generate_text raises ValueError for 404 error."""
187
- with patch('services.gemini_service.genai') as mock_genai:
188
- from services.gemini_service import GeminiService
189
-
190
- mock_genai.Client.return_value.models.generate_content.side_effect = Exception("404 NOT_FOUND")
191
-
192
- service = GeminiService(api_key="test-key")
193
-
194
- with pytest.raises(ValueError, match="Model not found"):
195
- await service.generate_text("Hello")
196
-
197
-
198
- # =============================================================================
199
- # 4. Animation Prompt Tests
200
- # =============================================================================
201
-
202
- class TestAnimationPrompt:
203
- """Test generate_animation_prompt method."""
204
-
205
- @pytest.mark.asyncio
206
- async def test_generate_animation_prompt_default(self):
207
- """generate_animation_prompt uses default prompt."""
208
- with patch('services.gemini_service.genai') as mock_genai:
209
- with patch('services.gemini_service.types'):
210
- from services.gemini_service import GeminiService
211
-
212
- mock_response = MagicMock()
213
- mock_response.text = "Subtle zoom with camera pan"
214
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
215
-
216
- service = GeminiService(api_key="test-key")
217
- result = await service.generate_animation_prompt(
218
- base64_image="base64data",
219
- mime_type="image/jpeg"
220
- )
221
-
222
- assert result == "Subtle zoom with camera pan"
223
-
224
- @pytest.mark.asyncio
225
- async def test_generate_animation_prompt_custom(self):
226
- """generate_animation_prompt uses custom prompt when provided."""
227
- with patch('services.gemini_service.genai') as mock_genai:
228
- with patch('services.gemini_service.types'):
229
- from services.gemini_service import GeminiService
230
-
231
- mock_response = MagicMock()
232
- mock_response.text = "Custom animation"
233
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
234
-
235
- service = GeminiService(api_key="test-key")
236
- result = await service.generate_animation_prompt(
237
- base64_image="base64data",
238
- mime_type="image/jpeg",
239
- custom_prompt="Make it dramatic"
240
- )
241
-
242
- assert result == "Custom animation"
243
-
244
- @pytest.mark.asyncio
245
- async def test_generate_animation_prompt_fallback(self):
246
- """generate_animation_prompt returns fallback on empty response."""
247
- with patch('services.gemini_service.genai') as mock_genai:
248
- with patch('services.gemini_service.types'):
249
- from services.gemini_service import GeminiService
250
-
251
- mock_response = MagicMock()
252
- mock_response.text = None
253
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
254
-
255
- service = GeminiService(api_key="test-key")
256
- result = await service.generate_animation_prompt(
257
- base64_image="base64data",
258
- mime_type="image/jpeg"
259
- )
260
-
261
- assert result == "Cinematic subtle movement"
262
-
263
-
264
- # =============================================================================
265
- # 5. Image Analysis Tests
266
- # =============================================================================
267
-
268
- class TestImageAnalysis:
269
- """Test analyze_image method."""
270
-
271
- @pytest.mark.asyncio
272
- async def test_analyze_image_success(self):
273
- """analyze_image returns analysis text."""
274
- with patch('services.gemini_service.genai') as mock_genai:
275
- with patch('services.gemini_service.types'):
276
- from services.gemini_service import GeminiService
277
-
278
- mock_response = MagicMock()
279
- mock_response.text = "This image shows a sunset over mountains"
280
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
281
-
282
- service = GeminiService(api_key="test-key")
283
- result = await service.analyze_image(
284
- base64_image="base64data",
285
- mime_type="image/jpeg",
286
- prompt="Describe this image"
287
- )
288
-
289
- assert result == "This image shows a sunset over mountains"
290
-
291
- @pytest.mark.asyncio
292
- async def test_analyze_image_empty_response(self):
293
- """analyze_image returns empty string for None response."""
294
- with patch('services.gemini_service.genai') as mock_genai:
295
- with patch('services.gemini_service.types'):
296
- from services.gemini_service import GeminiService
297
-
298
- mock_response = MagicMock()
299
- mock_response.text = None
300
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
301
-
302
- service = GeminiService(api_key="test-key")
303
- result = await service.analyze_image(
304
- base64_image="base64data",
305
- mime_type="image/jpeg",
306
- prompt="Describe"
307
- )
308
-
309
- assert result == ""
310
-
311
-
312
- # =============================================================================
313
- # 6. Image Editing Tests
314
- # =============================================================================
315
-
316
- class TestImageEditing:
317
- """Test edit_image method."""
318
-
319
- @pytest.mark.asyncio
320
- async def test_edit_image_returns_data_uri(self):
321
- """edit_image returns base64 data URI."""
322
- with patch('services.gemini_service.genai') as mock_genai:
323
- from services.gemini_service import GeminiService
324
-
325
- # Create mock response structure
326
- mock_inline_data = MagicMock()
327
- mock_inline_data.data = "base64imagedata"
328
- mock_inline_data.mime_type = "image/png"
329
-
330
- mock_part = MagicMock()
331
- mock_part.inline_data = mock_inline_data
332
-
333
- mock_content = MagicMock()
334
- mock_content.parts = [mock_part]
335
-
336
- mock_candidate = MagicMock()
337
- mock_candidate.content = mock_content
338
-
339
- mock_response = MagicMock()
340
- mock_response.candidates = [mock_candidate]
341
-
342
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
343
-
344
- service = GeminiService(api_key="test-key")
345
- result = await service.edit_image(
346
- base64_image="input-base64",
347
- mime_type="image/jpeg",
348
- prompt="Make it colorful"
349
- )
350
-
351
- assert result == "data:image/png;base64,base64imagedata"
352
-
353
- @pytest.mark.asyncio
354
- async def test_edit_image_no_candidates(self):
355
- """edit_image raises error when no candidates returned."""
356
- with patch('services.gemini_service.genai') as mock_genai:
357
- from services.gemini_service import GeminiService
358
-
359
- mock_response = MagicMock()
360
- mock_response.candidates = []
361
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
362
-
363
- service = GeminiService(api_key="test-key")
364
-
365
- with pytest.raises(ValueError, match="No candidates returned"):
366
- await service.edit_image(
367
- base64_image="input-base64",
368
- mime_type="image/jpeg",
369
- prompt="Edit"
370
- )
371
-
372
- @pytest.mark.asyncio
373
- async def test_edit_image_no_image_data(self):
374
- """edit_image raises error when no image data in parts."""
375
- with patch('services.gemini_service.genai') as mock_genai:
376
- from services.gemini_service import GeminiService
377
-
378
- # Part without inline_data
379
- mock_part = MagicMock()
380
- mock_part.inline_data = None
381
-
382
- mock_content = MagicMock()
383
- mock_content.parts = [mock_part]
384
-
385
- mock_candidate = MagicMock()
386
- mock_candidate.content = mock_content
387
-
388
- mock_response = MagicMock()
389
- mock_response.candidates = [mock_candidate]
390
-
391
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
392
-
393
- service = GeminiService(api_key="test-key")
394
-
395
- with pytest.raises(ValueError, match="No image data found"):
396
- await service.edit_image(
397
- base64_image="input-base64",
398
- mime_type="image/jpeg",
399
- prompt="Edit"
400
- )
401
-
402
- @pytest.mark.asyncio
403
- async def test_edit_image_default_prompt(self):
404
- """edit_image uses default prompt when empty."""
405
- with patch('services.gemini_service.genai') as mock_genai:
406
- with patch('services.gemini_service.types'):
407
- from services.gemini_service import GeminiService
408
-
409
- mock_inline_data = MagicMock()
410
- mock_inline_data.data = "base64data"
411
- mock_inline_data.mime_type = "image/png"
412
-
413
- mock_part = MagicMock()
414
- mock_part.inline_data = mock_inline_data
415
-
416
- mock_content = MagicMock()
417
- mock_content.parts = [mock_part]
418
-
419
- mock_candidate = MagicMock()
420
- mock_candidate.content = mock_content
421
-
422
- mock_response = MagicMock()
423
- mock_response.candidates = [mock_candidate]
424
-
425
- mock_genai.Client.return_value.models.generate_content.return_value = mock_response
426
-
427
- service = GeminiService(api_key="test-key")
428
- result = await service.edit_image(
429
- base64_image="input",
430
- mime_type="image/jpeg",
431
- prompt="" # Empty prompt
432
- )
433
-
434
- assert "data:" in result
435
-
436
-
437
- # =============================================================================
438
- # 7. Video Generation Tests
439
- # =============================================================================
440
-
441
- class TestVideoGeneration:
442
- """Test start_video_generation method."""
443
-
444
- @pytest.mark.asyncio
445
- async def test_start_video_returns_operation_dict(self):
446
- """start_video_generation returns operation dictionary."""
447
- with patch('services.gemini_service.genai') as mock_genai:
448
- with patch('services.gemini_service.types'):
449
- from services.gemini_service import GeminiService
450
-
451
- mock_operation = MagicMock()
452
- mock_operation.name = "operations/video-123"
453
- mock_operation.done = False
454
-
455
- mock_genai.Client.return_value.models.generate_videos.return_value = mock_operation
456
-
457
- service = GeminiService(api_key="test-key")
458
- result = await service.start_video_generation(
459
- base64_image="base64data",
460
- mime_type="image/jpeg",
461
- prompt="Animate this"
462
- )
463
-
464
- assert result["gemini_operation_name"] == "operations/video-123"
465
- assert result["done"] == False
466
- assert result["status"] == "pending"
467
-
468
- @pytest.mark.asyncio
469
- async def test_start_video_completed_immediately(self):
470
- """start_video_generation returns completed when done=True."""
471
- with patch('services.gemini_service.genai') as mock_genai:
472
- with patch('services.gemini_service.types'):
473
- from services.gemini_service import GeminiService
474
-
475
- mock_operation = MagicMock()
476
- mock_operation.name = "operations/video-123"
477
- mock_operation.done = True
478
-
479
- mock_genai.Client.return_value.models.generate_videos.return_value = mock_operation
480
-
481
- service = GeminiService(api_key="test-key")
482
- result = await service.start_video_generation(
483
- base64_image="base64data",
484
- mime_type="image/jpeg",
485
- prompt="Animate this"
486
- )
487
-
488
- assert result["status"] == "completed"
489
-
490
- @pytest.mark.asyncio
491
- async def test_start_video_with_params(self):
492
- """start_video_generation passes aspect_ratio and resolution."""
493
- with patch('services.gemini_service.genai') as mock_genai:
494
- with patch('services.gemini_service.types'):
495
- from services.gemini_service import GeminiService
496
-
497
- mock_operation = MagicMock()
498
- mock_operation.name = "operations/video-123"
499
- mock_operation.done = False
500
-
501
- mock_genai.Client.return_value.models.generate_videos.return_value = mock_operation
502
-
503
- service = GeminiService(api_key="test-key")
504
- await service.start_video_generation(
505
- base64_image="base64data",
506
- mime_type="image/jpeg",
507
- prompt="Animate",
508
- aspect_ratio="9:16",
509
- resolution="1080p",
510
- number_of_videos=2
511
- )
512
-
513
- # Verify config was passed
514
- call_args = mock_genai.Client.return_value.models.generate_videos.call_args
515
- assert call_args is not None
516
-
517
-
518
- # =============================================================================
519
- # 8. Video Status Checking Tests
520
- # =============================================================================
521
-
522
- class TestVideoStatusChecking:
523
- """Test check_video_status method."""
524
-
525
- @pytest.mark.asyncio
526
- async def test_check_status_pending(self):
527
- """check_video_status returns pending when not done."""
528
- with patch('services.gemini_service.genai') as mock_genai:
529
- from services.gemini_service import GeminiService
530
-
531
- mock_operation = MagicMock()
532
- mock_operation.done = False
533
- mock_operation.error = None
534
-
535
- mock_genai.Client.return_value.operations.get.return_value = mock_operation
536
-
537
- service = GeminiService(api_key="test-key")
538
- result = await service.check_video_status("operations/video-123")
539
-
540
- assert result["done"] == False
541
- assert result["status"] == "pending"
542
-
543
- @pytest.mark.asyncio
544
- async def test_check_status_completed_with_url(self):
545
- """check_video_status returns completed with video URL."""
546
- with patch('services.gemini_service.genai') as mock_genai:
547
- from services.gemini_service import GeminiService
548
-
549
- # Build nested mock structure
550
- mock_video = MagicMock()
551
- mock_video.uri = "https://storage.googleapis.com/video.mp4"
552
-
553
- mock_generated_video = MagicMock()
554
- mock_generated_video.video = mock_video
555
-
556
- mock_result = MagicMock()
557
- mock_result.generated_videos = [mock_generated_video]
558
-
559
- mock_operation = MagicMock()
560
- mock_operation.done = True
561
- mock_operation.error = None
562
- mock_operation.result = mock_result
563
-
564
- mock_genai.Client.return_value.operations.get.return_value = mock_operation
565
-
566
- service = GeminiService(api_key="test-api-key")
567
- result = await service.check_video_status("operations/video-123")
568
-
569
- assert result["done"] == True
570
- assert result["status"] == "completed"
571
- assert "video_url" in result
572
- assert "test-api-key" in result["video_url"] # API key appended
573
-
574
- @pytest.mark.asyncio
575
- async def test_check_status_operation_error(self):
576
- """check_video_status returns failed on operation error."""
577
- with patch('services.gemini_service.genai') as mock_genai:
578
- from services.gemini_service import GeminiService
579
-
580
- mock_error = MagicMock()
581
- mock_error.message = "Content blocked by policy"
582
-
583
- mock_operation = MagicMock()
584
- mock_operation.done = True
585
- mock_operation.error = mock_error
586
-
587
- mock_genai.Client.return_value.operations.get.return_value = mock_operation
588
-
589
- service = GeminiService(api_key="test-key")
590
- result = await service.check_video_status("operations/video-123")
591
-
592
- assert result["done"] == True
593
- assert result["status"] == "failed"
594
- assert "error" in result
595
-
596
- @pytest.mark.asyncio
597
- async def test_check_status_404_expired(self):
598
- """check_video_status handles 404 for expired operation."""
599
- with patch('services.gemini_service.genai') as mock_genai:
600
- from services.gemini_service import GeminiService
601
-
602
- mock_genai.Client.return_value.operations.get.side_effect = Exception("404 NOT_FOUND")
603
-
604
- service = GeminiService(api_key="test-key")
605
- result = await service.check_video_status("operations/expired-123")
606
-
607
- assert result["done"] == True
608
- assert result["status"] == "failed"
609
- assert "expired" in result["error"].lower()
610
-
611
- @pytest.mark.asyncio
612
- async def test_check_status_no_video_uri(self):
613
- """check_video_status returns failed when no video URI."""
614
- with patch('services.gemini_service.genai') as mock_genai:
615
- from services.gemini_service import GeminiService
616
-
617
- mock_result = MagicMock()
618
- mock_result.generated_videos = [] # Empty
619
-
620
- mock_operation = MagicMock()
621
- mock_operation.done = True
622
- mock_operation.error = None
623
- mock_operation.result = mock_result
624
-
625
- mock_genai.Client.return_value.operations.get.return_value = mock_operation
626
-
627
- service = GeminiService(api_key="test-key")
628
- result = await service.check_video_status("operations/video-123")
629
-
630
- assert result["status"] == "failed"
631
- assert "safety filters" in result["error"].lower()
632
-
633
-
634
- # =============================================================================
635
- # 9. Video Download Tests
636
- # =============================================================================
637
-
638
- class TestVideoDownload:
639
- """Test download_video method."""
640
-
641
- @pytest.mark.asyncio
642
- async def test_download_video_saves_file(self):
643
- """download_video saves file and returns filename."""
644
- with patch('services.gemini_service.genai'):
645
- from services.gemini_service import GeminiService, DOWNLOADS_DIR
646
-
647
- with patch('httpx.AsyncClient') as mock_client:
648
- mock_response = MagicMock()
649
- mock_response.content = b"fake video data"
650
- mock_response.raise_for_status = MagicMock()
651
-
652
- mock_client_instance = AsyncMock()
653
- mock_client_instance.get.return_value = mock_response
654
- mock_client_instance.__aenter__.return_value = mock_client_instance
655
- mock_client_instance.__aexit__.return_value = None
656
- mock_client.return_value = mock_client_instance
657
-
658
- service = GeminiService(api_key="test-key")
659
-
660
- # Use temp directory
661
- with tempfile.TemporaryDirectory() as temp_dir:
662
- with patch.object(
663
- __import__('services.gemini_service', fromlist=['DOWNLOADS_DIR']),
664
- 'DOWNLOADS_DIR',
665
- temp_dir
666
- ):
667
- result = await service.download_video(
668
- "https://example.com/video.mp4",
669
- "test-op-123"
670
- )
671
-
672
- assert result == "test-op-123.mp4"
673
-
674
- @pytest.mark.asyncio
675
- async def test_download_video_http_error(self):
676
- """download_video raises error on HTTP failure."""
677
- with patch('services.gemini_service.genai'):
678
- from services.gemini_service import GeminiService
679
-
680
- with patch('httpx.AsyncClient') as mock_client:
681
- mock_client_instance = AsyncMock()
682
- mock_client_instance.get.side_effect = Exception("Connection refused")
683
- mock_client_instance.__aenter__.return_value = mock_client_instance
684
- mock_client_instance.__aexit__.return_value = None
685
- mock_client.return_value = mock_client_instance
686
-
687
- service = GeminiService(api_key="test-key")
688
-
689
- with pytest.raises(ValueError, match="Failed to download"):
690
- await service.download_video(
691
- "https://example.com/video.mp4",
692
- "test-op-123"
693
- )
694
-
695
- @pytest.mark.asyncio
696
- async def test_download_video_follows_redirects(self):
697
- """download_video client is configured to follow redirects."""
698
- with patch('services.gemini_service.genai'):
699
- from services.gemini_service import GeminiService
700
-
701
- with patch('httpx.AsyncClient') as mock_client:
702
- mock_response = MagicMock()
703
- mock_response.content = b"video data"
704
- mock_response.raise_for_status = MagicMock()
705
-
706
- mock_client_instance = AsyncMock()
707
- mock_client_instance.get.return_value = mock_response
708
- mock_client_instance.__aenter__.return_value = mock_client_instance
709
- mock_client_instance.__aexit__.return_value = None
710
- mock_client.return_value = mock_client_instance
711
-
712
- service = GeminiService(api_key="test-key")
713
-
714
- with tempfile.TemporaryDirectory() as temp_dir:
715
- with patch('services.gemini_service.DOWNLOADS_DIR', temp_dir):
716
- await service.download_video(
717
- "https://example.com/video.mp4",
718
- "redirect-test"
719
- )
720
-
721
- # Verify follow_redirects=True was passed
722
- mock_client.assert_called_with(timeout=120.0, follow_redirects=True)
723
-
724
-
725
- # =============================================================================
726
- # 10. Error Handling Tests
727
- # =============================================================================
728
-
729
- class TestErrorHandling:
730
- """Test _handle_api_error method."""
731
-
732
- def test_handle_api_error_404(self):
733
- """_handle_api_error raises ValueError for 404."""
734
- with patch('services.gemini_service.genai'):
735
- from services.gemini_service import GeminiService
736
-
737
- service = GeminiService(api_key="test-key")
738
-
739
- with pytest.raises(ValueError, match="Model not found"):
740
- service._handle_api_error(Exception("Error 404"), "test-model")
741
-
742
- def test_handle_api_error_not_found(self):
743
- """_handle_api_error handles NOT_FOUND in message."""
744
- with patch('services.gemini_service.genai'):
745
- from services.gemini_service import GeminiService
746
-
747
- service = GeminiService(api_key="test-key")
748
-
749
- with pytest.raises(ValueError, match="Model not found"):
750
- service._handle_api_error(Exception("NOT_FOUND: resource"), "test-model")
751
-
752
- def test_handle_api_error_entity_not_found(self):
753
- """_handle_api_error handles 'Requested entity was not found'."""
754
- with patch('services.gemini_service.genai'):
755
- from services.gemini_service import GeminiService
756
-
757
- service = GeminiService(api_key="test-key")
758
-
759
- with pytest.raises(ValueError, match="Model not found"):
760
- service._handle_api_error(
761
- Exception("Requested entity was not found"),
762
- "test-model"
763
- )
764
-
765
- def test_handle_api_error_bracket_5_pattern(self):
766
- """_handle_api_error handles [5, pattern."""
767
- with patch('services.gemini_service.genai'):
768
- from services.gemini_service import GeminiService
769
-
770
- service = GeminiService(api_key="test-key")
771
-
772
- with pytest.raises(ValueError, match="Model not found"):
773
- service._handle_api_error(
774
- Exception("Response [5, 'NOT_FOUND']"),
775
- "test-model"
776
- )
777
-
778
- def test_handle_api_error_reraises_other(self):
779
- """_handle_api_error re-raises non-404 errors."""
780
- with patch('services.gemini_service.genai'):
781
- from services.gemini_service import GeminiService
782
-
783
- service = GeminiService(api_key="test-key")
784
-
785
- with pytest.raises(RuntimeError, match="Connection timeout"):
786
- service._handle_api_error(
787
- RuntimeError("Connection timeout"),
788
- "test-model"
789
- )
790
-
791
-
792
- # =============================================================================
793
- # 11. Downloads Directory Tests
794
- # =============================================================================
795
-
796
- class TestDownloadsDirectory:
797
- """Test downloads directory handling."""
798
-
799
- def test_downloads_dir_exists(self):
800
- """DOWNLOADS_DIR is created on module import."""
801
- from services.gemini_service import DOWNLOADS_DIR
802
-
803
- assert os.path.exists(DOWNLOADS_DIR)
804
- assert os.path.isdir(DOWNLOADS_DIR)
805
-
806
- def test_downloads_dir_is_in_project(self):
807
- """DOWNLOADS_DIR is within project directory."""
808
- from services.gemini_service import DOWNLOADS_DIR
809
-
810
- assert "downloads" in DOWNLOADS_DIR
811
-
812
-
813
- if __name__ == "__main__":
814
- pytest.main([__file__, "-v"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/{test_worker_pool.py → test_worker_pool.py.archived} RENAMED
File without changes