diff --git "a/main.py" "b/main.py"
--- "a/main.py"
+++ "b/main.py"
@@ -1,96 +1,71 @@
"""
PsyAdGenesis - FastAPI Application
-Design ads that stop the scroll. Generate high-converting ad creatives for Home Insurance and GLP-1 niches
-Saves all ads to Neon PostgreSQL database with image URLs
+Design ads that stop the scroll. Generate high-converting ad creatives for Home Insurance and GLP-1 niches.
+Saves all ads to Neon PostgreSQL database with image URLs.
"""
+import os
from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException, Request, Response, Depends, BackgroundTasks
+import httpx
+from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
-from fastapi.responses import FileResponse, StreamingResponse, Response as FastAPIResponse
-from pydantic import BaseModel, Field
-from typing import Optional, List, Literal, Any, Dict
-from datetime import datetime
-import asyncio
-import os
-import logging
-import time
-import random
-import uuid
+from fastapi.responses import Response as FastAPIResponse
from starlette.middleware.gzip import GZipMiddleware
-import httpx
from starlette.requests import Request as StarletteRequest
-from services.generator import ad_generator
-from services.matrix import matrix_service
from services.database import db_service
-from services.correction import correction_service
-from services.image import image_service
-from services.auth import auth_service
-from services.auth_dependency import get_current_user
-from services.motivator import generate_motivators as motivator_generate
-from services.trend_monitor import trend_monitor
-from services.export_service import export_service
from config import settings
+from api.routers import get_all_routers
+
+
# Configure logging for API
+import logging
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
)
-api_logger = logging.getLogger("api")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events."""
- # Startup: Connect to database
print("Starting PsyAdGenesis...")
await db_service.connect()
yield
- # Shutdown: Disconnect from database
print("Shutting down...")
await db_service.disconnect()
-# Create FastAPI app with lifespan for database connection
app = FastAPI(
title="PsyAdGenesis",
- description="Design ads that stop the scroll. Generate high-converting ad creatives for Home Insurance and GLP-1 niches using psychological triggers and AI-powered image generation.",
+ description="Design ads that stop the scroll. Generate high-converting ad creatives using psychological triggers and AI-powered image generation.",
version="2.0.0",
lifespan=lifespan,
)
-# Compression middleware (gzip responses)
+# Middleware
app.add_middleware(GZipMiddleware, minimum_size=1000)
-# CORS middleware
-# Allow localhost for development and Hugging Face Spaces for deployment
cors_origins = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
-
-# Add custom origins from environment variable
if os.getenv("CORS_ORIGINS"):
- cors_origins.extend([origin.strip() for origin in os.getenv("CORS_ORIGINS").split(",")])
-
-# For Hugging Face Spaces, use regex to match any .hf.space domain
-# Note: If deploying to HF Spaces, add your Space URL to CORS_ORIGINS env var
-# Example: CORS_ORIGINS=https://your-username-psyadgenesis.hf.space
+ cors_origins.extend([o.strip() for o in os.getenv("CORS_ORIGINS").split(",")])
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
- allow_origin_regex=r"https://.*\.hf\.space", # Match any Hugging Face Space
+ allow_origin_regex=r"https://.*\.hf\.space",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
-# Add cache headers for static files
+
@app.middleware("http")
async def add_cache_headers(request: Request, call_next):
response = await call_next(request)
@@ -98,2625 +73,61 @@ async def add_cache_headers(request: Request, call_next):
response.headers["Cache-Control"] = "public, max-age=31536000, immutable"
return response
-# Serve generated images
+
+# Static files
os.makedirs(settings.output_dir, exist_ok=True)
app.mount("/images", StaticFiles(directory=settings.output_dir), name="images")
-# Serve Next.js static files directly (better performance than proxying)
frontend_static_path = os.path.join(os.path.dirname(__file__), "frontend", ".next", "static")
if os.path.exists(frontend_static_path):
app.mount("/_next/static", StaticFiles(directory=frontend_static_path), name="nextjs_static")
-# Request/Response schemas
-class GenerateRequest(BaseModel):
- """Request schema for ad generation."""
- niche: Literal["home_insurance", "glp1", "auto_insurance"] = Field(
- description="Target niche: home_insurance, glp1, or auto_insurance"
- )
- num_images: int = Field(
- default=1,
- ge=1,
- le=10,
- description="Number of images to generate (1-10)"
- )
- image_model: Optional[str] = Field(
- default=None,
- description="Image generation model to use (e.g., 'z-image-turbo', 'nano-banana', 'nano-banana-pro', 'imagen-4-ultra', 'recraft-v3', 'ideogram-v3', 'photon', 'seedream-3')"
- )
- target_audience: Optional[str] = Field(
- default=None,
- description="Optional target audience description (e.g., 'US people over 50+ age')"
- )
- offer: Optional[str] = Field(
- default=None,
- description="Optional offer to run (e.g., 'Don't overpay your insurance')"
- )
- use_trending: bool = Field(
- default=False,
- description="Whether to incorporate current trending topics from Google News"
- )
- trending_context: Optional[str] = Field(
- default=None,
- description="Specific trending context to use (auto-fetched if not provided when use_trending=True)"
- )
-
-
-class GenerateBatchRequest(BaseModel):
- """Request schema for batch ad generation."""
- niche: Literal["home_insurance", "glp1", "auto_insurance"] = Field(
- description="Target niche: home_insurance, glp1, or auto_insurance"
- )
- count: int = Field(
- default=5,
- ge=1,
- le=100,
- description="Number of ads to generate (1-100)"
- )
- images_per_ad: int = Field(
- default=1,
- ge=1,
- le=3,
- description="Images per ad (1-3)"
- )
- image_model: Optional[str] = Field(
- default=None,
- description="Image generation model to use (e.g., 'z-image-turbo', 'nano-banana', 'nano-banana-pro', 'imagen-4-ultra', 'recraft-v3', 'ideogram-v3', 'photon', 'seedream-3')"
- )
- method: Optional[Literal["standard", "matrix"]] = Field(
- default=None,
- description="Generation method: 'standard' for standard method only, 'matrix' for matrix method only, None for mixed (50/50)"
- )
- target_audience: Optional[str] = Field(
- default=None,
- description="Optional target audience description (e.g., 'US people over 50+ age')"
- )
- offer: Optional[str] = Field(
- default=None,
- description="Optional offer to run (e.g., 'Don't overpay your insurance')"
- )
-
-
-class ImageResult(BaseModel):
- """Image result schema."""
- filename: Optional[str] = None
- filepath: Optional[str] = None
- image_url: Optional[str] = Field(default=None, description="Direct URL to the image (Replicate hosted)")
- model_used: Optional[str] = None
- seed: Optional[int] = None
- error: Optional[str] = None
-
-
-class AdMetadata(BaseModel):
- """Metadata about the generation."""
- strategies_used: List[str]
- creative_direction: str
- visual_mood: str
- framework: Optional[str] = None
- camera_angle: Optional[str] = None
- lighting: Optional[str] = None
- composition: Optional[str] = None
- hooks_inspiration: List[str]
- visual_styles: List[str]
-
-
-class GenerateResponse(BaseModel):
- """Response schema for ad generation."""
- id: str
- niche: str
- created_at: str
- title: Optional[str] = Field(default=None, description="Short punchy ad title (3-5 words) - only for extensive flow")
- headline: str
- primary_text: str
- description: str
- body_story: str = Field(description="Compelling 8-12 sentence story that hooks emotionally")
- cta: str
- psychological_angle: str
- why_it_works: Optional[str] = None
- images: List[ImageResult]
- metadata: AdMetadata
-
-
-class BatchResponse(BaseModel):
- """Response schema for batch generation."""
- count: int
- ads: List[GenerateResponse]
-
-
-# Matrix-based schemas
-class MatrixGenerateRequest(BaseModel):
- """Request for angle × concept matrix generation."""
- niche: Literal["home_insurance", "glp1", "auto_insurance"] = Field(
- description="Target niche"
- )
- angle_key: Optional[str] = Field(
- default=None,
- description="Specific angle key (random if not provided)"
- )
- concept_key: Optional[str] = Field(
- default=None,
- description="Specific concept key (random if not provided)"
- )
- custom_angle: Optional[str] = Field(
- default=None,
- description="Custom angle text (AI will structure it properly). Used when angle_key is 'custom'"
- )
- custom_concept: Optional[str] = Field(
- default=None,
- description="Custom concept text (AI will structure it properly). Used when concept_key is 'custom'"
- )
- num_images: int = Field(
- default=1,
- ge=1,
- le=5,
- description="Number of images to generate"
- )
- image_model: Optional[str] = Field(
- default=None,
- description="Image generation model to use (e.g., 'z-image-turbo', 'nano-banana', 'nano-banana-pro', 'imagen-4-ultra', 'recraft-v3', 'ideogram-v3', 'photon', 'seedream-3')"
- )
- target_audience: Optional[str] = Field(
- default=None,
- description="Optional target audience description (e.g., 'US people over 50+ age')"
- )
- offer: Optional[str] = Field(
- default=None,
- description="Optional offer to run (e.g., 'Don't overpay your insurance')"
- )
- core_motivator: Optional[str] = Field(
- default=None,
- description="Optional motivator selected by user to guide ad generation"
- )
-
-
-class RefineCustomRequest(BaseModel):
- """Request to refine custom angle or concept text using AI."""
- text: str = Field(
- description="The raw custom text from user"
- )
- type: Literal["angle", "concept"] = Field(
- description="Whether this is an angle or concept"
- )
- niche: Literal["home_insurance", "glp1", "auto_insurance"] = Field(
- description="Target niche for context"
- )
- goal: Optional[str] = Field(
- default=None,
- description="Optional user goal or context"
- )
-
-
-class RefinedAngleResponse(BaseModel):
- """Response for refined angle."""
- key: str = Field(default="custom")
- name: str
- trigger: str
- example: str
- category: str = Field(default="Custom")
- original_text: str
-
-
-class RefinedConceptResponse(BaseModel):
- """Response for refined concept."""
- key: str = Field(default="custom")
- name: str
- structure: str
- visual: str
- category: str = Field(default="Custom")
- original_text: str
-
-
-class RefineCustomResponse(BaseModel):
- """Response for refined custom angle or concept."""
- status: str
- type: Literal["angle", "concept"]
- refined: Optional[dict] = None
- error: Optional[str] = None
-
-
-# Motivator generation (angle + concept context)
-class MotivatorGenerateRequest(BaseModel):
- """Request to generate motivators from niche + angle + concept."""
- niche: Literal["home_insurance", "glp1", "auto_insurance"] = Field(description="Target niche")
- angle: Dict[str, Any] = Field(
- description="Angle context: name, trigger, example (and optional key, category)"
- )
- concept: Dict[str, Any] = Field(
- description="Concept context: name, structure, visual (and optional key, category)"
- )
- target_audience: Optional[str] = Field(default=None, description="Optional target audience")
- offer: Optional[str] = Field(default=None, description="Optional offer")
- count: int = Field(default=6, ge=3, le=10, description="Number of motivators to generate")
-
-
-class MotivatorGenerateResponse(BaseModel):
- """Response with generated motivators."""
- motivators: List[str]
-
-
-class MatrixBatchRequest(BaseModel):
- """Request for batch matrix generation."""
- niche: Literal["home_insurance", "glp1"] = Field(
- description="Target niche"
- )
- angle_count: int = Field(
- default=6,
- ge=1,
- le=10,
- description="Number of angles to test"
- )
- concept_count: int = Field(
- default=5,
- ge=1,
- le=10,
- description="Number of concepts per angle"
- )
- strategy: Literal["balanced", "top_performers", "diverse"] = Field(
- default="balanced",
- description="Selection strategy"
- )
-
-
-class AngleInfo(BaseModel):
- """Angle information."""
- key: str
- name: str
- trigger: str
- category: str
-
-
-class ConceptInfo(BaseModel):
- """Concept information."""
- key: str
- name: str
- structure: str
- visual: str
- category: str
-
-
-class MatrixMetadata(BaseModel):
- """Matrix generation metadata."""
- generation_method: str = "angle_concept_matrix"
-
-
-class MatrixResult(BaseModel):
- """Result from matrix-based generation."""
- angle: AngleInfo
- concept: ConceptInfo
-
-
-class MatrixGenerateResponse(BaseModel):
- """Response for matrix-based ad generation."""
- id: str
- niche: str
- created_at: str
- title: Optional[str] = Field(default=None, description="Short punchy ad title (3-5 words) - not used in matrix flow")
- headline: str
- primary_text: str
- description: str
- body_story: str = Field(description="Compelling 8-12 sentence story that hooks emotionally")
- cta: str
- psychological_angle: str
- why_it_works: Optional[str] = None
- images: List[ImageResult]
- matrix: MatrixResult
- metadata: MatrixMetadata
-
-
-class CombinationInfo(BaseModel):
- """Info about a single angle × concept combination."""
- combination_id: str
- angle: AngleInfo
- concept: ConceptInfo
- compatibility_score: float
- prompt_guidance: str
-
-
-class MatrixSummary(BaseModel):
- """Summary of a testing matrix."""
- total_combinations: int
- unique_angles: int
- unique_concepts: int
- average_compatibility: float
- angles_used: List[str]
- concepts_used: List[str]
-
-
-class TestingMatrixResponse(BaseModel):
- """Response for testing matrix generation."""
- niche: str
- strategy: str
- summary: MatrixSummary
- combinations: List[CombinationInfo]
-
-
-# Endpoints
-@app.get("/api/info")
-async def api_info():
- """API info endpoint."""
- return {
- "name": "PsyAdGenesis",
- "version": "2.0.0",
- "description": "Design ads that stop the scroll. Generate high-converting ads using Angle × Concept matrix system",
- "endpoints": {
- "POST /generate": "Generate single ad (original mode)",
- "POST /generate/batch": "Generate multiple ads (original mode)",
- "POST /matrix/generate": "Generate ad using Angle × Concept matrix",
- "POST /matrix/testing": "Generate testing matrix (30 combinations)",
- "GET /matrix/angles": "List all 100 angles",
- "GET /matrix/concepts": "List all 100 concepts",
- "GET /matrix/angle/{key}": "Get specific angle details",
- "GET /matrix/concept/{key}": "Get specific concept details",
- "GET /matrix/compatible/{angle_key}": "Get compatible concepts for angle",
- "POST /extensive/generate": "Generate ad using extensive (researcher → creative director → designer → copywriter)",
- "POST /api/motivator/generate": "Generate motivators from niche + angle + concept (Matrix mode)",
- "POST /api/correct": "Correct image for spelling mistakes and visual issues (requires image_id)",
- "POST /api/regenerate": "Regenerate image with optional model selection (requires image_id)",
- "GET /api/models": "List all available image generation models",
- "POST /api/creative/upload": "Upload a creative image for analysis",
- "POST /api/creative/analyze": "Analyze a creative image with AI vision (via URL)",
- "POST /api/creative/analyze/upload": "Analyze a creative image with AI vision (via file upload)",
- "POST /api/creative/modify": "Modify a creative with new angle/concept",
- "GET /api/trends/{niche}": "Get current trending topics from Google News",
- "GET /api/trends/angles/{niche}": "Get auto-generated angles from trending topics",
- "GET /health": "Health check",
- },
- "supported_niches": ["home_insurance", "glp1"],
- "matrix_system": {
- "total_angles": 100,
- "total_concepts": 100,
- "possible_combinations": 10000,
- "formula": "1 Offer → 5-8 Angles → 3-5 Concepts per angle",
- },
- }
-
-
-# Root route - proxy to Next.js frontend
-@app.get("/")
-async def root():
- """Proxy root to Next.js frontend."""
- try:
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.get("http://localhost:3000/")
- # Return full response (not streaming) for better HF compatibility
- return FastAPIResponse(
- content=response.content,
- status_code=response.status_code,
- headers={k: v for k, v in response.headers.items() if k.lower() not in ['content-encoding', 'transfer-encoding', 'content-length']},
- media_type=response.headers.get("content-type"),
- )
- except httpx.RequestError:
- # Return a simple HTML page if frontend is not ready yet
- return FastAPIResponse(
- content="
Loading...
",
- status_code=200,
- media_type="text/html"
- )
-
-
-@app.get("/health")
-async def health():
- """Health check endpoint for Hugging Face Spaces."""
- return {"status": "ok"}
-
-
-# =============================================================================
-# TRENDING TOPICS ENDPOINTS
-# =============================================================================
-
-@app.get("/api/trends/{niche}")
-async def get_trends(
- niche: Literal["home_insurance", "glp1", "auto_insurance"],
- username: str = Depends(get_current_user)
-):
- """
- Get current trending topics for a niche from Google News.
-
- Requires authentication.
-
- 🚧 COMING SOON - This feature is currently under development.
-
- Returns top 5 most relevant news articles with context for ad generation.
- Articles are scored by relevance, recency, and emotional triggers.
-
- Results are cached for 1 hour to avoid rate limits.
- """
- # Feature temporarily disabled - coming soon
- return {
- "status": "coming_soon",
- "message": "🔥 Trending Topics feature is coming soon! Stay tuned.",
- "niche": niche,
- "trends": [],
- "count": 0,
- "available_soon": True
- }
-
- # Original implementation (commented out for later)
- # try:
- # trends = await trend_monitor.fetch_trends(niche)
- # return {
- # "niche": niche,
- # "trends": trends,
- # "count": len(trends),
- # "fetched_at": datetime.now().isoformat()
- # }
- # except Exception as e:
- # raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.get("/api/trends/angles/{niche}")
-async def get_trending_angles(
- niche: Literal["home_insurance", "glp1", "auto_insurance"],
- username: str = Depends(get_current_user)
-):
- """
- Get auto-generated angle suggestions based on current trends.
-
- Requires authentication.
-
- 🚧 COMING SOON - This feature is currently under development.
-
- These trending angles can be used in matrix generation like regular angles.
- Each angle is generated from a real news article with:
- - Detected psychological trigger
- - Relevance score
- - Expiry date (7 days)
- """
- # Feature temporarily disabled - coming soon
- return {
- "status": "coming_soon",
- "message": "🔥 Trending Topics feature is coming soon! Stay tuned.",
- "niche": niche,
- "trending_angles": [],
- "count": 0,
- "available_soon": True
- }
-
- # Original implementation (commented out for later)
- # try:
- # angles = await trend_monitor.get_trending_angles(niche)
- # return {
- # "niche": niche,
- # "trending_angles": angles,
- # "count": len(angles),
- # "fetched_at": datetime.now().isoformat()
- # }
- # except Exception as e:
- # raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# AUTHENTICATION ENDPOINTS
-# =============================================================================
-
-class LoginRequest(BaseModel):
- """Login request schema."""
- username: str = Field(description="Username")
- password: str = Field(description="Password")
-
-
-class LoginResponse(BaseModel):
- """Login response schema."""
- token: str
- username: str
- message: str = "Login successful"
-
-
-@app.post("/auth/login", response_model=LoginResponse)
-async def login(request: LoginRequest):
- """
- Authenticate a user and return a JWT token.
-
- Credentials must be created manually using the create_user.py script.
- """
- # Get user from database
- user = await db_service.get_user(request.username)
- if not user:
- raise HTTPException(
- status_code=401,
- detail="Invalid username or password"
- )
-
- # Verify password
- hashed_password = user.get("hashed_password")
- if not hashed_password:
- raise HTTPException(
- status_code=500,
- detail="User data corrupted"
- )
-
- if not auth_service.verify_password(request.password, hashed_password):
- raise HTTPException(
- status_code=401,
- detail="Invalid username or password"
- )
-
- # Create access token
- token = auth_service.create_access_token(request.username)
-
- return {
- "token": token,
- "username": request.username,
- "message": "Login successful"
- }
-
-
-@app.post("/generate", response_model=GenerateResponse)
-async def generate(
- request: GenerateRequest,
- username: str = Depends(get_current_user)
-):
- """
- Generate a single ad creative.
-
- Requires authentication. Users can only see their own generated ads.
-
- Uses maximum randomization to ensure different results every time:
- - Random psychological strategies (2-3 combined)
- - Random hooks and angles
- - Random visual styles and moods
- - Random seed for image generation
-
- Supports niches:
- - home_insurance: Fear, urgency, savings, authority, guilt strategies
- - glp1: Shame, transformation, FOMO, authority, simplicity strategies
-
- Trending Topics Integration:
- - Set use_trending=True to incorporate current Google News trends
- - Optionally provide trending_context, or it will be auto-fetched
- """
- try:
- result = await ad_generator.generate_ad(
- niche=request.niche,
- num_images=request.num_images,
- image_model=request.image_model,
- username=username, # Pass current user
- target_audience=request.target_audience,
- offer=request.offer,
- use_trending=request.use_trending,
- trending_context=request.trending_context,
- )
- return result
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/generate/batch", response_model=BatchResponse)
-async def generate_batch(
- request: GenerateBatchRequest,
- username: str = Depends(get_current_user)
-):
- """
- Generate multiple ad creatives in batch.
-
- Requires authentication. Users can only see their own generated ads.
-
- Each ad will be unique due to randomization:
- - Different strategy combinations
- - Different hooks and angles
- - Different visual styles
- - Different random seeds
- """
- try:
- results = await ad_generator.generate_batch(
- niche=request.niche,
- count=request.count,
- images_per_ad=request.images_per_ad,
- image_model=request.image_model,
- username=username, # Pass current user
- method=request.method, # Pass method parameter
- target_audience=request.target_audience,
- offer=request.offer,
- )
- return {
- "count": len(results),
- "ads": results,
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.get("/image/{filename}")
-async def get_image(filename: str):
- """Get a generated image by filename."""
- filepath = os.path.join(settings.output_dir, filename)
- if not os.path.exists(filepath):
- raise HTTPException(status_code=404, detail="Image not found")
- return FileResponse(filepath)
-
-
-@app.get("/api/download-image")
-async def download_image_proxy(
- image_url: Optional[str] = None,
- image_id: Optional[str] = None,
- username: str = Depends(get_current_user)
-):
- """
- Proxy endpoint to download images, avoiding CORS issues.
- Can fetch from external URLs (R2, Replicate) or local files.
- """
- import httpx
-
- filename = None
-
- # If image_id is provided, verify ownership
- if image_id:
- ad = await db_service.get_ad_creative(image_id)
- if not ad:
- raise HTTPException(status_code=404, detail="Ad not found")
-
- # Verify ownership
- if ad.get("username") != username:
- raise HTTPException(status_code=403, detail="Access denied")
-
- # Only use ad's image URL if image_url was not explicitly provided
- # This allows downloading original images from metadata
- if not image_url:
- image_url = ad.get("r2_url") or ad.get("image_url")
- filename = ad.get("image_filename")
- else:
- # If image_url is provided, try to get filename from metadata if it's an original image
- metadata = ad.get("metadata", {})
- if metadata.get("original_r2_url") == image_url or metadata.get("original_image_url") == image_url:
- filename = metadata.get("original_image_filename")
-
- if not image_url:
- raise HTTPException(status_code=400, detail="No image URL provided")
-
- try:
- # Check if it's a local file
- if not image_url.startswith(("http://", "https://")):
- # Local file
- filepath = os.path.join(settings.output_dir, image_url)
- if os.path.exists(filepath):
- return FileResponse(filepath, filename=filename or os.path.basename(filepath))
- raise HTTPException(status_code=404, detail="Image file not found")
-
- # External URL - fetch and proxy
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.get(image_url)
- response.raise_for_status()
-
- # Determine content type
- content_type = response.headers.get("content-type", "image/png")
-
- # Get filename from URL or use provided filename
- if not filename:
- # Try to extract from URL
- filename = image_url.split("/")[-1].split("?")[0]
- if not filename or "." not in filename:
- filename = "image.png"
-
- return FastAPIResponse(
- content=response.content,
- media_type=content_type,
- headers={
- "Content-Disposition": f'attachment; filename="{filename}"',
- "Cache-Control": "public, max-age=3600",
- }
- )
- except httpx.HTTPError as e:
- raise HTTPException(status_code=502, detail=f"Failed to fetch image: {str(e)}")
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Error downloading image: {str(e)}")
-
-
-# =============================================================================
-# IMAGE CORRECTION ENDPOINTS
-# =============================================================================
-
-class ImageCorrectRequest(BaseModel):
- """Request schema for image correction."""
- image_id: str = Field(
- description="ID of existing ad creative in database, or 'temp-id' for images not in DB"
- )
- image_url: Optional[str] = Field(
- default=None,
- description="Optional image URL for images not in DB (required if image_id='temp-id')"
- )
- user_instructions: Optional[str] = Field(
- default=None,
- description="User-specified instructions for what to correct (e.g., 'Fix spelling in the headline', 'Adjust colors', 'Change text to X')"
- )
- auto_analyze: bool = Field(
- default=False,
- description="Whether to automatically analyze the image for issues (if user_instructions not provided)"
- )
-
-
-class SpellingCorrection(BaseModel):
- """Spelling correction entry."""
- detected: str
- corrected: str
- context: Optional[str] = None
-
-
-class VisualCorrection(BaseModel):
- """Visual correction entry."""
- issue: str
- suggestion: str
- priority: Optional[str] = None
-
-
-class CorrectionData(BaseModel):
- """Correction data structure."""
- spelling_corrections: List[SpellingCorrection]
- visual_corrections: List[VisualCorrection]
- corrected_prompt: str
-
-
-class CorrectedImageResult(BaseModel):
- """Corrected image result."""
- filename: Optional[str] = None
- filepath: Optional[str] = None
- image_url: Optional[str] = None
- r2_url: Optional[str] = None
- model_used: Optional[str] = None
- corrected_prompt: Optional[str] = None
-
-
-class ImageCorrectResponse(BaseModel):
- """Response schema for image correction."""
- status: str
- analysis: Optional[str] = None
- corrections: Optional[CorrectionData] = None
- corrected_image: Optional[CorrectedImageResult] = None
- error: Optional[str] = None
-
-
-@app.post("/api/correct", response_model=ImageCorrectResponse)
-async def correct_image(
- request: ImageCorrectRequest,
- username: str = Depends(get_current_user)
-):
- """
- Correct an image by analyzing it for spelling mistakes and visual issues,
- then regenerating a corrected version using nano-banana-pro model.
-
- Requires authentication. Users can only correct their own ads.
-
- The service will automatically fetch the image and metadata from the database
- using the provided image_id, then:
- 1. Analyze the image using GPT-4 Vision for text and visual issues
- 2. Generate a structured correction JSON with spelling and visual fixes
- 3. Regenerate the image using nano-banana-pro model with corrected prompt and original image
- 4. Return the corrected image along with analysis and corrections
- """
- api_start_time = time.time()
- api_logger.info("=" * 80)
- api_logger.info(f"API: Correction request received")
- api_logger.info(f"User: {username}")
- api_logger.info(f"Image ID: {request.image_id}")
- api_logger.info(f"Auto-analyze: {request.auto_analyze}")
- api_logger.info(f"User instructions: {request.user_instructions or 'None'}")
-
- try:
- # Fetch ad from database to get image and metadata (only if it belongs to current user)
- image_url = request.image_url
- ad = None
-
- if request.image_id != "temp-id":
- api_logger.info(f"Fetching ad creative from database...")
- ad = await db_service.get_ad_creative(request.image_id, username=username)
- if not ad:
- api_logger.error(f"Ad creative {request.image_id} not found or access denied for user {username}")
- raise HTTPException(status_code=404, detail=f"Ad creative with ID {request.image_id} not found or access denied")
-
- api_logger.info(f"Ad creative found: {ad.get('title', 'N/A')} (niche: {ad.get('niche', 'N/A')})")
-
- # Get image URL from ad if not provided in request
- if not image_url:
- image_url = ad.get("r2_url") or ad.get("image_url")
-
- if not image_url:
- api_logger.error(f"Image URL not found for request")
- raise HTTPException(
- status_code=400,
- detail="Image URL must be provided for images not in database, or found in database for provided ID"
- )
-
- api_logger.info(f"Image URL: {image_url}")
-
- # Load image bytes for analysis (needed for vision API)
- api_logger.info("Loading image bytes for analysis...")
- image_bytes = await image_service.load_image(
- image_id=request.image_id if request.image_id != "temp-id" else None,
- image_url=image_url,
- image_bytes=None,
- filepath=None,
- )
-
- if not image_bytes:
- api_logger.error(f"Failed to load image bytes for request")
- raise HTTPException(
- status_code=404,
- detail="Image not found for analysis. Please ensure the URL is accessible."
- )
-
- api_logger.info(f"Image bytes loaded: {len(image_bytes)} bytes")
-
- # Get original prompt from ad metadata if available
- original_prompt = ad.get("image_prompt") or None if ad else None
- if original_prompt:
- api_logger.info(f"Original prompt available: {len(original_prompt)} characters")
-
- # Perform correction
- api_logger.info("Starting correction workflow...")
- result = await correction_service.correct_image(
- image_bytes=image_bytes,
- image_url=image_url, # Pass URL for image-to-image generation
- original_prompt=original_prompt,
- width=1024,
- height=1024,
- niche=ad.get("niche") if ad else "others",
- user_instructions=request.user_instructions,
- auto_analyze=request.auto_analyze,
- )
-
- # Format response
- api_logger.info(f"Correction workflow status: {result.get('status')}")
- response_data = {
- "status": result["status"],
- "analysis": result.get("analysis"),
- "corrections": None,
- "corrected_image": None,
- "error": result.get("error"),
- }
-
- # Format corrections if available
- if result.get("corrections"):
- corrections = result["corrections"]
- api_logger.info(f"Corrections generated: {len(corrections.get('spelling_corrections', []))} spelling, {len(corrections.get('visual_corrections', []))} visual")
- response_data["corrections"] = {
- "spelling_corrections": corrections.get("spelling_corrections", []),
- "visual_corrections": corrections.get("visual_corrections", []),
- "corrected_prompt": corrections.get("corrected_prompt", ""),
- }
-
- # Format corrected image if available
- if result.get("corrected_image"):
- corrected_img = result["corrected_image"]
- api_logger.info(f"Corrected image generated: {corrected_img.get('filename')}")
- api_logger.info(f"Corrected image URL: {corrected_img.get('image_url')}")
- response_data["corrected_image"] = {
- "filename": corrected_img.get("filename"),
- "filepath": corrected_img.get("filepath"),
- "image_url": corrected_img.get("image_url"),
- "r2_url": corrected_img.get("r2_url"),
- "model_used": corrected_img.get("model_used"),
- "corrected_prompt": corrected_img.get("corrected_prompt"),
- }
-
- # Update original ad with corrected image (instead of creating new one)
- if result.get("status") == "success" and result.get("_db_metadata"):
- db_metadata = result["_db_metadata"]
- try:
- api_logger.info("Updating original ad with corrected image...")
-
- # Store old image data in metadata before updating
- old_image_url = ad.get("r2_url") or ad.get("image_url")
- old_r2_url = ad.get("r2_url")
- old_image_filename = ad.get("image_filename")
- old_image_model = ad.get("image_model")
- old_image_prompt = ad.get("image_prompt")
-
- # Prepare metadata with old image info and corrections
- # Only include fields that have values
- correction_metadata = {
- "is_corrected": True,
- "correction_date": datetime.utcnow().isoformat() + "Z",
- }
-
- # Add original image data only if it exists
- if old_image_url:
- correction_metadata["original_image_url"] = old_image_url
- if old_r2_url:
- correction_metadata["original_r2_url"] = old_r2_url
- if old_image_filename:
- correction_metadata["original_image_filename"] = old_image_filename
- if old_image_model:
- correction_metadata["original_image_model"] = old_image_model
- if old_image_prompt:
- correction_metadata["original_image_prompt"] = old_image_prompt
- if result.get("corrections"):
- correction_metadata["corrections"] = result.get("corrections")
-
- api_logger.info(f"Prepared correction metadata: {correction_metadata}")
- api_logger.info(f"Old image URL: {old_image_url}")
- api_logger.info(f"Old R2 URL: {old_r2_url}")
-
- # Update the original ad with corrected image
- # Also update r2_url if available
- update_kwargs = {
- "image_url": db_metadata.get("image_url"),
- "image_filename": db_metadata.get("filename"),
- "image_model": db_metadata.get("model_used"),
- "image_prompt": db_metadata.get("corrected_prompt"),
- }
- if db_metadata.get("r2_url"):
- update_kwargs["r2_url"] = db_metadata.get("r2_url")
-
- api_logger.info(f"Updating ad {request.image_id} with metadata: {correction_metadata}")
- update_success = await db_service.update_ad_creative(
- ad_id=request.image_id,
- username=username,
- metadata=correction_metadata,
- **update_kwargs
- )
- api_logger.info(f"Update success: {update_success}")
-
- if update_success:
- api_logger.info(f"✓ Original ad updated with corrected image (ID: {request.image_id})")
- # Add the updated ad ID to the response; ensure corrected_image is a dict
- if not response_data.get("corrected_image"):
- response_data["corrected_image"] = {}
- response_data["corrected_image"]["ad_id"] = request.image_id
- else:
- api_logger.warning("Failed to update ad with corrected image (update returned False)")
- except Exception as e:
- api_logger.error(f"Failed to update ad with corrected image: {e}")
- api_logger.exception("Database update error details:")
- # Don't fail the request if database update fails
-
- total_api_time = time.time() - api_start_time
- api_logger.info("=" * 80)
- if result.get("status") == "success":
- api_logger.info(f"✓ API: Correction request completed successfully in {total_api_time:.2f}s")
- else:
- api_logger.error(f"✗ API: Correction request failed after {total_api_time:.2f}s")
- api_logger.error(f"Error: {result.get('error', 'Unknown error')}")
- api_logger.info("=" * 80)
-
- return response_data
-
- except HTTPException:
- total_api_time = time.time() - api_start_time
- api_logger.error(f"✗ API: Correction request failed with HTTPException after {total_api_time:.2f}s")
- raise
- except Exception as e:
- total_api_time = time.time() - api_start_time
- api_logger.error(f"✗ API: Correction request failed with exception after {total_api_time:.2f}s: {str(e)}")
- api_logger.exception("Full exception traceback:")
- raise HTTPException(status_code=500, detail=str(e))
-
-
-
-
-# =============================================================================
-# IMAGE REGENERATION ENDPOINTS
-# =============================================================================
-
-class ImageRegenerateRequest(BaseModel):
- """Request schema for image regeneration."""
- image_id: str = Field(
- description="ID of existing ad creative in database"
- )
- image_model: Optional[str] = Field(
- default=None,
- description="Image generation model to use (e.g., 'z-image-turbo', 'nano-banana', 'nano-banana-pro', 'imagen-4-ultra', 'recraft-v3', 'ideogram-v3', 'photon', 'seedream-3'). If not provided, uses the original model."
- )
- preview_only: bool = Field(
- default=True,
- description="If True, generates preview without updating database. User can then confirm selection."
- )
-
-
-class RegeneratedImageResult(BaseModel):
- """Regenerated image result."""
- filename: Optional[str] = None
- filepath: Optional[str] = None
- image_url: Optional[str] = None
- r2_url: Optional[str] = None
- model_used: Optional[str] = None
- prompt_used: Optional[str] = None
- seed_used: Optional[int] = None
-
-
-class ImageRegenerateResponse(BaseModel):
- """Response schema for image regeneration."""
- status: str
- regenerated_image: Optional[RegeneratedImageResult] = None
- original_image_url: Optional[str] = None
- original_preserved: bool = Field(default=True, description="Whether original image info was preserved in metadata")
- is_preview: bool = Field(default=False, description="Whether this is a preview (not yet saved)")
- error: Optional[str] = None
-
-
-class ImageSelectionRequest(BaseModel):
- """Request schema for confirming image selection after regeneration."""
- image_id: str = Field(description="ID of existing ad creative in database")
- selection: str = Field(description="Which image to keep: 'new' or 'original'")
- new_image_url: Optional[str] = Field(default=None, description="URL of the new image (required if selection='new')")
- new_r2_url: Optional[str] = Field(default=None, description="R2 URL of the new image")
- new_filename: Optional[str] = Field(default=None, description="Filename of the new image")
- new_model: Optional[str] = Field(default=None, description="Model used for the new image")
- new_seed: Optional[int] = Field(default=None, description="Seed used for the new image")
-
-
-@app.post("/api/regenerate", response_model=ImageRegenerateResponse)
-async def regenerate_image(
- request: ImageRegenerateRequest,
- username: str = Depends(get_current_user)
-):
- """
- Regenerate an image for an existing ad creative with an optional new model.
-
- Requires authentication. Users can only regenerate their own ads.
-
- If preview_only=True (default):
- - Generates a new image and uploads to storage
- - Returns both old and new image URLs for comparison
- - Does NOT update the database yet
-
- If preview_only=False:
- - Generates and immediately saves to database (legacy behavior)
- """
- api_start_time = time.time()
- api_logger.info("=" * 80)
- api_logger.info(f"API: Regeneration request received")
- api_logger.info(f"User: {username}")
- api_logger.info(f"Image ID: {request.image_id}")
- api_logger.info(f"Requested model: {request.image_model or 'Use original'}")
- api_logger.info(f"Preview only: {request.preview_only}")
-
- try:
- # Fetch ad from database (only if it belongs to current user)
- api_logger.info(f"Fetching ad creative from database...")
- ad = await db_service.get_ad_creative(request.image_id, username=username)
- if not ad:
- api_logger.error(f"Ad creative {request.image_id} not found or access denied for user {username}")
- raise HTTPException(status_code=404, detail=f"Ad creative with ID {request.image_id} not found or access denied")
-
- api_logger.info(f"Ad creative found: {ad.get('title', 'N/A')} (niche: {ad.get('niche', 'N/A')})")
-
- # Get the image prompt
- image_prompt = ad.get("image_prompt")
- if not image_prompt:
- api_logger.error(f"No image prompt found for ad {request.image_id}")
- raise HTTPException(
- status_code=400,
- detail="No image prompt found for this ad creative. Cannot regenerate without a prompt."
- )
-
- api_logger.info(f"Image prompt found: {len(image_prompt)} characters")
-
- # Determine which model to use
- model_to_use = request.image_model or ad.get("image_model") or settings.image_model
- api_logger.info(f"Using model: {model_to_use}")
-
- # Generate a new random seed for variety
- seed = random.randint(1, 2147483647)
- api_logger.info(f"Using seed: {seed}")
-
- # Generate the new image
- api_logger.info("Generating new image...")
- try:
- image_bytes, model_used, generated_url = await image_service.generate(
- prompt=image_prompt,
- width=1024,
- height=1024,
- seed=seed,
- model_key=model_to_use,
- )
- except Exception as e:
- api_logger.error(f"Image generation failed: {e}")
- raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
-
- api_logger.info(f"Image generated successfully with model: {model_used}")
-
- # Generate filename for the new image
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- unique_id = uuid.uuid4().hex[:8]
- niche = ad.get("niche", "unknown").replace(" ", "_")
- filename = f"regen_{niche}_{timestamp}_{unique_id}.png"
-
- # Try to upload to R2
- r2_url = None
- try:
- from services.r2_storage import get_r2_storage
- r2_storage = get_r2_storage()
- if r2_storage and image_bytes:
- r2_url = r2_storage.upload_image(
- image_bytes=image_bytes,
- filename=filename,
- niche=niche,
- )
- api_logger.info(f"Uploaded to R2: {r2_url}")
- except Exception as e:
- api_logger.warning(f"R2 upload failed: {e}")
-
- # Save locally as fallback
- local_path = None
- if not r2_url and image_bytes:
- local_path = os.path.join(settings.output_dir, filename)
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
- with open(local_path, "wb") as f:
- f.write(image_bytes)
- api_logger.info(f"Saved locally: {local_path}")
-
- # Get original image URL for comparison
- original_image_url = ad.get("r2_url") or ad.get("image_url")
-
- # Determine the new image URL
- new_image_url = r2_url or generated_url or f"/images/{filename}"
-
- # If preview_only, return without updating database
- if request.preview_only:
- total_api_time = time.time() - api_start_time
- api_logger.info("=" * 80)
- api_logger.info(f"✓ API: Regeneration preview completed in {total_api_time:.2f}s (not saved to DB)")
- api_logger.info("=" * 80)
-
- return {
- "status": "success",
- "regenerated_image": {
- "filename": filename,
- "filepath": local_path,
- "image_url": new_image_url,
- "r2_url": r2_url,
- "model_used": model_used,
- "prompt_used": image_prompt,
- "seed_used": seed,
- },
- "original_image_url": original_image_url,
- "original_preserved": True,
- "is_preview": True,
- }
-
- # If not preview_only, update the database immediately (legacy behavior)
- old_r2_url = ad.get("r2_url")
- old_image_filename = ad.get("image_filename")
- old_image_model = ad.get("image_model")
- old_seed = ad.get("image_seed")
-
- # Build metadata with original image info
- regeneration_metadata = {
- "is_regenerated": True,
- "regeneration_date": datetime.utcnow().isoformat() + "Z",
- "regeneration_seed": seed,
- }
-
- if original_image_url:
- regeneration_metadata["original_image_url"] = original_image_url
- if old_r2_url:
- regeneration_metadata["original_r2_url"] = old_r2_url
- if old_image_filename:
- regeneration_metadata["original_image_filename"] = old_image_filename
- if old_image_model:
- regeneration_metadata["original_image_model"] = old_image_model
- if old_seed:
- regeneration_metadata["original_seed"] = old_seed
-
- # Update the ad with new image
- update_kwargs = {
- "image_filename": filename,
- "image_model": model_used,
- "image_seed": seed,
- }
- if r2_url:
- update_kwargs["image_url"] = r2_url
- update_kwargs["r2_url"] = r2_url
- elif generated_url:
- update_kwargs["image_url"] = generated_url
- elif local_path:
- update_kwargs["image_url"] = f"/images/{filename}"
-
- api_logger.info(f"Updating ad {request.image_id} with new image...")
- update_success = await db_service.update_ad_creative(
- ad_id=request.image_id,
- username=username,
- metadata=regeneration_metadata,
- **update_kwargs
- )
-
- if update_success:
- api_logger.info(f"✓ Ad updated with regenerated image (ID: {request.image_id})")
- else:
- api_logger.warning("Failed to update ad with regenerated image")
-
- total_api_time = time.time() - api_start_time
- api_logger.info("=" * 80)
- api_logger.info(f"✓ API: Regeneration completed successfully in {total_api_time:.2f}s")
- api_logger.info("=" * 80)
-
- return {
- "status": "success",
- "regenerated_image": {
- "filename": filename,
- "filepath": local_path,
- "image_url": new_image_url,
- "r2_url": r2_url,
- "model_used": model_used,
- "prompt_used": image_prompt,
- "seed_used": seed,
- },
- "original_image_url": original_image_url,
- "original_preserved": True,
- "is_preview": False,
- }
-
- except HTTPException:
- total_api_time = time.time() - api_start_time
- api_logger.error(f"✗ API: Regeneration request failed with HTTPException after {total_api_time:.2f}s")
- raise
- except Exception as e:
- total_api_time = time.time() - api_start_time
- api_logger.error(f"✗ API: Regeneration request failed with exception after {total_api_time:.2f}s: {str(e)}")
- api_logger.exception("Full exception traceback:")
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/api/regenerate/confirm")
-async def confirm_image_selection(
- request: ImageSelectionRequest,
- username: str = Depends(get_current_user)
-):
- """
- Confirm the user's image selection after regeneration preview.
-
- If selection='new': Updates the ad with the new regenerated image
- If selection='original': Keeps the original image (no database update needed)
- """
- api_start_time = time.time()
- api_logger.info("=" * 80)
- api_logger.info(f"API: Image selection confirmation received")
- api_logger.info(f"User: {username}")
- api_logger.info(f"Image ID: {request.image_id}")
- api_logger.info(f"Selection: {request.selection}")
-
- try:
- # Validate selection value
- if request.selection not in ["new", "original"]:
- raise HTTPException(status_code=400, detail="Selection must be 'new' or 'original'")
-
- # Fetch ad from database (only if it belongs to current user)
- ad = await db_service.get_ad_creative(request.image_id, username=username)
- if not ad:
- api_logger.error(f"Ad creative {request.image_id} not found or access denied for user {username}")
- raise HTTPException(status_code=404, detail=f"Ad creative with ID {request.image_id} not found or access denied")
-
- if request.selection == "original":
- # User chose to keep the original - no database update needed
- api_logger.info("User chose to keep original image - no update needed")
- total_api_time = time.time() - api_start_time
- api_logger.info(f"✓ API: Selection confirmed (original kept) in {total_api_time:.2f}s")
- return {
- "status": "success",
- "message": "Original image kept",
- "selection": "original",
- }
-
- # User chose the new image - update the database
- if not request.new_image_url:
- raise HTTPException(status_code=400, detail="new_image_url is required when selection='new'")
-
- # Get original image info before updating
- original_image_url = ad.get("r2_url") or ad.get("image_url")
- original_r2_url = ad.get("r2_url")
- original_filename = ad.get("image_filename")
- original_model = ad.get("image_model")
- original_seed = ad.get("image_seed")
-
- # Build metadata with original image info
- regeneration_metadata = {
- "is_regenerated": True,
- "regeneration_date": datetime.utcnow().isoformat() + "Z",
- "regeneration_seed": request.new_seed,
- }
-
- if original_image_url:
- regeneration_metadata["original_image_url"] = original_image_url
- if original_r2_url:
- regeneration_metadata["original_r2_url"] = original_r2_url
- if original_filename:
- regeneration_metadata["original_image_filename"] = original_filename
- if original_model:
- regeneration_metadata["original_image_model"] = original_model
- if original_seed:
- regeneration_metadata["original_seed"] = original_seed
-
- # Update the ad with new image
- update_kwargs = {}
- if request.new_filename:
- update_kwargs["image_filename"] = request.new_filename
- if request.new_model:
- update_kwargs["image_model"] = request.new_model
- if request.new_seed:
- update_kwargs["image_seed"] = request.new_seed
- if request.new_r2_url:
- update_kwargs["image_url"] = request.new_r2_url
- update_kwargs["r2_url"] = request.new_r2_url
- elif request.new_image_url:
- update_kwargs["image_url"] = request.new_image_url
-
- api_logger.info(f"Updating ad {request.image_id} with new image...")
- update_success = await db_service.update_ad_creative(
- ad_id=request.image_id,
- username=username,
- metadata=regeneration_metadata,
- **update_kwargs
- )
-
- if update_success:
- api_logger.info(f"✓ Ad updated with selected new image (ID: {request.image_id})")
- else:
- api_logger.warning("Failed to update ad with new image")
- raise HTTPException(status_code=500, detail="Failed to update ad with new image")
-
- total_api_time = time.time() - api_start_time
- api_logger.info("=" * 80)
- api_logger.info(f"✓ API: Selection confirmed (new image saved) in {total_api_time:.2f}s")
- api_logger.info("=" * 80)
-
- return {
- "status": "success",
- "message": "New image saved",
- "selection": "new",
- "new_image_url": request.new_image_url,
- }
-
- except HTTPException:
- total_api_time = time.time() - api_start_time
- api_logger.error(f"✗ API: Selection confirmation failed with HTTPException after {total_api_time:.2f}s")
- raise
- except Exception as e:
- total_api_time = time.time() - api_start_time
- api_logger.error(f"✗ API: Selection confirmation failed with exception after {total_api_time:.2f}s: {str(e)}")
- api_logger.exception("Full exception traceback:")
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.get("/api/models")
-async def list_image_models():
- """
- List all available image generation models.
-
- Returns model keys and their descriptions for use in image generation/regeneration.
- Default for regeneration is nano-banana (best quality for affiliate marketing).
- """
- from services.image import MODEL_REGISTRY
-
- # Order models with nano-banana first (recommended for regeneration)
- preferred_order = ["nano-banana", "nano-banana-pro", "z-image-turbo", "imagen-4-ultra", "recraft-v3", "ideogram-v3", "photon", "seedream-3"]
-
- models = []
- # Add models in preferred order first
- for key in preferred_order:
- if key in MODEL_REGISTRY:
- config = MODEL_REGISTRY[key]
- models.append({
- "key": key,
- "id": config["id"],
- "uses_dimensions": config.get("uses_dimensions", False),
- })
-
- # Add any remaining models not in preferred order
- for key, config in MODEL_REGISTRY.items():
- if key not in preferred_order:
- models.append({
- "key": key,
- "id": config["id"],
- "uses_dimensions": config.get("uses_dimensions", False),
- })
-
- # Add OpenAI model at the end
- models.append({
- "key": "gpt-image-1.5",
- "id": "openai/gpt-image-1.5",
- "uses_dimensions": True,
- })
-
- return {
- "models": models,
- "default": "nano-banana", # Best for affiliate marketing regeneration
- }
-
-
-@app.get("/strategies/{niche}")
-async def get_strategies(niche: Literal["home_insurance", "glp1", "auto_insurance"]):
- """
- Get available psychological strategies for a niche.
-
- Useful for understanding what strategies will be used.
- """
- from data import home_insurance, glp1
-
- if niche == "home_insurance":
- data = home_insurance.get_niche_data()
- else:
- data = glp1.get_niche_data()
-
- strategies = {}
- for name, strategy in data["strategies"].items():
- strategies[name] = {
- "name": strategy["name"],
- "description": strategy["description"],
- "hook_count": len(strategy["hooks"]),
- "sample_hooks": strategy["hooks"][:3],
- }
-
- return {
- "niche": niche,
- "total_strategies": len(strategies),
- "total_hooks": len(data["all_hooks"]),
- "strategies": strategies,
- }
-
+# Include all API routers
+for router in get_all_routers():
+ app.include_router(router)
-# =============================================================================
-# ANGLE × CONCEPT MATRIX ENDPOINTS
-# =============================================================================
-@app.post("/matrix/generate", response_model=MatrixGenerateResponse)
-async def generate_with_matrix(
- request: MatrixGenerateRequest,
- username: str = Depends(get_current_user)
-):
- """
- Generate ad using the Angle × Concept matrix approach.
-
- Requires authentication. Users can only see their own generated ads.
-
- This provides systematic ad generation with explicit control over:
- - ANGLE: The psychological WHY (100 angles in 10 categories)
- - CONCEPT: The visual HOW (100 concepts in 10 categories)
-
- If angle_key and concept_key are not provided, a compatible
- combination will be selected automatically based on the niche.
-
- Supports custom angles/concepts:
- - Set angle_key='custom' and provide custom_angle text
- - Set concept_key='custom' and provide custom_concept text
- """
- try:
- result = await ad_generator.generate_ad_with_matrix(
- niche=request.niche,
- angle_key=request.angle_key,
- concept_key=request.concept_key,
- custom_angle=request.custom_angle,
- custom_concept=request.custom_concept,
- num_images=request.num_images,
- image_model=request.image_model,
- username=username,
- core_motivator=request.core_motivator,
- target_audience=request.target_audience,
- offer=request.offer,
- )
- return result
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/matrix/testing", response_model=TestingMatrixResponse)
-async def generate_testing_matrix(request: MatrixBatchRequest):
- """
- Generate a testing matrix for systematic ad testing.
-
- Implements the scaling formula:
- - 1 Offer → 5-8 Angles → 3-5 Concepts per angle
- - Default: 6 angles × 5 concepts = 30 combinations
-
- Strategies:
- - balanced: Mix of top performers and diverse selection
- - top_performers: Focus on proven winning angles/concepts
- - diverse: Maximum variety across categories
-
- Returns combinations WITHOUT generating images (for planning).
- """
- try:
- combinations = matrix_service.generate_testing_matrix(
- niche=request.niche,
- angle_count=request.angle_count,
- concept_count=request.concept_count,
- strategy=request.strategy,
- )
-
- summary = matrix_service.get_matrix_summary(combinations)
-
- return {
- "niche": request.niche,
- "strategy": request.strategy,
- "summary": summary,
- "combinations": combinations,
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.get("/matrix/angles")
-async def list_angles():
- """
- List all available angles (100 total, 10 categories).
-
- Angles answer: "Why should I care?" - the psychological WHY.
- """
- from data.angles import ANGLES, get_all_angles, AngleCategory
-
- categories = {}
- for cat_key, cat_data in ANGLES.items():
- categories[cat_key.value] = {
- "name": cat_data["name"],
- "angle_count": len(cat_data["angles"]),
- "angles": [
- {
- "key": a["key"],
- "name": a["name"],
- "trigger": a["trigger"],
- "example": a["example"],
- }
- for a in cat_data["angles"]
- ],
- }
-
- return {
- "total_angles": len(get_all_angles()),
- "categories": categories,
- }
-
-
-@app.get("/matrix/concepts")
-async def list_concepts():
- """
- List all available concepts (100 total, 10 categories).
-
- Concepts answer: "How do I show it?" - the visual HOW.
- """
- from data.concepts import CONCEPTS, get_all_concepts
-
- categories = {}
- for cat_key, cat_data in CONCEPTS.items():
- categories[cat_key.value] = {
- "name": cat_data["name"],
- "concept_count": len(cat_data["concepts"]),
- "concepts": [
- {
- "key": c["key"],
- "name": c["name"],
- "structure": c["structure"],
- "visual": c["visual"],
- }
- for c in cat_data["concepts"]
- ],
- }
-
- return {
- "total_concepts": len(get_all_concepts()),
- "categories": categories,
- }
-
-
-@app.get("/matrix/angle/{angle_key}")
-async def get_angle(angle_key: str):
- """Get details for a specific angle by key."""
- from data.angles import get_angle_by_key
-
- angle = get_angle_by_key(angle_key)
- if not angle:
- raise HTTPException(status_code=404, detail=f"Angle '{angle_key}' not found")
-
- return angle
-
-
-@app.get("/matrix/concept/{concept_key}")
-async def get_concept(concept_key: str):
- """Get details for a specific concept by key."""
- from data.concepts import get_concept_by_key
-
- concept = get_concept_by_key(concept_key)
- if not concept:
- raise HTTPException(status_code=404, detail=f"Concept '{concept_key}' not found")
-
- return concept
-
-
-@app.get("/matrix/compatible/{angle_key}")
-async def get_compatible_concepts(angle_key: str):
- """
- Get concepts compatible with a specific angle.
-
- Compatibility is based on psychological trigger matching.
- """
- from data.angles import get_angle_by_key
- from data.concepts import get_compatible_concepts as get_compatible
-
- angle = get_angle_by_key(angle_key)
- if not angle:
- raise HTTPException(status_code=404, detail=f"Angle '{angle_key}' not found")
-
- compatible = get_compatible(angle.get("trigger", ""))
-
- return {
- "angle": {
- "key": angle["key"],
- "name": angle["name"],
- "trigger": angle["trigger"],
- },
- "compatible_concepts": [
- {
- "key": c["key"],
- "name": c["name"],
- "structure": c["structure"],
- }
- for c in compatible
- ],
- }
-
-
-@app.post("/matrix/refine-custom", response_model=RefineCustomResponse)
-async def refine_custom_angle_or_concept(request: RefineCustomRequest):
- """
- Refine a custom angle or concept text using AI.
-
- This endpoint takes raw user input and structures it properly
- according to the angle/concept framework used in ad generation.
-
- For angles, it extracts:
- - name: Short descriptive name
- - trigger: Psychological trigger (e.g., Fear, Hope, Pride)
- - example: Example hook text
-
- For concepts, it extracts:
- - name: Short descriptive name
- - structure: How to structure the visual/copy
- - visual: Visual guidance for the image
- """
- try:
- result = await ad_generator.refine_custom_angle_or_concept(
- text=request.text,
- type=request.type,
- niche=request.niche,
- goal=request.goal,
- )
- return {
- "status": "success",
- "type": request.type,
- "refined": result,
- }
- except Exception as e:
- return {
- "status": "error",
- "type": request.type,
- "refined": None,
- "error": str(e),
- }
-
-
-# =============================================================================
-# MOTIVATOR ENDPOINTS
-# =============================================================================
-
-@app.post("/api/motivator/generate", response_model=MotivatorGenerateResponse)
-async def motivator_generate_endpoint(
- request: MotivatorGenerateRequest,
- username: str = Depends(get_current_user),
-):
- """
- Generate multiple motivators from niche + angle + concept context.
-
- Used in Matrix mode: user selects angle and concept, then generates motivators
- to review and pick one for ad generation.
- """
- try:
- motivators = await motivator_generate(
- niche=request.niche,
- angle=request.angle,
- concept=request.concept,
- target_audience=request.target_audience,
- offer=request.offer,
- count=request.count,
- )
- return MotivatorGenerateResponse(motivators=motivators)
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# EXTENSIVE ENDPOINTS (async job pattern to avoid connection timeout on HF Spaces)
-# =============================================================================
-
-# In-memory job store: job_id -> { status, result?, error?, username }
-_extensive_jobs: Dict[str, Dict[str, Any]] = {}
-
-class ExtensiveGenerateRequest(BaseModel):
- """Request for extensive generation."""
- niche: str = Field(
- description="Target niche: home_insurance, glp1, auto_insurance, or others (use custom_niche when others)"
- )
- custom_niche: Optional[str] = Field(
- default=None,
- description="Custom niche name when 'others' is selected"
- )
- target_audience: Optional[str] = Field(
- default=None,
- description="Optional target audience description (e.g., 'US people over 50+ age')"
- )
- offer: Optional[str] = Field(
- default=None,
- description="Optional offer to run (e.g., 'Don't overpay your insurance')"
- )
- num_images: int = Field(
- default=1,
- ge=1,
- le=3,
- description="Number of images to generate per strategy (1-3)"
- )
- image_model: Optional[str] = Field(
- default=None,
- description="Image generation model to use"
- )
- num_strategies: int = Field(
- default=5,
- ge=1,
- le=10,
- description="Number of creative strategies to generate (1-10)"
- )
-
-
-class ExtensiveJobResponse(BaseModel):
- """Response when extensive generation is started (202 Accepted)."""
- job_id: str
- message: str = "Extensive generation started. Poll /extensive/status/{job_id} for progress."
-
-
-async def _run_extensive_job_async(
- job_id: str,
- username: str,
- effective_niche: str,
- target_audience: Optional[str],
- offer: Optional[str],
- num_images: int,
- image_model: Optional[str],
- num_strategies: int,
-):
- """Run extensive generation on the main event loop so DB and other async code use the same loop."""
- try:
- results = await ad_generator.generate_ad_extensive(
- niche=effective_niche,
- target_audience=target_audience,
- offer=offer,
- num_images=num_images,
- image_model=image_model,
- num_strategies=num_strategies,
- username=username,
- )
- _extensive_jobs[job_id]["status"] = "completed"
- _extensive_jobs[job_id]["result"] = BatchResponse(count=len(results), ads=results)
- except Exception as e:
- api_logger.exception("Extensive job %s failed", job_id)
- _extensive_jobs[job_id]["status"] = "failed"
- _extensive_jobs[job_id]["error"] = str(e)
-
-
-@app.post("/extensive/generate", status_code=202)
-async def generate_extensive(
- request: ExtensiveGenerateRequest,
- username: str = Depends(get_current_user)
-):
- """
- Start extensive ad generation (researcher → creative director → designer → copywriter).
- Returns 202 with job_id. Poll GET /extensive/status/{job_id} then GET /extensive/result/{job_id}.
- Runs on the main event loop so DB (MongoDB) and other async code stay on the same loop.
- """
- if request.niche == "others":
- if not request.custom_niche or not request.custom_niche.strip():
- raise HTTPException(
- status_code=400,
- detail="custom_niche is required when niche is 'others'"
- )
- effective_niche = request.custom_niche.strip()
- else:
- effective_niche = request.niche
-
- job_id = str(uuid.uuid4())
- _extensive_jobs[job_id] = {
- "status": "running",
- "result": None,
- "error": None,
- "username": username,
- }
-
- asyncio.create_task(
- _run_extensive_job_async(
- job_id,
- username,
- effective_niche,
- request.target_audience,
- request.offer,
- request.num_images,
- request.image_model,
- request.num_strategies,
- )
- )
- return ExtensiveJobResponse(job_id=job_id)
-
-
-@app.get("/extensive/status/{job_id}")
-async def extensive_job_status(
- job_id: str,
- username: str = Depends(get_current_user)
-):
- """Get status of an extensive generation job."""
- if job_id not in _extensive_jobs:
- raise HTTPException(status_code=404, detail="Job not found")
- job = _extensive_jobs[job_id]
- if job["username"] != username:
- raise HTTPException(status_code=404, detail="Job not found")
- return {
- "job_id": job_id,
- "status": job["status"],
- "error": job.get("error") if job["status"] == "failed" else None,
- }
-
-
-@app.get("/extensive/result/{job_id}", response_model=BatchResponse)
-async def extensive_job_result(
- job_id: str,
- username: str = Depends(get_current_user)
-):
- """Get result of a completed extensive generation job. Returns 404 if not found, 425 if still running."""
- if job_id not in _extensive_jobs:
- raise HTTPException(status_code=404, detail="Job not found")
- job = _extensive_jobs[job_id]
- if job["username"] != username:
- raise HTTPException(status_code=404, detail="Job not found")
- if job["status"] == "running":
- raise HTTPException(status_code=425, detail="Generation still in progress")
- if job["status"] == "failed":
- raise HTTPException(status_code=500, detail=job.get("error", "Generation failed"))
- return job["result"]
-
-
-# =============================================================================
-# CREATIVE MODIFIER ENDPOINTS
-# =============================================================================
-
-from fastapi import File, UploadFile
-from services.creative_modifier import creative_modifier_service
-
-
-class CreativeAnalysisData(BaseModel):
- """Structured analysis of a creative."""
- visual_style: str
- color_palette: List[str]
- mood: str
- composition: str
- subject_matter: str
- text_content: Optional[str] = None
- current_angle: Optional[str] = None
- current_concept: Optional[str] = None
- target_audience: Optional[str] = None
- strengths: List[str]
- areas_for_improvement: List[str]
-
-
-class CreativeAnalyzeRequest(BaseModel):
- """Request for creative analysis."""
- image_url: Optional[str] = Field(
- default=None,
- description="URL of the image to analyze (alternative to file upload)"
- )
-
-
-class CreativeAnalysisResponse(BaseModel):
- """Response for creative analysis."""
- status: str
- analysis: Optional[CreativeAnalysisData] = None
- suggested_angles: Optional[List[str]] = None
- suggested_concepts: Optional[List[str]] = None
- error: Optional[str] = None
-
-
-class CreativeModifyRequest(BaseModel):
- """Request for creative modification."""
- image_url: str = Field(
- description="URL of the original image"
- )
- analysis: Optional[Dict[str, Any]] = Field(
- default=None,
- description="Previous analysis data (optional)"
- )
- angle: Optional[str] = Field(
- default=None,
- description="Angle to apply to the creative"
- )
- concept: Optional[str] = Field(
- default=None,
- description="Concept to apply to the creative"
- )
- mode: Literal["modify", "inspired"] = Field(
- default="modify",
- description="Modification mode: 'modify' for image-to-image, 'inspired' for new generation"
- )
- image_model: Optional[str] = Field(
- default=None,
- description="Image generation model to use"
- )
- user_prompt: Optional[str] = Field(
- default=None,
- description="Optional custom user prompt/instructions for modification"
- )
-
-
-class ModifiedImageResult(BaseModel):
- """Result of creative modification."""
- filename: Optional[str] = None
- filepath: Optional[str] = None
- image_url: Optional[str] = None
- r2_url: Optional[str] = None
- model_used: Optional[str] = None
- mode: Optional[str] = None
- applied_angle: Optional[str] = None
- applied_concept: Optional[str] = None
-
-
-class CreativeModifyResponse(BaseModel):
- """Response for creative modification."""
- status: str
- prompt: Optional[str] = None
- image: Optional[ModifiedImageResult] = None
- error: Optional[str] = None
-
-
-class FileUploadResponse(BaseModel):
- """Response for file upload."""
- status: str
- image_url: Optional[str] = None
- filename: Optional[str] = None
- error: Optional[str] = None
-
-
-@app.post("/api/creative/upload", response_model=FileUploadResponse)
-async def upload_creative(
- file: UploadFile = File(...),
- username: str = Depends(get_current_user)
-):
- """
- Upload a creative image for analysis and modification.
-
- Accepts PNG, JPG, JPEG, WebP files.
- Returns the uploaded image URL that can be used for subsequent analysis/modification.
- """
- # Validate file type
- allowed_types = ["image/png", "image/jpeg", "image/jpg", "image/webp"]
- if file.content_type not in allowed_types:
- raise HTTPException(
- status_code=400,
- detail=f"Invalid file type. Allowed: PNG, JPG, JPEG, WebP. Got: {file.content_type}"
- )
-
- # Check file size (max 10MB)
- contents = await file.read()
- if len(contents) > 10 * 1024 * 1024:
- raise HTTPException(
- status_code=400,
- detail="File too large. Maximum size is 10MB."
- )
-
- try:
- # Generate filename
- from datetime import datetime
- import uuid
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- unique_id = uuid.uuid4().hex[:8]
- ext = file.filename.split(".")[-1] if file.filename else "png"
- filename = f"upload_{username}_{timestamp}_{unique_id}.{ext}"
-
- # Try to upload to R2
- r2_url = None
- try:
- from services.r2_storage import get_r2_storage
- r2_storage = get_r2_storage()
- if r2_storage:
- r2_url = r2_storage.upload_image(
- image_bytes=contents,
- filename=filename,
- niche="uploads",
- )
- except Exception as e:
- api_logger.warning(f"R2 upload failed: {e}")
-
- # Save locally as fallback
- local_path = None
- if not r2_url:
- local_path = os.path.join(settings.output_dir, filename)
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
- with open(local_path, "wb") as f:
- f.write(contents)
- # Construct local URL
- r2_url = f"/images/{filename}"
-
- return {
- "status": "success",
- "image_url": r2_url,
- "filename": filename,
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/api/creative/analyze", response_model=CreativeAnalysisResponse)
-async def analyze_creative(
- request: CreativeAnalyzeRequest,
- username: str = Depends(get_current_user)
-):
- """
- Analyze a creative image using AI vision (via URL).
-
- Accepts image_url in request body.
-
- Returns detailed analysis including visual style, mood, current angle/concept,
- and suggestions for new angles and concepts.
- """
- if not request.image_url:
- raise HTTPException(
- status_code=400,
- detail="image_url must be provided"
- )
-
- # Fetch image from URL
- try:
- image_bytes = await image_service.load_image(image_url=request.image_url)
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {e}")
-
- if not image_bytes:
- raise HTTPException(status_code=400, detail="Failed to load image")
-
- try:
- result = await creative_modifier_service.analyze_creative(image_bytes)
-
- if result["status"] != "success":
- return CreativeAnalysisResponse(
- status="error",
- error=result.get("error", "Analysis failed")
- )
-
- return CreativeAnalysisResponse(
- status="success",
- analysis=result.get("analysis"),
- suggested_angles=result.get("suggested_angles"),
- suggested_concepts=result.get("suggested_concepts"),
- )
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/api/creative/analyze/upload", response_model=CreativeAnalysisResponse)
-async def analyze_creative_upload(
- file: UploadFile = File(...),
- username: str = Depends(get_current_user)
-):
- """
- Analyze a creative image using AI vision (via file upload).
-
- Accepts file upload via multipart form.
-
- Returns detailed analysis including visual style, mood, current angle/concept,
- and suggestions for new angles and concepts.
- """
- # Validate file type
- allowed_types = ["image/png", "image/jpeg", "image/jpg", "image/webp"]
- if file.content_type not in allowed_types:
- raise HTTPException(
- status_code=400,
- detail=f"Invalid file type. Allowed: PNG, JPG, JPEG, WebP. Got: {file.content_type}"
- )
-
- image_bytes = await file.read()
-
- if not image_bytes:
- raise HTTPException(status_code=400, detail="Failed to load image")
-
- try:
- result = await creative_modifier_service.analyze_creative(image_bytes)
-
- if result["status"] != "success":
- return CreativeAnalysisResponse(
- status="error",
- error=result.get("error", "Analysis failed")
- )
-
- return CreativeAnalysisResponse(
- status="success",
- analysis=result.get("analysis"),
- suggested_angles=result.get("suggested_angles"),
- suggested_concepts=result.get("suggested_concepts"),
- )
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@app.post("/api/creative/modify", response_model=CreativeModifyResponse)
-async def modify_creative(
- request: CreativeModifyRequest,
- username: str = Depends(get_current_user)
-):
- """
- Modify a creative based on user-provided angle and/or concept.
-
- Modes:
- - 'modify': Uses image-to-image to make targeted changes while preserving most of the original
- - 'inspired': Generates a completely new image inspired by the original
-
- At least one of angle or concept must be provided.
- """
- if not request.angle and not request.concept:
- raise HTTPException(
- status_code=400,
- detail="At least one of 'angle' or 'concept' must be provided"
- )
-
- # If no analysis provided, we need to analyze first
- analysis = request.analysis
- if not analysis:
- # Fetch and analyze the image
- try:
- image_bytes = await image_service.load_image(image_url=request.image_url)
- if not image_bytes:
- raise HTTPException(status_code=400, detail="Failed to load image from URL")
-
- analysis_result = await creative_modifier_service.analyze_creative(image_bytes)
- if analysis_result["status"] != "success":
- raise HTTPException(
- status_code=500,
- detail=f"Failed to analyze image: {analysis_result.get('error')}"
- )
- analysis = analysis_result.get("analysis", {})
- except HTTPException:
- raise
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to analyze image: {e}")
-
- try:
- result = await creative_modifier_service.modify_creative(
- image_url=request.image_url,
- analysis=analysis,
- user_angle=request.angle,
- user_concept=request.concept,
- mode=request.mode,
- image_model=request.image_model,
- user_prompt=request.user_prompt,
- )
-
- if result["status"] != "success":
- return CreativeModifyResponse(
- status="error",
- error=result.get("error", "Modification failed")
- )
-
- return CreativeModifyResponse(
- status="success",
- prompt=result.get("prompt"),
- image=result.get("image"),
- )
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# DATABASE ENDPOINTS
-# =============================================================================
-
-class AdCreativeDB(BaseModel):
- """Ad creative from database."""
- id: str
- niche: str
- title: Optional[str] = None
- headline: str
- primary_text: Optional[str] = None
- description: Optional[str] = None
- body_story: Optional[str] = None
- cta: Optional[str] = None
- psychological_angle: Optional[str] = None
- why_it_works: Optional[str] = None
- image_url: Optional[str] = None
- image_filename: Optional[str] = None
- image_model: Optional[str] = None
- image_seed: Optional[int] = None
- angle_key: Optional[str] = None
- angle_name: Optional[str] = None
- concept_key: Optional[str] = None
- concept_name: Optional[str] = None
- generation_method: Optional[str] = None
- created_at: Optional[str] = None
-
-
-class DbStatsResponse(BaseModel):
- """Database statistics response."""
- connected: bool
- total_ads: Optional[int] = None
- by_niche: Optional[Dict[str, int]] = None
- by_method: Optional[Dict[str, int]] = None
- error: Optional[str] = None
-
-
-@app.get("/db/stats", response_model=DbStatsResponse)
-async def get_database_stats(username: str = Depends(get_current_user)):
- """
- Get statistics about stored ad creatives for the current user.
-
- Requires authentication. Shows only the current user's ads.
-
- Shows total ads, breakdown by niche, and breakdown by generation method.
- """
- stats = await db_service.get_stats(username=username)
- return stats
-
-
-@app.get("/db/ads")
-async def list_stored_ads(
- niche: Optional[str] = None,
- generation_method: Optional[str] = None,
- limit: int = 50,
- offset: int = 0,
- username: str = Depends(get_current_user),
-):
- """
- List ad creatives stored in the database for the current user.
-
- Requires authentication. Users can only see their own ads.
-
- - Filter by niche (optional)
- - Filter by generation_method (optional)
- - Paginate with limit and offset
- - Returns ads with direct image URLs
- """
- ads, total = await db_service.list_ad_creatives(
- username=username,
- niche=niche,
- generation_method=generation_method,
- limit=limit,
- offset=offset,
- )
-
- return {
- "total": total,
- "limit": limit,
- "offset": offset,
- "ads": [
- {
- "id": str(ad.get("id", "")),
- "niche": ad.get("niche", ""),
- "title": ad.get("title"),
- "headline": ad.get("headline", ""),
- "primary_text": ad.get("primary_text"),
- "description": ad.get("description"),
- "body_story": ad.get("body_story"),
- "cta": ad.get("cta", ""),
- "psychological_angle": ad.get("psychological_angle", ""),
- "image_url": ad.get("image_url"),
- "r2_url": ad.get("r2_url"),
- "image_filename": ad.get("image_filename"),
- "image_model": ad.get("image_model"),
- "angle_key": ad.get("angle_key"),
- "concept_key": ad.get("concept_key"),
- "generation_method": ad.get("generation_method", "standard"),
- "created_at": ad.get("created_at"),
- }
- for ad in ads
- ],
- }
-
-
-@app.get("/db/ad/{ad_id}")
-async def get_stored_ad(ad_id: str):
- """
- Get a specific ad creative by ID.
-
- Returns full ad data including image URL that can be used directly in documents.
- """
- ad = await db_service.get_ad_creative(ad_id)
-
- if not ad:
- raise HTTPException(status_code=404, detail=f"Ad '{ad_id}' not found")
-
- return {
- "id": str(ad.get("id", "")),
- "niche": ad.get("niche", ""),
- "title": ad.get("title"),
- "headline": ad.get("headline", ""),
- "primary_text": ad.get("primary_text"),
- "description": ad.get("description"),
- "body_story": ad.get("body_story"),
- "cta": ad.get("cta", ""),
- "psychological_angle": ad.get("psychological_angle", ""),
- "why_it_works": ad.get("why_it_works"),
- "image_url": ad.get("image_url"),
- "image_filename": ad.get("image_filename"),
- "image_model": ad.get("image_model"),
- "image_seed": ad.get("image_seed"),
- "r2_url": ad.get("r2_url"), # Include r2_url in response
- "angle_key": ad.get("angle_key"),
- "angle_name": ad.get("angle_name"),
- "angle_trigger": ad.get("angle_trigger"),
- "angle_category": ad.get("angle_category"),
- "concept_key": ad.get("concept_key"),
- "concept_name": ad.get("concept_name"),
- "concept_structure": ad.get("concept_structure"),
- "concept_visual": ad.get("concept_visual"),
- "concept_category": ad.get("concept_category"),
- "generation_method": ad.get("generation_method", "standard"),
- "metadata": ad.get("metadata"),
- "created_at": ad.get("created_at"),
- "updated_at": ad.get("updated_at"),
- }
-
-
-@app.delete("/db/ad/{ad_id}")
-async def delete_stored_ad(ad_id: str, username: str = Depends(get_current_user)):
- """
- Delete an ad creative from the database.
-
- Requires authentication. Users can only delete their own ads.
- """
- success = await db_service.delete_ad_creative(ad_id, username=username)
-
- if not success:
- raise HTTPException(status_code=404, detail=f"Ad '{ad_id}' not found or could not be deleted")
-
- return {"success": True, "deleted_id": ad_id}
-
-
-class EditAdCopyRequest(BaseModel):
- """Request for editing ad copy."""
- ad_id: str = Field(description="ID of the ad to edit")
- field: Literal["title", "headline", "primary_text", "description", "body_story", "cta"] = Field(
- description="Field to edit"
- )
- value: str = Field(description="New value for the field (for manual edit) or current value (for AI edit)")
- mode: Literal["manual", "ai"] = Field(description="Edit mode: manual or AI")
- user_suggestion: Optional[str] = Field(
- default=None,
- description="User suggestion for AI editing (optional)"
- )
-
-
-@app.post("/db/ad/edit")
-async def edit_ad_copy(
- request: EditAdCopyRequest,
- username: str = Depends(get_current_user)
-):
- """
- Edit ad copy fields with manual or AI assistance.
-
- Requires authentication. Users can only edit their own ads.
-
- Modes:
- - manual: Directly update the field with the provided value
- - ai: Generate an improved version using AI, optionally with user suggestions
- """
- from services.llm import LLMService
-
- # Verify ad exists and belongs to user
- ad = await db_service.get_ad_creative(request.ad_id)
- if not ad:
- raise HTTPException(status_code=404, detail=f"Ad '{request.ad_id}' not found")
-
- if ad.get("username") != username:
- raise HTTPException(status_code=403, detail="You can only edit your own ads")
-
- if request.mode == "manual":
- # Direct update
- update_data = {request.field: request.value}
- success = await db_service.update_ad_creative(
- ad_id=request.ad_id,
- username=username,
- **update_data
- )
-
- if not success:
- raise HTTPException(status_code=500, detail="Failed to update ad")
-
- return {
- "edited_value": request.value,
- "success": True
- }
-
- else: # AI mode
- # Generate improved version using AI
- llm_service = LLMService()
-
- # Build context for AI
- field_labels = {
- "title": "title",
- "headline": "headline",
- "primary_text": "primary text",
- "description": "description",
- "body_story": "body story",
- "cta": "call to action"
- }
-
- field_label = field_labels.get(request.field, request.field)
- current_value = request.value
- niche = ad.get("niche", "general")
-
- # Build prompt
- system_prompt = f"""You are an expert copywriter specializing in high-converting ad copy for {niche.replace('_', ' ')}.
-Your task is to improve the {field_label} while maintaining its core message and emotional impact.
-Keep the same tone and style, but make it more compelling, clear, and effective."""
-
- user_prompt = f"""Current {field_label}:
-{current_value}
-
-"""
-
- if request.user_suggestion:
- user_prompt += f"""User's suggestion: {request.user_suggestion}
-
-"""
-
- user_prompt += f"""Please provide an improved version of this {field_label} that:
-1. Maintains the core message and emotional impact
-2. Is more compelling and engaging
-3. Follows best practices for {field_label} in ad copy
-4. {"Incorporates the user's suggestion" if request.user_suggestion else "Is optimized for conversion"}
-
-Return ONLY the improved {field_label} text, without any explanations or additional text."""
-
- try:
- edited_value = await llm_service.generate(
- prompt=user_prompt,
- system_prompt=system_prompt,
- temperature=0.7
- )
-
- # Clean up the response (remove quotes if wrapped)
- edited_value = edited_value.strip().strip('"').strip("'")
-
- return {
- "edited_value": edited_value,
- "success": True
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to generate AI edit: {str(e)}")
-
-
-# =============================================================================
-# BULK EXPORT ENDPOINTS
-# =============================================================================
-
-class BulkExportRequest(BaseModel):
- """Request schema for bulk export."""
- ad_ids: List[str] = Field(
- description="List of ad IDs to export",
- min_items=1,
- max_items=50
- )
-
-
-class BulkExportResponse(BaseModel):
- """Response schema for bulk export."""
- status: str
- message: str
- filename: str
-
-
-@app.post("/api/export/bulk")
-async def export_bulk_ads(
- request: BulkExportRequest,
- background_tasks: BackgroundTasks,
- username: str = Depends(get_current_user)
-):
- """
- Export multiple ad creatives as a ZIP package.
-
- Requires authentication. Users can only export their own ads.
-
- Creates a ZIP file containing:
- - /creatives/ folder with renamed images (nomenclature: {niche}_{concept}_{angle}_{date}_{version}.png)
- - ad_copy_data.xlsx with core fields (Headline, Title, Description, CTA, Psychological Angle, Image Filename, Image URL)
-
- Maximum 50 ads per export.
- """
- try:
- # Validate number of ads
- if len(request.ad_ids) > 50:
- raise HTTPException(
- status_code=400,
- detail="Maximum 50 ads can be exported at once"
- )
-
- # Fetch all ads and verify ownership
- ads = []
- for ad_id in request.ad_ids:
- ad = await db_service.get_ad_creative(ad_id, username=username)
- if not ad:
- raise HTTPException(
- status_code=404,
- detail=f"Ad '{ad_id}' not found or access denied"
- )
- ads.append(ad)
-
- # Create export package
- api_logger.info(f"Creating export package for {len(ads)} ads (user: {username})")
- zip_path = await export_service.create_export_package(ads)
-
- # Schedule cleanup after response is sent
- background_tasks.add_task(export_service.cleanup_zip, zip_path)
-
- return FileResponse(
- zip_path,
- media_type="application/zip",
- filename=os.path.basename(zip_path)
- )
-
- except HTTPException:
- raise
- except Exception as e:
- api_logger.error(f"Bulk export failed: {e}")
- raise HTTPException(status_code=500, detail=f"Export failed: {str(e)}")
-
-
-# Frontend proxy - forward non-API requests to Next.js
-# This must be LAST so it doesn't intercept API routes
+# Frontend proxy - must be last so it doesn't intercept API routes
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def frontend_proxy(path: str, request: StarletteRequest):
"""
Proxy frontend requests to Next.js server.
Smart routing based on path AND HTTP method.
"""
- # Exact API-only routes (never frontend)
- # Note: /health has its own explicit route, so not listed here
api_only_routes = [
- "auth/login", "api/correct", "api/download-image", "api/export/bulk",
- "db/stats", "db/ads", "strategies", "extensive/generate", "extensive/status", "extensive/result"
+ "auth/login", "api/correct", "api/download-image", "api/export/bulk",
+ "db/stats", "db/ads", "strategies", "extensive/generate", "extensive/status", "extensive/result",
]
-
- # Routes that are API for POST but frontend for GET
- # GET /generate -> frontend page, POST /generate -> API
api_post_routes = [
- "generate", "generate/batch", "matrix/generate", "matrix/testing"
+ "generate", "generate/batch", "matrix/generate", "matrix/testing",
]
-
- # Routes that are API for GET (data endpoints)
api_get_routes = [
- "matrix/angles", "matrix/concepts", "matrix/angle", "matrix/concept",
- "matrix/compatible", "db/ad"
- ]
-
- # Routes that are API for POST
- api_post_routes_additional = [
- "db/ad/edit"
+ "matrix/angles", "matrix/concepts", "matrix/angle", "matrix/concept",
+ "matrix/compatible", "db/ad",
]
-
- # Check if this is an API-only route
+ api_post_routes_additional = ["db/ad/edit"]
+
if any(path == route or path.startswith(f"{route}/") for route in api_only_routes):
raise HTTPException(status_code=404, detail="API endpoint not found")
-
- # Check if this is a POST to an API endpoint
if request.method == "POST" and any(path == route or path.startswith(f"{route}/") for route in api_post_routes):
raise HTTPException(status_code=404, detail="API endpoint not found")
-
- # Check additional POST routes
if request.method == "POST" and any(path == route or path.startswith(f"{route}/") for route in api_post_routes_additional):
raise HTTPException(status_code=404, detail="API endpoint not found")
-
- # Check if path starts with image serving routes
if path.startswith("image/") or path.startswith("images/"):
raise HTTPException(status_code=404, detail="API endpoint not found")
-
- # Skip static files - they're mounted directly
if path.startswith("_next/static/"):
raise HTTPException(status_code=404, detail="Static file not found")
-
- # Check if this is a GET for data endpoints (with specific patterns)
if request.method == "GET":
for route in api_get_routes:
- # Exact match or with path parameter (e.g., matrix/angle/fear)
- if path == route or (path.startswith(f"{route}/") and "/" in path[len(route)+1:] == False):
+ if path == route or (path.startswith(f"{route}/") and "/" not in path[len(route) + 1:]):
raise HTTPException(status_code=404, detail="API endpoint not found")
-
- # Everything else goes to Next.js frontend
+
try:
async with httpx.AsyncClient(timeout=30.0) as client:
- # Forward the request to Next.js
nextjs_url = f"http://localhost:3000/{path}"
-
- # Forward query parameters
if request.url.query:
nextjs_url += f"?{request.url.query}"
-
- # Forward request
response = await client.request(
method=request.method,
url=nextjs_url,
@@ -2724,24 +135,19 @@ async def frontend_proxy(path: str, request: StarletteRequest):
content=await request.body() if request.method in ["POST", "PUT", "PATCH"] else None,
follow_redirects=True,
)
-
- # Return full response (not streaming) for better compatibility
return FastAPIResponse(
content=response.content,
status_code=response.status_code,
- headers={k: v for k, v in response.headers.items() if k.lower() not in ['content-encoding', 'transfer-encoding', 'content-length']},
+ headers={k: v for k, v in response.headers.items() if k.lower() not in ["content-encoding", "transfer-encoding", "content-length"]},
media_type=response.headers.get("content-type"),
)
except httpx.RequestError:
- # If Next.js is not running, return a helpful error
raise HTTPException(
status_code=503,
- detail="Frontend server is not available. Please ensure Next.js is running on port 3000."
+ detail="Frontend server is not available. Please ensure Next.js is running on port 3000.",
)
-# Run with: uvicorn main:app --reload
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
-