""" Gemini Router - API endpoints for Gemini AI services. """ import os import uuid from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.responses import FileResponse from pydantic import BaseModel, Field from typing import Optional, Literal from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from core.database import get_db from core.models import User, GeminiJob from services.gemini_service import MODELS, DOWNLOADS_DIR from datetime import datetime router = APIRouter(prefix="/gemini", tags=["gemini"]) class GenerateAnimationPromptRequest(BaseModel): base64_image: str = Field(..., description="Base64 encoded image data") mime_type: str = Field(..., description="MIME type of the image (e.g., image/png)") custom_prompt: Optional[str] = Field(None, description="Optional custom prompt for analysis") class EditImageRequest(BaseModel): base64_image: str = Field(..., description="Base64 encoded image data") mime_type: str = Field(..., description="MIME type of the image") prompt: str = Field(..., description="Edit instructions") class GenerateVideoRequest(BaseModel): base64_image: str = Field(..., description="Base64 encoded image data") mime_type: str = Field(..., description="MIME type of the image") prompt: str = Field(..., description="Video generation prompt") aspect_ratio: Literal["16:9", "9:16"] = Field("16:9", description="Video aspect ratio") resolution: Literal["720p", "1080p"] = Field("720p", description="Video resolution") number_of_videos: int = Field(1, ge=1, le=4, description="Number of videos to generate") class GenerateTextRequest(BaseModel): prompt: str = Field(..., description="Text prompt") model: Optional[str] = Field(None, description="Model to use (defaults to gemini-2.5-flash)") class AnalyzeImageRequest(BaseModel): base64_image: str = Field(..., description="Base64 encoded image data") mime_type: str = Field(..., description="MIME type of the image") prompt: str = Field(..., description="Analysis prompt") async def get_queue_position(db: AsyncSession, job_id: str) -> int: """Get the position of a job in the queue.""" query = select(func.count()).where( GeminiJob.status == "queued", GeminiJob.created_at < select(GeminiJob.created_at).where(GeminiJob.job_id == job_id).scalar_subquery() ) result = await db.execute(query) return result.scalar() + 1 async def create_job( db: AsyncSession, user: User, job_type: str, input_data: dict, credits_reserved: int = 0 ) -> GeminiJob: """Create a new job in the queue.""" from services.gemini_service.job_processor import get_priority_for_job_type, get_pool job_id = f"job_{uuid.uuid4().hex[:16]}" priority = get_priority_for_job_type(job_type) job = GeminiJob( job_id=job_id, user_id=user.id, job_type=job_type, status="queued", priority=priority, input_data=input_data, credits_reserved=credits_reserved ) db.add(job) await db.commit() await db.refresh(job) get_pool().notify_new_job(priority) return job @router.post("/generate-animation-prompt") async def generate_animation_prompt( req: Request, request: GenerateAnimationPromptRequest, db: AsyncSession = Depends(get_db) ): """Queue an animation prompt generation job.""" user = req.state.user credits_reserved = req.state.credits_reserved job = await create_job( db=db, user=user, job_type="animation_prompt", input_data={ "base64_image": request.base64_image, "mime_type": request.mime_type, "custom_prompt": request.custom_prompt }, credits_reserved=credits_reserved ) position = await get_queue_position(db, job.job_id) return { "success": True, "job_id": job.job_id, "status": "queued", "position": position, "credits_remaining": user.credits } @router.post("/edit-image") async def edit_image( req: Request, request: EditImageRequest, db: AsyncSession = Depends(get_db) ): """Queue an image edit job.""" user = req.state.user credits_reserved = req.state.credits_reserved job = await create_job( db=db, user=user, job_type="image", input_data={ "base64_image": request.base64_image, "mime_type": request.mime_type, "prompt": request.prompt }, credits_reserved=credits_reserved ) position = await get_queue_position(db, job.job_id) return { "success": True, "job_id": job.job_id, "status": "queued", "position": position, "credits_remaining": user.credits } @router.post("/generate-video") async def generate_video( req: Request, request: GenerateVideoRequest, db: AsyncSession = Depends(get_db) ): """Queue a video generation job.""" user = req.state.user credits_reserved = req.state.credits_reserved job = await create_job( db=db, user=user, job_type="video", input_data={ "base64_image": request.base64_image, "mime_type": request.mime_type, "prompt": request.prompt, "aspect_ratio": request.aspect_ratio, "resolution": request.resolution, "number_of_videos": request.number_of_videos }, credits_reserved=credits_reserved ) position = await get_queue_position(db, job.job_id) return { "success": True, "job_id": job.job_id, "status": "queued", "position": position, "credits_remaining": user.credits } @router.post("/generate-text") async def generate_text( req: Request, request: GenerateTextRequest, db: AsyncSession = Depends(get_db) ): """Queue a text generation job.""" user = req.state.user credits_reserved = req.state.credits_reserved job = await create_job( db=db, user=user, job_type="text", input_data={ "prompt": request.prompt, "model": request.model }, credits_reserved=credits_reserved ) position = await get_queue_position(db, job.job_id) return { "success": True, "job_id": job.job_id, "status": "queued", "position": position, "credits_remaining": user.credits } @router.post("/analyze-image") async def analyze_image( req: Request, request: AnalyzeImageRequest, db: AsyncSession = Depends(get_db) ): """Queue an image analysis job.""" user = req.state.user credits_reserved = req.state.credits_reserved job = await create_job( db=db, user=user, job_type="analyze", input_data={ "base64_image": request.base64_image, "mime_type": request.mime_type, "prompt": request.prompt }, credits_reserved=credits_reserved ) position = await get_queue_position(db, job.job_id) return { "success": True, "job_id": job.job_id, "status": "queued", "position": position, "credits_remaining": user.credits } @router.get("/jobs") async def get_jobs( req: Request, db: AsyncSession = Depends(get_db), page: int = 1, limit: int = 20 ): """Get all jobs for the current user.""" user = req.state.user offset = (page - 1) * limit query = select(GeminiJob).where( GeminiJob.user_id == user.id ).order_by(GeminiJob.created_at.desc()).offset(offset).limit(limit) result = await db.execute(query) jobs = result.scalars().all() count_query = select(func.count()).where(GeminiJob.user_id == user.id) count_result = await db.execute(count_query) total_count = count_result.scalar() job_list = [] for job in jobs: job_item = { "job_id": job.job_id, "job_type": job.job_type, "status": job.status, "created_at": job.created_at.isoformat() if job.created_at else None, "completed_at": job.completed_at.isoformat() if job.completed_at else None, } if job.job_type == "video" and job.input_data: job_item["prompt"] = job.input_data.get("prompt") if job.status == "failed": job_item["error"] = job.error_message if job.status == "completed" and job.job_type == "video" and job.output_data and job.output_data.get("filename"): job_item["download_url"] = f"/gemini/download/{job.job_id}" job_list.append(job_item) return { "success": True, "jobs": job_list, "total_count": total_count, "page": page, "limit": limit } @router.get("/job/{job_id}") async def get_job_status( job_id: str, req: Request, db: AsyncSession = Depends(get_db) ): """Get job status and update if processing.""" user = req.state.user query = select(GeminiJob).where( GeminiJob.job_id == job_id, GeminiJob.user_id == user.id ) result = await db.execute(query) job = result.scalar_one_or_none() if not job: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) if job.status == "processing" and job.job_type == "video" and job.third_party_id: from services.gemini_service.job_processor import GeminiJobProcessor processor = GeminiJobProcessor() job = await processor.check_status(job, db) await db.commit() await db.refresh(job) response = { "success": True, "job_id": job.job_id, "job_type": job.job_type, "status": job.status, "created_at": job.created_at.isoformat() if job.created_at else None, "credits_remaining": user.credits } if job.job_type == "video" and job.input_data: response["prompt"] = job.input_data.get("prompt") if job.status == "queued": response["position"] = await get_queue_position(db, job.job_id) if job.status == "processing": response["started_at"] = job.started_at.isoformat() if job.started_at else None if job.status == "completed": response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None if job.output_data and "prompt" in job.output_data: response["prompt"] = job.output_data["prompt"] if job.status == "failed": response["error"] = job.error_message response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None return response @router.get("/download/{job_id}") async def download_video( job_id: str, req: Request, db: AsyncSession = Depends(get_db) ): """Stream video from Gemini to client.""" user = req.state.user from fastapi.responses import StreamingResponse import httpx query = select(GeminiJob).where( GeminiJob.job_id == job_id, GeminiJob.user_id == user.id, GeminiJob.job_type == "video" ) result = await db.execute(query) job = result.scalar_one_or_none() if not job: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) if job.status != "completed" or not job.output_data: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Video not ready for download" ) video_url = job.output_data.get("video_url") if not video_url: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="No video URL available" ) async def stream_video(): try: async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client: async with client.stream("GET", video_url) as response: response.raise_for_status() async for chunk in response.aiter_bytes(chunk_size=8192): yield chunk except httpx.HTTPStatusError as e: if e.response.status_code in (401, 403, 404, 410): job.status = "expired" await db.commit() raise HTTPException( status_code=status.HTTP_410_GONE, detail="Video URL has expired. Please generate a new video." ) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Failed to fetch video from source: {e.response.status_code}" ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to stream video: {str(e)}" ) return StreamingResponse( stream_video(), media_type="video/mp4", headers={ "Content-Disposition": f'attachment; filename="video_{job_id}.mp4"' } ) @router.post("/job/{job_id}/cancel") async def cancel_job( job_id: str, req: Request, db: AsyncSession = Depends(get_db) ): """Cancel a queued job.""" user = req.state.user query = select(GeminiJob).where( GeminiJob.job_id == job_id, GeminiJob.user_id == user.id ) result = await db.execute(query) job = result.scalar_one_or_none() if not job: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) if job.status != "queued": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Cannot cancel job with status '{job.status}'. Only queued jobs can be cancelled." ) job.status = "cancelled" job.completed_at = datetime.utcnow() await db.commit() return { "success": True, "job_id": job.job_id, "status": "cancelled", "message": "Job cancelled successfully" } @router.delete("/job/{job_id}") async def delete_job( job_id: str, req: Request, db: AsyncSession = Depends(get_db) ): """Delete job with conditional credit refund.""" user = req.state.user from services.db_service import QueryService qs = QueryService(user, db) job = await qs.select().execute_one( select(GeminiJob).where(GeminiJob.job_id == job_id) ) if not job: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Job not found" ) refund_amount = 0 message = "Job deleted" if job.credits_reserved > 0 and not job.credits_refunded: from services.credit_service import CreditTransactionManager if not job.third_party_id: refund_amount = job.credits_reserved refund_reason = "Job deleted before execution" elif job.status == "queued": penalty = 2 refund_amount = max(0, job.credits_reserved - penalty) refund_reason = f"Job cancelled while queued (penalty: {penalty} credits)" elif job.status in ["processing", "completed"]: refund_amount = 0 refund_reason = None elif job.status == "failed": refund_amount = job.credits_reserved refund_reason = "Job failed - full refund" if refund_amount > 0: try: await CreditTransactionManager.add_credits( session=db, user=user, amount=refund_amount, source="job_deletion", reference_type="job", reference_id=job.job_id, reason=refund_reason, metadata={ "job_type": job.job_type, "original_cost": job.credits_reserved, "job_status": job.status } ) job.credits_refunded = True await db.commit() message = f"Job deleted. {refund_amount} credits refunded." except Exception as e: logger.error(f"Failed to refund credits for job {job_id}: {e}") message = "Job deleted (refund failed - contact support)" deleted = await qs.delete().soft_delete_one(job) if not deleted: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete job" ) return { "success": True, "message": message, "refund_amount": refund_amount, "new_credit_balance": user.credits } @router.get("/models") async def get_models(): """Get available model names.""" return {"models": MODELS}