|
|
"""Ad generation endpoints (single, batch, image, strategies, models).""" |
|
|
|
|
|
import os |
|
|
from typing import Literal, Optional |
|
|
from fastapi import APIRouter, HTTPException, Depends |
|
|
from fastapi.responses import FileResponse, Response as FastAPIResponse |
|
|
|
|
|
from api.schemas import ( |
|
|
GenerateRequest, |
|
|
GenerateResponse, |
|
|
GenerateBatchRequest, |
|
|
BatchResponse, |
|
|
) |
|
|
from services.generator import ad_generator |
|
|
from services.auth_dependency import get_current_user |
|
|
from config import settings |
|
|
|
|
|
router = APIRouter(tags=["generate"]) |
|
|
|
|
|
|
|
|
@router.post("/generate", response_model=GenerateResponse) |
|
|
async def generate( |
|
|
request: GenerateRequest, |
|
|
username: str = Depends(get_current_user), |
|
|
): |
|
|
""" |
|
|
Generate a single ad creative. |
|
|
Requires authentication. Uses randomization for strategies, hooks, visuals. |
|
|
""" |
|
|
try: |
|
|
return await ad_generator.generate_ad( |
|
|
niche=request.niche, |
|
|
num_images=request.num_images, |
|
|
image_model=request.image_model, |
|
|
username=username, |
|
|
target_audience=request.target_audience, |
|
|
offer=request.offer, |
|
|
use_trending=request.use_trending, |
|
|
trending_context=request.trending_context, |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.post("/generate/batch", response_model=BatchResponse) |
|
|
async def generate_batch( |
|
|
request: GenerateBatchRequest, |
|
|
username: str = Depends(get_current_user), |
|
|
): |
|
|
""" |
|
|
Generate multiple ad creatives in batch. |
|
|
Requires authentication. Each ad is unique due to randomization. |
|
|
""" |
|
|
try: |
|
|
results = await ad_generator.generate_batch( |
|
|
niche=request.niche, |
|
|
count=request.count, |
|
|
images_per_ad=request.images_per_ad, |
|
|
image_model=request.image_model, |
|
|
username=username, |
|
|
method=request.method, |
|
|
target_audience=request.target_audience, |
|
|
offer=request.offer, |
|
|
) |
|
|
return {"count": len(results), "ads": results} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.get("/image/{filename}") |
|
|
async def get_image(filename: str): |
|
|
"""Get a generated image by filename.""" |
|
|
filepath = os.path.join(settings.output_dir, filename) |
|
|
if not os.path.exists(filepath): |
|
|
raise HTTPException(status_code=404, detail="Image not found") |
|
|
return FileResponse(filepath) |
|
|
|
|
|
|
|
|
@router.get("/api/download-image") |
|
|
async def download_image_proxy( |
|
|
image_url: Optional[str] = None, |
|
|
image_id: Optional[str] = None, |
|
|
username: str = Depends(get_current_user), |
|
|
): |
|
|
""" |
|
|
Proxy endpoint to download images, avoiding CORS. |
|
|
Can fetch from external URLs (R2, Replicate) or local files. |
|
|
""" |
|
|
import httpx |
|
|
|
|
|
from services.database import db_service |
|
|
|
|
|
filename = None |
|
|
if image_id: |
|
|
ad = await db_service.get_ad_creative(image_id) |
|
|
if not ad: |
|
|
raise HTTPException(status_code=404, detail="Ad not found") |
|
|
if ad.get("username") != username: |
|
|
raise HTTPException(status_code=403, detail="Access denied") |
|
|
if not image_url: |
|
|
image_url = ad.get("r2_url") or ad.get("image_url") |
|
|
filename = ad.get("image_filename") |
|
|
else: |
|
|
metadata = ad.get("metadata", {}) |
|
|
if metadata.get("original_r2_url") == image_url or metadata.get("original_image_url") == image_url: |
|
|
filename = metadata.get("original_image_filename") |
|
|
|
|
|
if not image_url: |
|
|
raise HTTPException(status_code=400, detail="No image URL provided") |
|
|
|
|
|
try: |
|
|
if not image_url.startswith(("http://", "https://")): |
|
|
filepath = os.path.join(settings.output_dir, image_url) |
|
|
if os.path.exists(filepath): |
|
|
return FileResponse(filepath, filename=filename or os.path.basename(filepath)) |
|
|
raise HTTPException(status_code=404, detail="Image file not found") |
|
|
async with httpx.AsyncClient(timeout=30.0) as client: |
|
|
response = await client.get(image_url) |
|
|
response.raise_for_status() |
|
|
content_type = response.headers.get("content-type", "image/png") |
|
|
if not filename: |
|
|
filename = image_url.split("/")[-1].split("?")[0] |
|
|
if not filename or "." not in filename: |
|
|
filename = "image.png" |
|
|
return FastAPIResponse( |
|
|
content=response.content, |
|
|
media_type=content_type, |
|
|
headers={ |
|
|
"Content-Disposition": f'attachment; filename="{filename}"', |
|
|
"Cache-Control": "public, max-age=3600", |
|
|
}, |
|
|
) |
|
|
except httpx.HTTPError as e: |
|
|
raise HTTPException(status_code=502, detail=f"Failed to fetch image: {str(e)}") |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error downloading image: {str(e)}") |
|
|
|
|
|
|
|
|
@router.get("/api/models") |
|
|
async def list_image_models(): |
|
|
""" |
|
|
List all available image generation models. |
|
|
""" |
|
|
from services.image import MODEL_REGISTRY |
|
|
|
|
|
preferred_order = ["nano-banana", "nano-banana-pro", "z-image-turbo", "imagen-4-ultra", "recraft-v3", "ideogram-v3", "photon", "seedream-3"] |
|
|
models = [] |
|
|
for key in preferred_order: |
|
|
if key in MODEL_REGISTRY: |
|
|
config = MODEL_REGISTRY[key] |
|
|
models.append({ |
|
|
"key": key, |
|
|
"id": config["id"], |
|
|
"uses_dimensions": config.get("uses_dimensions", False), |
|
|
}) |
|
|
for key, config in MODEL_REGISTRY.items(): |
|
|
if key not in preferred_order: |
|
|
models.append({ |
|
|
"key": key, |
|
|
"id": config["id"], |
|
|
"uses_dimensions": config.get("uses_dimensions", False), |
|
|
}) |
|
|
models.append({ |
|
|
"key": "gpt-image-1.5", |
|
|
"id": "openai/gpt-image-1.5", |
|
|
"uses_dimensions": True, |
|
|
}) |
|
|
return {"models": models, "default": "nano-banana"} |
|
|
|
|
|
|
|
|
@router.get("/strategies/{niche}") |
|
|
async def get_strategies(niche: Literal["home_insurance", "glp1", "auto_insurance"]): |
|
|
"""Get available psychological strategies for a niche.""" |
|
|
from data import home_insurance, glp1, auto_insurance |
|
|
|
|
|
if niche == "home_insurance": |
|
|
data = home_insurance.get_niche_data() |
|
|
elif niche == "auto_insurance": |
|
|
data = auto_insurance.get_niche_data() |
|
|
else: |
|
|
data = glp1.get_niche_data() |
|
|
|
|
|
strategies = {} |
|
|
for name, strategy in data["strategies"].items(): |
|
|
strategies[name] = { |
|
|
"name": strategy["name"], |
|
|
"description": strategy["description"], |
|
|
"hook_count": len(strategy["hooks"]), |
|
|
"sample_hooks": strategy["hooks"][:3], |
|
|
} |
|
|
return { |
|
|
"niche": niche, |
|
|
"total_strategies": len(strategies), |
|
|
"total_hooks": len(data["all_hooks"]), |
|
|
"strategies": strategies, |
|
|
} |
|
|
|