apigateway / routers /gemini.py
jebin2's picture
credit issue fix
2dbfc89
"""
Gemini Router - API endpoints for Gemini AI services.
"""
import os
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from typing import Optional, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from core.database import get_db
from core.models import User, GeminiJob
from services.gemini_service import MODELS, DOWNLOADS_DIR
from datetime import datetime
router = APIRouter(prefix="/gemini", tags=["gemini"])
class GenerateAnimationPromptRequest(BaseModel):
base64_image: str = Field(..., description="Base64 encoded image data")
mime_type: str = Field(..., description="MIME type of the image (e.g., image/png)")
custom_prompt: Optional[str] = Field(None, description="Optional custom prompt for analysis")
class EditImageRequest(BaseModel):
base64_image: str = Field(..., description="Base64 encoded image data")
mime_type: str = Field(..., description="MIME type of the image")
prompt: str = Field(..., description="Edit instructions")
class GenerateVideoRequest(BaseModel):
base64_image: str = Field(..., description="Base64 encoded image data")
mime_type: str = Field(..., description="MIME type of the image")
prompt: str = Field(..., description="Video generation prompt")
aspect_ratio: Literal["16:9", "9:16"] = Field("16:9", description="Video aspect ratio")
resolution: Literal["720p", "1080p"] = Field("720p", description="Video resolution")
number_of_videos: int = Field(1, ge=1, le=4, description="Number of videos to generate")
class GenerateTextRequest(BaseModel):
prompt: str = Field(..., description="Text prompt")
model: Optional[str] = Field(None, description="Model to use (defaults to gemini-2.5-flash)")
class AnalyzeImageRequest(BaseModel):
base64_image: str = Field(..., description="Base64 encoded image data")
mime_type: str = Field(..., description="MIME type of the image")
prompt: str = Field(..., description="Analysis prompt")
async def get_queue_position(db: AsyncSession, job_id: str) -> int:
"""Get the position of a job in the queue."""
query = select(func.count()).where(
GeminiJob.status == "queued",
GeminiJob.created_at < select(GeminiJob.created_at).where(GeminiJob.job_id == job_id).scalar_subquery()
)
result = await db.execute(query)
return result.scalar() + 1
async def create_job(
db: AsyncSession,
user: User,
job_type: str,
input_data: dict,
credits_reserved: int = 0
) -> GeminiJob:
"""Create a new job in the queue."""
from services.gemini_service.job_processor import get_priority_for_job_type, get_pool
job_id = f"job_{uuid.uuid4().hex[:16]}"
priority = get_priority_for_job_type(job_type)
job = GeminiJob(
job_id=job_id,
user_id=user.id,
job_type=job_type,
status="queued",
priority=priority,
input_data=input_data,
credits_reserved=credits_reserved
)
db.add(job)
await db.commit()
await db.refresh(job)
get_pool().notify_new_job(priority)
return job
@router.post("/generate-animation-prompt")
async def generate_animation_prompt(
req: Request,
request: GenerateAnimationPromptRequest,
db: AsyncSession = Depends(get_db)
):
"""Queue an animation prompt generation job."""
user = req.state.user
credits_reserved = req.state.credits_reserved
job = await create_job(
db=db,
user=user,
job_type="animation_prompt",
input_data={
"base64_image": request.base64_image,
"mime_type": request.mime_type,
"custom_prompt": request.custom_prompt
},
credits_reserved=credits_reserved
)
position = await get_queue_position(db, job.job_id)
return {
"success": True,
"job_id": job.job_id,
"status": "queued",
"position": position,
"credits_remaining": user.credits
}
@router.post("/edit-image")
async def edit_image(
req: Request,
request: EditImageRequest,
db: AsyncSession = Depends(get_db)
):
"""Queue an image edit job."""
user = req.state.user
credits_reserved = req.state.credits_reserved
job = await create_job(
db=db,
user=user,
job_type="image",
input_data={
"base64_image": request.base64_image,
"mime_type": request.mime_type,
"prompt": request.prompt
},
credits_reserved=credits_reserved
)
position = await get_queue_position(db, job.job_id)
return {
"success": True,
"job_id": job.job_id,
"status": "queued",
"position": position,
"credits_remaining": user.credits
}
@router.post("/generate-video")
async def generate_video(
req: Request,
request: GenerateVideoRequest,
db: AsyncSession = Depends(get_db)
):
"""Queue a video generation job."""
user = req.state.user
credits_reserved = req.state.credits_reserved
job = await create_job(
db=db,
user=user,
job_type="video",
input_data={
"base64_image": request.base64_image,
"mime_type": request.mime_type,
"prompt": request.prompt,
"aspect_ratio": request.aspect_ratio,
"resolution": request.resolution,
"number_of_videos": request.number_of_videos
},
credits_reserved=credits_reserved
)
position = await get_queue_position(db, job.job_id)
return {
"success": True,
"job_id": job.job_id,
"status": "queued",
"position": position,
"credits_remaining": user.credits
}
@router.post("/generate-text")
async def generate_text(
req: Request,
request: GenerateTextRequest,
db: AsyncSession = Depends(get_db)
):
"""Queue a text generation job."""
user = req.state.user
credits_reserved = req.state.credits_reserved
job = await create_job(
db=db,
user=user,
job_type="text",
input_data={
"prompt": request.prompt,
"model": request.model
},
credits_reserved=credits_reserved
)
position = await get_queue_position(db, job.job_id)
return {
"success": True,
"job_id": job.job_id,
"status": "queued",
"position": position,
"credits_remaining": user.credits
}
@router.post("/analyze-image")
async def analyze_image(
req: Request,
request: AnalyzeImageRequest,
db: AsyncSession = Depends(get_db)
):
"""Queue an image analysis job."""
user = req.state.user
credits_reserved = req.state.credits_reserved
job = await create_job(
db=db,
user=user,
job_type="analyze",
input_data={
"base64_image": request.base64_image,
"mime_type": request.mime_type,
"prompt": request.prompt
},
credits_reserved=credits_reserved
)
position = await get_queue_position(db, job.job_id)
return {
"success": True,
"job_id": job.job_id,
"status": "queued",
"position": position,
"credits_remaining": user.credits
}
@router.get("/jobs")
async def get_jobs(
req: Request,
db: AsyncSession = Depends(get_db),
page: int = 1,
limit: int = 20
):
"""Get all jobs for the current user."""
user = req.state.user
offset = (page - 1) * limit
query = select(GeminiJob).where(
GeminiJob.user_id == user.id
).order_by(GeminiJob.created_at.desc()).offset(offset).limit(limit)
result = await db.execute(query)
jobs = result.scalars().all()
count_query = select(func.count()).where(GeminiJob.user_id == user.id)
count_result = await db.execute(count_query)
total_count = count_result.scalar()
job_list = []
for job in jobs:
job_item = {
"job_id": job.job_id,
"job_type": job.job_type,
"status": job.status,
"created_at": job.created_at.isoformat() if job.created_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
}
if job.job_type == "video" and job.input_data:
job_item["prompt"] = job.input_data.get("prompt")
if job.status == "failed":
job_item["error"] = job.error_message
if job.status == "completed" and job.job_type == "video" and job.output_data and job.output_data.get("filename"):
job_item["download_url"] = f"/gemini/download/{job.job_id}"
job_list.append(job_item)
return {
"success": True,
"jobs": job_list,
"total_count": total_count,
"page": page,
"limit": limit
}
@router.get("/job/{job_id}")
async def get_job_status(
job_id: str,
req: Request,
db: AsyncSession = Depends(get_db)
):
"""Get job status and update if processing."""
user = req.state.user
query = select(GeminiJob).where(
GeminiJob.job_id == job_id,
GeminiJob.user_id == user.id
)
result = await db.execute(query)
job = result.scalar_one_or_none()
if not job:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Job not found"
)
if job.status == "processing" and job.job_type == "video" and job.third_party_id:
from services.gemini_service.job_processor import GeminiJobProcessor
processor = GeminiJobProcessor()
job = await processor.check_status(job, db)
await db.commit()
await db.refresh(job)
response = {
"success": True,
"job_id": job.job_id,
"job_type": job.job_type,
"status": job.status,
"created_at": job.created_at.isoformat() if job.created_at else None,
"credits_remaining": user.credits
}
if job.job_type == "video" and job.input_data:
response["prompt"] = job.input_data.get("prompt")
if job.status == "queued":
response["position"] = await get_queue_position(db, job.job_id)
if job.status == "processing":
response["started_at"] = job.started_at.isoformat() if job.started_at else None
if job.status == "completed":
response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
if job.output_data and "prompt" in job.output_data:
response["prompt"] = job.output_data["prompt"]
if job.status == "failed":
response["error"] = job.error_message
response["completed_at"] = job.completed_at.isoformat() if job.completed_at else None
return response
@router.get("/download/{job_id}")
async def download_video(
job_id: str,
req: Request,
db: AsyncSession = Depends(get_db)
):
"""Stream video from Gemini to client."""
user = req.state.user
from fastapi.responses import StreamingResponse
import httpx
query = select(GeminiJob).where(
GeminiJob.job_id == job_id,
GeminiJob.user_id == user.id,
GeminiJob.job_type == "video"
)
result = await db.execute(query)
job = result.scalar_one_or_none()
if not job:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Job not found"
)
if job.status != "completed" or not job.output_data:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Video not ready for download"
)
video_url = job.output_data.get("video_url")
if not video_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No video URL available"
)
async def stream_video():
try:
async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
async with client.stream("GET", video_url) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes(chunk_size=8192):
yield chunk
except httpx.HTTPStatusError as e:
if e.response.status_code in (401, 403, 404, 410):
job.status = "expired"
await db.commit()
raise HTTPException(
status_code=status.HTTP_410_GONE,
detail="Video URL has expired. Please generate a new video."
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to fetch video from source: {e.response.status_code}"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to stream video: {str(e)}"
)
return StreamingResponse(
stream_video(),
media_type="video/mp4",
headers={
"Content-Disposition": f'attachment; filename="video_{job_id}.mp4"'
}
)
@router.post("/job/{job_id}/cancel")
async def cancel_job(
job_id: str,
req: Request,
db: AsyncSession = Depends(get_db)
):
"""Cancel a queued job."""
user = req.state.user
query = select(GeminiJob).where(
GeminiJob.job_id == job_id,
GeminiJob.user_id == user.id
)
result = await db.execute(query)
job = result.scalar_one_or_none()
if not job:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Job not found"
)
if job.status != "queued":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot cancel job with status '{job.status}'. Only queued jobs can be cancelled."
)
job.status = "cancelled"
job.completed_at = datetime.utcnow()
await db.commit()
return {
"success": True,
"job_id": job.job_id,
"status": "cancelled",
"message": "Job cancelled successfully"
}
@router.delete("/job/{job_id}")
async def delete_job(
job_id: str,
req: Request,
db: AsyncSession = Depends(get_db)
):
"""Delete job with conditional credit refund."""
user = req.state.user
from services.db_service import QueryService
qs = QueryService(user, db)
job = await qs.select().execute_one(
select(GeminiJob).where(GeminiJob.job_id == job_id)
)
if not job:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Job not found"
)
refund_amount = 0
message = "Job deleted"
if job.credits_reserved > 0 and not job.credits_refunded:
from services.credit_service import CreditTransactionManager
if not job.third_party_id:
refund_amount = job.credits_reserved
refund_reason = "Job deleted before execution"
elif job.status == "queued":
penalty = 2
refund_amount = max(0, job.credits_reserved - penalty)
refund_reason = f"Job cancelled while queued (penalty: {penalty} credits)"
elif job.status in ["processing", "completed"]:
refund_amount = 0
refund_reason = None
elif job.status == "failed":
refund_amount = job.credits_reserved
refund_reason = "Job failed - full refund"
if refund_amount > 0:
try:
await CreditTransactionManager.add_credits(
session=db,
user=user,
amount=refund_amount,
source="job_deletion",
reference_type="job",
reference_id=job.job_id,
reason=refund_reason,
metadata={
"job_type": job.job_type,
"original_cost": job.credits_reserved,
"job_status": job.status
}
)
job.credits_refunded = True
await db.commit()
message = f"Job deleted. {refund_amount} credits refunded."
except Exception as e:
logger.error(f"Failed to refund credits for job {job_id}: {e}")
message = "Job deleted (refund failed - contact support)"
deleted = await qs.delete().soft_delete_one(job)
if not deleted:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete job"
)
return {
"success": True,
"message": message,
"refund_amount": refund_amount,
"new_credit_balance": user.credits
}
@router.get("/models")
async def get_models():
"""Get available model names."""
return {"models": MODELS}