Spaces:
Sleeping
Sleeping
| """ | |
| Gemini Router - API endpoints for Gemini AI services. | |
| """ | |
| import os | |
| import uuid | |
| from fastapi import APIRouter, Depends, HTTPException, status, Request | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, Literal | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy import select, func | |
| from core.database import get_db | |
| from core.models import User, GeminiJob | |
| from services.gemini_service import MODELS, DOWNLOADS_DIR | |
| from datetime import datetime | |
| router = APIRouter(prefix="/gemini", tags=["gemini"]) | |
| class GenerateAnimationPromptRequest(BaseModel): | |
| base64_image: str = Field(..., description="Base64 encoded image data") | |
| mime_type: str = Field(..., description="MIME type of the image (e.g., image/png)") | |
| custom_prompt: Optional[str] = Field(None, description="Optional custom prompt for analysis") | |
| class EditImageRequest(BaseModel): | |
| base64_image: str = Field(..., description="Base64 encoded image data") | |
| mime_type: str = Field(..., description="MIME type of the image") | |
| prompt: str = Field(..., description="Edit instructions") | |
| class GenerateVideoRequest(BaseModel): | |
| base64_image: str = Field(..., description="Base64 encoded image data") | |
| mime_type: str = Field(..., description="MIME type of the image") | |
| prompt: str = Field(..., description="Video generation prompt") | |
| aspect_ratio: Literal["16:9", "9:16"] = Field("16:9", description="Video aspect ratio") | |
| resolution: Literal["720p", "1080p"] = Field("720p", description="Video resolution") | |
| number_of_videos: int = Field(1, ge=1, le=4, description="Number of videos to generate") | |
| class GenerateTextRequest(BaseModel): | |
| prompt: str = Field(..., description="Text prompt") | |
| model: Optional[str] = Field(None, description="Model to use (defaults to gemini-2.5-flash)") | |
| class AnalyzeImageRequest(BaseModel): | |
| base64_image: str = Field(..., description="Base64 encoded image data") | |
| mime_type: str = Field(..., description="MIME type of the image") | |
| prompt: str = Field(..., description="Analysis prompt") | |
| async def get_queue_position(db: AsyncSession, job_id: str) -> int: | |
| """Get the position of a job in the queue.""" | |
| query = select(func.count()).where( | |
| GeminiJob.status == "queued", | |
| GeminiJob.created_at < select(GeminiJob.created_at).where(GeminiJob.job_id == job_id).scalar_subquery() | |
| ) | |
| result = await db.execute(query) | |
| return result.scalar() + 1 | |
| async def create_job( | |
| db: AsyncSession, | |
| user: User, | |
| job_type: str, | |
| input_data: dict, | |
| credits_reserved: int = 0 | |
| ) -> GeminiJob: | |
| """Create a new job in the queue.""" | |
| from services.gemini_service.job_processor import get_priority_for_job_type, get_pool | |
| job_id = f"job_{uuid.uuid4().hex[:16]}" | |
| priority = get_priority_for_job_type(job_type) | |
| job = GeminiJob( | |
| job_id=job_id, | |
| user_id=user.id, | |
| job_type=job_type, | |
| status="queued", | |
| priority=priority, | |
| input_data=input_data, | |
| credits_reserved=credits_reserved | |
| ) | |
| db.add(job) | |
| await db.commit() | |
| await db.refresh(job) | |
| get_pool().notify_new_job(priority) | |
| return job | |
| async def generate_animation_prompt( | |
| req: Request, | |
| request: GenerateAnimationPromptRequest, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Queue an animation prompt generation job.""" | |
| user = req.state.user | |
| credits_reserved = req.state.credits_reserved | |
| job = await create_job( | |
| db=db, | |
| user=user, | |
| job_type="animation_prompt", | |
| input_data={ | |
| "base64_image": request.base64_image, | |
| "mime_type": request.mime_type, | |
| "custom_prompt": request.custom_prompt | |
| }, | |
| credits_reserved=credits_reserved | |
| ) | |
| position = await get_queue_position(db, job.job_id) | |
| return { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "status": "queued", | |
| "position": position, | |
| "credits_remaining": user.credits | |
| } | |
| async def edit_image( | |
| req: Request, | |
| request: EditImageRequest, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Queue an image edit job.""" | |
| user = req.state.user | |
| credits_reserved = req.state.credits_reserved | |
| job = await create_job( | |
| db=db, | |
| user=user, | |
| job_type="image", | |
| input_data={ | |
| "base64_image": request.base64_image, | |
| "mime_type": request.mime_type, | |
| "prompt": request.prompt | |
| }, | |
| credits_reserved=credits_reserved | |
| ) | |
| position = await get_queue_position(db, job.job_id) | |
| return { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "status": "queued", | |
| "position": position, | |
| "credits_remaining": user.credits | |
| } | |
| async def generate_video( | |
| req: Request, | |
| request: GenerateVideoRequest, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Queue a video generation job.""" | |
| user = req.state.user | |
| credits_reserved = req.state.credits_reserved | |
| job = await create_job( | |
| db=db, | |
| user=user, | |
| job_type="video", | |
| input_data={ | |
| "base64_image": request.base64_image, | |
| "mime_type": request.mime_type, | |
| "prompt": request.prompt, | |
| "aspect_ratio": request.aspect_ratio, | |
| "resolution": request.resolution, | |
| "number_of_videos": request.number_of_videos | |
| }, | |
| credits_reserved=credits_reserved | |
| ) | |
| position = await get_queue_position(db, job.job_id) | |
| return { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "status": "queued", | |
| "position": position, | |
| "credits_remaining": user.credits | |
| } | |
| async def generate_text( | |
| req: Request, | |
| request: GenerateTextRequest, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Queue a text generation job.""" | |
| user = req.state.user | |
| credits_reserved = req.state.credits_reserved | |
| job = await create_job( | |
| db=db, | |
| user=user, | |
| job_type="text", | |
| input_data={ | |
| "prompt": request.prompt, | |
| "model": request.model | |
| }, | |
| credits_reserved=credits_reserved | |
| ) | |
| position = await get_queue_position(db, job.job_id) | |
| return { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "status": "queued", | |
| "position": position, | |
| "credits_remaining": user.credits | |
| } | |
| async def analyze_image( | |
| req: Request, | |
| request: AnalyzeImageRequest, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Queue an image analysis job.""" | |
| user = req.state.user | |
| credits_reserved = req.state.credits_reserved | |
| job = await create_job( | |
| db=db, | |
| user=user, | |
| job_type="analyze", | |
| input_data={ | |
| "base64_image": request.base64_image, | |
| "mime_type": request.mime_type, | |
| "prompt": request.prompt | |
| }, | |
| credits_reserved=credits_reserved | |
| ) | |
| position = await get_queue_position(db, job.job_id) | |
| return { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "status": "queued", | |
| "position": position, | |
| "credits_remaining": user.credits | |
| } | |
| async def get_jobs( | |
| req: Request, | |
| db: AsyncSession = Depends(get_db), | |
| page: int = 1, | |
| limit: int = 20 | |
| ): | |
| """Get all jobs for the current user.""" | |
| user = req.state.user | |
| offset = (page - 1) * limit | |
| query = select(GeminiJob).where( | |
| GeminiJob.user_id == user.id | |
| ).order_by(GeminiJob.created_at.desc()).offset(offset).limit(limit) | |
| result = await db.execute(query) | |
| jobs = result.scalars().all() | |
| count_query = select(func.count()).where(GeminiJob.user_id == user.id) | |
| count_result = await db.execute(count_query) | |
| total_count = count_result.scalar() | |
| job_list = [] | |
| for job in jobs: | |
| job_item = { | |
| "job_id": job.job_id, | |
| "job_type": job.job_type, | |
| "status": job.status, | |
| "created_at": job.created_at.isoformat() if job.created_at else None, | |
| "completed_at": job.completed_at.isoformat() if job.completed_at else None, | |
| } | |
| if job.job_type == "video" and job.input_data: | |
| job_item["prompt"] = job.input_data.get("prompt") | |
| if job.status == "failed": | |
| job_item["error"] = job.error_message | |
| if job.status == "completed" and job.job_type == "video" and job.output_data and job.output_data.get("filename"): | |
| job_item["download_url"] = f"/gemini/download/{job.job_id}" | |
| job_list.append(job_item) | |
| return { | |
| "success": True, | |
| "jobs": job_list, | |
| "total_count": total_count, | |
| "page": page, | |
| "limit": limit | |
| } | |
| async def get_job_status( | |
| job_id: str, | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Get job status and update if processing.""" | |
| user = req.state.user | |
| query = select(GeminiJob).where( | |
| GeminiJob.job_id == job_id, | |
| GeminiJob.user_id == user.id | |
| ) | |
| result = await db.execute(query) | |
| job = result.scalar_one_or_none() | |
| if not job: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Job not found" | |
| ) | |
| if job.status == "processing" and job.job_type == "video" and job.third_party_id: | |
| from services.gemini_service.job_processor import GeminiJobProcessor | |
| processor = GeminiJobProcessor() | |
| job = await processor.check_status(job, db) | |
| await db.commit() | |
| await db.refresh(job) | |
| response = { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "job_type": job.job_type, | |
| "status": job.status, | |
| "created_at": job.created_at.isoformat() if job.created_at else None, | |
| "credits_remaining": user.credits | |
| } | |
| if job.job_type == "video" and job.input_data: | |
| response["prompt"] = job.input_data.get("prompt") | |
| if job.status == "queued": | |
| response["position"] = await get_queue_position(db, job.job_id) | |
| if job.status == "processing": | |
| response["started_at"] = job.started_at.isoformat() if job.started_at else None | |
| if job.status == "completed": | |
| response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None | |
| if job.output_data and "prompt" in job.output_data: | |
| response["prompt"] = job.output_data["prompt"] | |
| if job.status == "failed": | |
| response["error"] = job.error_message | |
| response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None | |
| return response | |
| async def download_video( | |
| job_id: str, | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Stream video from Gemini to client.""" | |
| user = req.state.user | |
| from fastapi.responses import StreamingResponse | |
| import httpx | |
| query = select(GeminiJob).where( | |
| GeminiJob.job_id == job_id, | |
| GeminiJob.user_id == user.id, | |
| GeminiJob.job_type == "video" | |
| ) | |
| result = await db.execute(query) | |
| job = result.scalar_one_or_none() | |
| if not job: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Job not found" | |
| ) | |
| if job.status != "completed" or not job.output_data: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Video not ready for download" | |
| ) | |
| video_url = job.output_data.get("video_url") | |
| if not video_url: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="No video URL available" | |
| ) | |
| async def stream_video(): | |
| try: | |
| async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client: | |
| async with client.stream("GET", video_url) as response: | |
| response.raise_for_status() | |
| async for chunk in response.aiter_bytes(chunk_size=8192): | |
| yield chunk | |
| except httpx.HTTPStatusError as e: | |
| if e.response.status_code in (401, 403, 404, 410): | |
| job.status = "expired" | |
| await db.commit() | |
| raise HTTPException( | |
| status_code=status.HTTP_410_GONE, | |
| detail="Video URL has expired. Please generate a new video." | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail=f"Failed to fetch video from source: {e.response.status_code}" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to stream video: {str(e)}" | |
| ) | |
| return StreamingResponse( | |
| stream_video(), | |
| media_type="video/mp4", | |
| headers={ | |
| "Content-Disposition": f'attachment; filename="video_{job_id}.mp4"' | |
| } | |
| ) | |
| async def cancel_job( | |
| job_id: str, | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Cancel a queued job.""" | |
| user = req.state.user | |
| query = select(GeminiJob).where( | |
| GeminiJob.job_id == job_id, | |
| GeminiJob.user_id == user.id | |
| ) | |
| result = await db.execute(query) | |
| job = result.scalar_one_or_none() | |
| if not job: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Job not found" | |
| ) | |
| if job.status != "queued": | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Cannot cancel job with status '{job.status}'. Only queued jobs can be cancelled." | |
| ) | |
| job.status = "cancelled" | |
| job.completed_at = datetime.utcnow() | |
| await db.commit() | |
| return { | |
| "success": True, | |
| "job_id": job.job_id, | |
| "status": "cancelled", | |
| "message": "Job cancelled successfully" | |
| } | |
| async def delete_job( | |
| job_id: str, | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ): | |
| """Delete job with conditional credit refund.""" | |
| user = req.state.user | |
| from services.db_service import QueryService | |
| qs = QueryService(user, db) | |
| job = await qs.select().execute_one( | |
| select(GeminiJob).where(GeminiJob.job_id == job_id) | |
| ) | |
| if not job: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Job not found" | |
| ) | |
| refund_amount = 0 | |
| message = "Job deleted" | |
| if job.credits_reserved > 0 and not job.credits_refunded: | |
| from services.credit_service import CreditTransactionManager | |
| if not job.third_party_id: | |
| refund_amount = job.credits_reserved | |
| refund_reason = "Job deleted before execution" | |
| elif job.status == "queued": | |
| penalty = 2 | |
| refund_amount = max(0, job.credits_reserved - penalty) | |
| refund_reason = f"Job cancelled while queued (penalty: {penalty} credits)" | |
| elif job.status in ["processing", "completed"]: | |
| refund_amount = 0 | |
| refund_reason = None | |
| elif job.status == "failed": | |
| refund_amount = job.credits_reserved | |
| refund_reason = "Job failed - full refund" | |
| if refund_amount > 0: | |
| try: | |
| await CreditTransactionManager.add_credits( | |
| session=db, | |
| user=user, | |
| amount=refund_amount, | |
| source="job_deletion", | |
| reference_type="job", | |
| reference_id=job.job_id, | |
| reason=refund_reason, | |
| metadata={ | |
| "job_type": job.job_type, | |
| "original_cost": job.credits_reserved, | |
| "job_status": job.status | |
| } | |
| ) | |
| job.credits_refunded = True | |
| await db.commit() | |
| message = f"Job deleted. {refund_amount} credits refunded." | |
| except Exception as e: | |
| logger.error(f"Failed to refund credits for job {job_id}: {e}") | |
| message = "Job deleted (refund failed - contact support)" | |
| deleted = await qs.delete().soft_delete_one(job) | |
| if not deleted: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to delete job" | |
| ) | |
| return { | |
| "success": True, | |
| "message": message, | |
| "refund_amount": refund_amount, | |
| "new_credit_balance": user.credits | |
| } | |
| async def get_models(): | |
| """Get available model names.""" | |
| return {"models": MODELS} | |