"""Ad generation endpoints (single, batch, image, strategies, models).""" import os from typing import Literal, Optional from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import FileResponse, Response as FastAPIResponse from api.schemas import ( GenerateRequest, GenerateResponse, GenerateBatchRequest, BatchResponse, ) from services.generator import ad_generator from services.auth_dependency import get_current_user from config import settings router = APIRouter(tags=["generate"]) @router.post("/generate", response_model=GenerateResponse) async def generate( request: GenerateRequest, username: str = Depends(get_current_user), ): """ Generate a single ad creative. Requires authentication. Uses randomization for strategies, hooks, visuals. """ try: return await ad_generator.generate_ad( niche=request.niche, num_images=request.num_images, image_model=request.image_model, username=username, target_audience=request.target_audience, offer=request.offer, use_trending=request.use_trending, trending_context=request.trending_context, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate/batch", response_model=BatchResponse) async def generate_batch( request: GenerateBatchRequest, username: str = Depends(get_current_user), ): """ Generate multiple ad creatives in batch. Requires authentication. Each ad is unique due to randomization. """ try: results = await ad_generator.generate_batch( niche=request.niche, count=request.count, images_per_ad=request.images_per_ad, image_model=request.image_model, username=username, method=request.method, target_audience=request.target_audience, offer=request.offer, ) return {"count": len(results), "ads": results} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/image/{filename}") async def get_image(filename: str): """Get a generated image by filename.""" filepath = os.path.join(settings.output_dir, filename) if not os.path.exists(filepath): raise HTTPException(status_code=404, detail="Image not found") return FileResponse(filepath) @router.get("/api/download-image") async def download_image_proxy( image_url: Optional[str] = None, image_id: Optional[str] = None, username: str = Depends(get_current_user), ): """ Proxy endpoint to download images, avoiding CORS. Can fetch from external URLs (R2, Replicate) or local files. """ import httpx from services.database import db_service filename = None if image_id: ad = await db_service.get_ad_creative(image_id) if not ad: raise HTTPException(status_code=404, detail="Ad not found") if ad.get("username") != username: raise HTTPException(status_code=403, detail="Access denied") if not image_url: image_url = ad.get("r2_url") or ad.get("image_url") filename = ad.get("image_filename") else: metadata = ad.get("metadata", {}) if metadata.get("original_r2_url") == image_url or metadata.get("original_image_url") == image_url: filename = metadata.get("original_image_filename") if not image_url: raise HTTPException(status_code=400, detail="No image URL provided") try: if not image_url.startswith(("http://", "https://")): filepath = os.path.join(settings.output_dir, image_url) if os.path.exists(filepath): return FileResponse(filepath, filename=filename or os.path.basename(filepath)) raise HTTPException(status_code=404, detail="Image file not found") async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(image_url) response.raise_for_status() content_type = response.headers.get("content-type", "image/png") if not filename: filename = image_url.split("/")[-1].split("?")[0] if not filename or "." not in filename: filename = "image.png" return FastAPIResponse( content=response.content, media_type=content_type, headers={ "Content-Disposition": f'attachment; filename="{filename}"', "Cache-Control": "public, max-age=3600", }, ) except httpx.HTTPError as e: raise HTTPException(status_code=502, detail=f"Failed to fetch image: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Error downloading image: {str(e)}") @router.get("/api/models") async def list_image_models(): """ List all available image generation models. """ from services.image import MODEL_REGISTRY preferred_order = ["nano-banana", "nano-banana-pro", "z-image-turbo", "imagen-4-ultra", "recraft-v3", "ideogram-v3", "photon", "seedream-3"] models = [] for key in preferred_order: if key in MODEL_REGISTRY: config = MODEL_REGISTRY[key] models.append({ "key": key, "id": config["id"], "uses_dimensions": config.get("uses_dimensions", False), }) for key, config in MODEL_REGISTRY.items(): if key not in preferred_order: models.append({ "key": key, "id": config["id"], "uses_dimensions": config.get("uses_dimensions", False), }) models.append({ "key": "gpt-image-1.5", "id": "openai/gpt-image-1.5", "uses_dimensions": True, }) return {"models": models, "default": "nano-banana"} @router.get("/strategies/{niche}") async def get_strategies(niche: Literal["home_insurance", "glp1", "auto_insurance"]): """Get available psychological strategies for a niche.""" from data import home_insurance, glp1, auto_insurance if niche == "home_insurance": data = home_insurance.get_niche_data() elif niche == "auto_insurance": data = auto_insurance.get_niche_data() else: data = glp1.get_niche_data() strategies = {} for name, strategy in data["strategies"].items(): strategies[name] = { "name": strategy["name"], "description": strategy["description"], "hook_count": len(strategy["hooks"]), "sample_hooks": strategy["hooks"][:3], } return { "niche": niche, "total_strategies": len(strategies), "total_hooks": len(data["all_hooks"]), "strategies": strategies, }