acd23's picture
Upload src/api/main.py with huggingface_hub
4cd3de5 verified
"""
FastAPI Application - AI Reel Creator Platform API
RESTful endpoints for reel generation pipeline.
"""
import logging
from contextlib import asynccontextmanager
from typing import Optional
from uuid import UUID
from fastapi import FastAPI, HTTPException, BackgroundTasks, File, UploadFile, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from ..models.schemas import (
ReelQueryRequest,
ReelQueryResponse,
ReelRequest,
RequestStatus,
DurationTarget,
Platform,
Tone,
AspectRatio,
ReelScript,
ReelManifest,
AssetSwapRequest,
RegenerateRequest,
)
logger = logging.getLogger(__name__)
# ============================================================
# APP LIFECYCLE
# ============================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events."""
logger.info("AI Reel Creator Platform API starting...")
yield
logger.info("AI Reel Creator Platform API shutting down...")
app = FastAPI(
title="AI Reel Creator Platform",
description="End-to-end reel generation from raw assets to editable video output",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================
# REQUEST/RESPONSE MODELS
# ============================================================
class CreateReelRequest(BaseModel):
"""Request to create a new reel."""
user_query: str = Field(..., min_length=5, max_length=2000,
description="Description of the desired reel")
duration_target: str = Field(default="20s",
description="Target duration: 10s, 20s, 30s, 60s, or custom")
platform: str = Field(default="instagram_reels",
description="Target platform: instagram_reels, tiktok, youtube_shorts, etc.")
tone: str = Field(default="sporty",
description="Reel tone: sporty, elegant, technical, luxury, adventure, minimal, dynamic, serene")
aspect_ratio: Optional[str] = Field(default=None,
description="Aspect ratio: 9:16, 16:9, 1:1, 4:5")
brand_config_id: Optional[str] = Field(default=None,
description="UUID of brand config to use")
additional_constraints: Optional[dict] = Field(default=None,
description="Extra constraints: must_include, avoid_parts, etc.")
class ReelStatusResponse(BaseModel):
"""Status of a reel generation request."""
request_id: str
status: str
progress_percent: Optional[int] = None
message: Optional[str] = None
reel_script: Optional[dict] = None
reel_manifest: Optional[dict] = None
render_url: Optional[str] = None
class ReelManifestResponse(BaseModel):
"""Complete reel manifest with all resolved assets."""
request_id: str
manifest: dict
preview_available: bool = False
preview_url: Optional[str] = None
class SwapAssetRequest(BaseModel):
"""Request to swap an asset for a specific beat."""
request_id: str
beat_number: int = Field(..., ge=1, description="Beat number to modify")
new_asset_id: str = Field(..., description="UUID of the replacement asset")
new_video_event_id: Optional[str] = Field(default=None,
description="UUID of replacement video event (for video assets)")
class SwapAssetResponse(BaseModel):
"""Response after asset swap."""
request_id: str
beat_number: int
new_asset_id: str
new_file_path: str
manifest_updated: bool = True
class OrchestratorStatus(BaseModel):
"""System health and pipeline status."""
status: str
version: str = "1.0.0"
active_requests: int = 0
completed_requests: int = 0
failed_requests: int = 0
# ============================================================
# IN-MEMORY STATE (replace with DB in production)
# ============================================================
_requests: dict[str, ReelRequest] = {}
_scripts: dict[str, ReelScript] = {}
_manifests: dict[str, ReelManifest] = {}
_stats = {"active": 0, "completed": 0, "failed": 0}
# ============================================================
# ENDPOINTS
# ============================================================
@app.get("/health", response_model=OrchestratorStatus)
async def health_check():
"""Health check endpoint."""
return OrchestratorStatus(
status="healthy",
active_requests=_stats["active"],
completed_requests=_stats["completed"],
failed_requests=_stats["failed"],
)
@app.post("/api/v1/reels/create", response_model=ReelQueryResponse)
async def create_reel(request: CreateReelRequest, background_tasks: BackgroundTasks):
"""
Create a new reel generation request.
This endpoint validates the user query, creates a ReelRequest,
and initiates the planning pipeline. Returns a request_id for polling.
"""
try:
from ..pipelines.query_interface import QueryInterface
# Convert to ReelQueryRequest
query = ReelQueryRequest(
user_query=request.user_query,
duration_target=DurationTarget(request.duration_target),
platform=Platform(request.platform),
tone=Tone(request.tone),
aspect_ratio=AspectRatio(request.aspect_ratio) if request.aspect_ratio else None,
brand_config_id=request.brand_config_id,
additional_constraints=request.additional_constraints,
)
interface = QueryInterface()
brand_uuid = UUID(request.brand_config_id) if request.brand_config_id else None
reel_request = interface.create_reel_request(query, brand_config_id=brand_uuid)
# Store for tracking
_requests[str(reel_request.id)] = reel_request
_stats["active"] += 1
# In production: dispatch to orchestrator pipeline
# background_tasks.add_task(orchestrator_pipeline, reel_request)
return ReelQueryResponse(
request_id=str(reel_request.id),
status=RequestStatus.PLANNING,
message="Reel request accepted and planning has begun",
estimated_completion_seconds=30,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Error creating reel request")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/api/v1/reels/{request_id}/status", response_model=ReelStatusResponse)
async def get_reel_status(request_id: str):
"""
Get the current status of a reel generation request.
Poll this endpoint to track progress through:
PENDING -> PLANNING -> GENERATING -> REVIEW -> COMPLETED
"""
reel_request = _requests.get(request_id)
if not reel_request:
raise HTTPException(status_code=404, detail="Request not found")
# Determine progress
status_map = {
RequestStatus.PENDING: 5,
RequestStatus.PLANNING: 25,
RequestStatus.GENERATING: 50,
RequestStatus.REVIEW: 75,
RequestStatus.COMPLETED: 100,
RequestStatus.FAILED: 0,
}
progress = status_map.get(reel_request.status, 0)
response = ReelStatusResponse(
request_id=request_id,
status=reel_request.status.value,
progress_percent=progress,
message=f"Current status: {reel_request.status.value}",
)
# Include script/manifest if available
if request_id in _scripts:
response.reel_script = _scripts[request_id].model_dump()
if request_id in _manifests:
response.reel_manifest = _manifests[request_id].model_dump()
return response
@app.get("/api/v1/reels/{request_id}/script", response_model=ReelStatusResponse)
async def get_reel_script(request_id: str):
"""Get the generated reel script for a request."""
if request_id not in _scripts:
raise HTTPException(status_code=404, detail="Script not yet generated or request not found")
return ReelStatusResponse(
request_id=request_id,
status="script_ready",
reel_script=_scripts[request_id].model_dump(),
)
@app.get("/api/v1/reels/{request_id}/manifest", response_model=ReelManifestResponse)
async def get_reel_manifest(request_id: str):
"""Get the complete reel manifest with resolved assets."""
if request_id not in _manifests:
raise HTTPException(status_code=404, detail="Manifest not yet generated or request not found")
manifest = _manifests[request_id]
return ReelManifestResponse(
request_id=request_id,
manifest=manifest.model_dump(),
preview_available=False,
)
@app.post("/api/v1/reels/{request_id}/swap-asset", response_model=SwapAssetResponse)
async def swap_asset(request_id: str, swap: SwapAssetRequest):
"""
Swap an asset for a specific beat in the reel manifest.
Allows user to override the LLM-selected asset with one of the
top-k candidates or any available asset.
"""
if request_id not in _manifests:
raise HTTPException(status_code=404, detail="Manifest not found")
manifest = _manifests[request_id]
# Find the beat to swap
beat_found = False
new_file_path = "unknown"
for beat in manifest.beats:
if beat.beat_number == swap.beat_number:
beat.asset_id = UUID(swap.new_asset_id)
if swap.new_video_event_id:
beat.video_event_id = UUID(swap.new_video_event_id)
new_file_path = beat.file_path # Would update from asset DB
beat_found = True
break
if not beat_found:
raise HTTPException(status_code=400, detail=f"Beat {swap.beat_number} not found in manifest")
return SwapAssetResponse(
request_id=request_id,
beat_number=swap.beat_number,
new_asset_id=swap.new_asset_id,
new_file_path=new_file_path,
manifest_updated=True,
)
@app.post("/api/v1/reels/{request_id}/regenerate", response_model=ReelQueryResponse)
async def regenerate_reel(request_id: str, request: CreateReelRequest):
"""
Regenerate a reel with modified parameters while preserving the original request context.
The new request is created as a child of the original, allowing comparison.
"""
old_request = _requests.get(request_id)
if not old_request:
raise HTTPException(status_code=404, detail="Original request not found")
# Create new request with updated parameters
from ..pipelines.query_interface import QueryInterface
new_query = ReelQueryRequest(
user_query=request.user_query or old_request.user_query,
duration_target=DurationTarget(request.duration_target) if request.duration_target else old_request.duration_target,
platform=Platform(request.platform) if request.platform else old_request.platform,
tone=Tone(request.tone) if request.tone else old_request.tone,
aspect_ratio=AspectRatio(request.aspect_ratio) if request.aspect_ratio else old_request.aspect_ratio,
)
interface = QueryInterface()
new_reel_request = interface.create_reel_request(
new_query,
brand_config_id=old_request.brand_config_id,
)
_requests[str(new_reel_request.id)] = new_reel_request
_stats["active"] += 1
return ReelQueryResponse(
request_id=str(new_reel_request.id),
status=RequestStatus.PLANNING,
message="Regeneration request accepted",
estimated_completion_seconds=30,
)
@app.post("/api/v1/reels/{request_id}/render")
async def start_render(request_id: str, background_tasks: BackgroundTasks):
"""
Start rendering the reel from the manifest.
This triggers the FFmpeg composition pipeline (Phase 4).
"""
if request_id not in _manifests:
raise HTTPException(status_code=404, detail="Manifest not found - generate manifest first")
# In production: background_tasks.add_task(render_pipeline, request_id)
return {
"request_id": request_id,
"render_status": "queued",
"message": "Render job queued for processing",
"estimated_time_seconds": 120,
}
@app.get("/api/v1/reels/{request_id}/render-status")
async def get_render_status(request_id: str):
"""Get the status of a render job."""
return {
"request_id": request_id,
"render_status": "pending", # pending, rendering, completed, failed
"progress_percent": 0,
"output_url": None,
}
@app.get("/api/v1/reels/{request_id}/export/remotion")
async def export_remotion(request_id: str):
"""
Export the reel manifest as a Remotion project (Phase 5).
Returns a ZIP file containing the Remotion composition,
assets, and package configuration.
"""
if request_id not in _manifests:
raise HTTPException(status_code=404, detail="Manifest not found")
# In production: generate Remotion project and return download URL
return {
"request_id": request_id,
"export_status": "ready",
"download_url": f"/downloads/{request_id}/remotion_project.zip",
"composition_file": "src/Composition.tsx",
}
# ============================================================
# ASSET MANAGEMENT ENDPOINTS
# ============================================================
@app.post("/api/v1/assets/upload")
async def upload_asset(
file: UploadFile = File(...),
asset_type: str = "video",
source: Optional[str] = None,
):
"""Upload a raw asset (video, image, or audio) to the asset database."""
# In production: save to S3/R2, extract metadata, store in DB
return {
"asset_id": str(UUID(int=0)), # placeholder
"filename": file.filename,
"asset_type": asset_type,
"status": "uploaded",
"metadata_extraction": "queued",
}
@app.get("/api/v1/assets/search")
async def search_assets(
query: str,
asset_type: Optional[str] = None,
limit: int = 10,
subject_part: Optional[str] = None,
mood: Optional[str] = None,
):
"""
Search assets by semantic similarity using CLIP embeddings.
Query is embedded via CLIP and matched against stored asset embeddings.
"""
# In production: call embedding_store.search_assets(query_embedding, ...)
return {
"query": query,
"results": [],
"total": 0,
}
# ============================================================
# BROCHURE MANAGEMENT ENDPOINTS
# ============================================================
@app.post("/api/v1/brochures/upload")
async def upload_brochure(
file: UploadFile = File(...),
brand_name: Optional[str] = None,
background_tasks: BackgroundTasks = None,
):
"""
Upload a brochure PDF for parsing.
Triggers background processing: PDF -> nodes -> embeddings -> asset mapping.
"""
# In production: queue brochure parsing pipeline
return {
"processing_id": str(UUID(int=0)),
"filename": file.filename,
"status": "queued",
"estimated_nodes": None,
}
@app.get("/api/v1/brochures/{brochure_id}/nodes")
async def get_brochure_nodes(brochure_id: str):
"""Get all brochure nodes for a parsed brochure."""
return {
"brochure_id": brochure_id,
"nodes": [],
"total_nodes": 0,
}
@app.get("/api/v1/brochures/{brochure_id}/mappings")
async def get_brochure_mappings(brochure_id: str):
"""Get brochure-to-asset mappings for review."""
return {
"brochure_id": brochure_id,
"mappings": [],
"total_mappings": 0,
"approved": 0,
"pending": 0,
"rejected": 0,
}
# ============================================================
# RUN
# ============================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)