ace-step15-endpoint / services /pipeline_api.py
Andrew
Consolidate AF3/Qwen pipelines, endpoint templates, and setup docs
8bdd018
"""
Local orchestration API for AF3 captioning + ChatGPT cleanup pipeline.
"""
from __future__ import annotations
import tempfile
from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from af3_chatgpt_pipeline import (
DEFAULT_AF3_MODEL_ID,
DEFAULT_AF3_PROMPT,
DEFAULT_OPENAI_MODEL,
AF3EndpointClient,
AF3LocalClient,
run_af3_chatgpt_pipeline,
save_sidecar,
)
from utils.env_config import get_env, load_project_env
load_project_env()
def _resolve_token(name_upper: str, name_lower: str) -> str:
return get_env(name_upper, name_lower)
def _build_af3_client(
backend: str,
endpoint_url: str,
hf_token: str,
model_id: str,
device: str,
torch_dtype: str,
):
if backend == "hf_endpoint":
if not endpoint_url:
raise HTTPException(status_code=400, detail="AF3 endpoint backend requires endpoint_url")
return AF3EndpointClient(
endpoint_url=endpoint_url,
token=hf_token,
model_id=model_id or DEFAULT_AF3_MODEL_ID,
)
return AF3LocalClient(
model_id=model_id or DEFAULT_AF3_MODEL_ID,
device=device,
torch_dtype=torch_dtype,
)
app = FastAPI(title="AF3 + ChatGPT Pipeline API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
FRONTEND_DIST = Path(__file__).resolve().parents[1] / "react-ui" / "dist"
FRONTEND_ASSETS = FRONTEND_DIST / "assets"
if FRONTEND_ASSETS.exists():
app.mount("/assets", StaticFiles(directory=str(FRONTEND_ASSETS)), name="assets")
class PipelinePathRequest(BaseModel):
audio_path: str
backend: str = "hf_endpoint"
endpoint_url: str = ""
hf_token: str = ""
model_id: str = DEFAULT_AF3_MODEL_ID
device: str = "auto"
torch_dtype: str = "auto"
af3_prompt: str = DEFAULT_AF3_PROMPT
af3_max_new_tokens: int = 1400
af3_temperature: float = 0.1
openai_api_key: str = ""
openai_model: str = DEFAULT_OPENAI_MODEL
user_context: str = ""
artist_name: str = ""
track_name: str = ""
enable_web_search: bool = False
output_json: str = ""
@app.get("/api/health")
def health():
return {"ok": True}
@app.get("/", include_in_schema=False)
def serve_root():
if FRONTEND_DIST.exists():
index = FRONTEND_DIST / "index.html"
if index.exists():
return FileResponse(index)
return JSONResponse(
{
"ok": True,
"message": "Frontend build not found. Run `python af3_gui_app.py` or `npm --prefix react-ui run build`.",
}
)
@app.get("/api/config")
def config():
return {
"defaults": {
"backend": "hf_endpoint",
"endpoint_url": _resolve_token("HF_AF3_ENDPOINT_URL", "hf_af3_endpoint_url"),
"model_id": _resolve_token("AF3_MODEL_ID", "af3_model_id") or DEFAULT_AF3_MODEL_ID,
"openai_model": _resolve_token("OPENAI_MODEL", "openai_model") or DEFAULT_OPENAI_MODEL,
"af3_prompt": DEFAULT_AF3_PROMPT,
}
}
@app.post("/api/pipeline/run-path")
def run_pipeline_path(req: PipelinePathRequest):
audio_path = Path(req.audio_path)
if not audio_path.is_file():
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_path}")
hf_token = req.hf_token or _resolve_token("HF_TOKEN", "hf_token")
openai_key = req.openai_api_key or _resolve_token("OPENAI_API_KEY", "openai_api_key")
if not openai_key:
raise HTTPException(status_code=400, detail="OPENAI_API_KEY is required.")
endpoint_url = req.endpoint_url or _resolve_token("HF_AF3_ENDPOINT_URL", "hf_af3_endpoint_url")
af3_client = _build_af3_client(
backend=req.backend,
endpoint_url=endpoint_url,
hf_token=hf_token,
model_id=req.model_id,
device=req.device,
torch_dtype=req.torch_dtype,
)
try:
result = run_af3_chatgpt_pipeline(
audio_path=str(audio_path),
af3_client=af3_client,
af3_prompt=req.af3_prompt,
af3_max_new_tokens=req.af3_max_new_tokens,
af3_temperature=req.af3_temperature,
openai_api_key=openai_key,
openai_model=req.openai_model,
user_context=req.user_context,
artist_name=req.artist_name,
track_name=req.track_name,
enable_web_search=req.enable_web_search,
)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
output_json = req.output_json or str(audio_path.with_suffix(".json"))
save_path = save_sidecar(result["sidecar"], output_json)
return {
"saved_to": save_path,
"af3_analysis": result["af3_analysis"],
"cleaned": result["cleaned"],
"sidecar": result["sidecar"],
}
@app.post("/api/pipeline/run-upload")
async def run_pipeline_upload(
audio_file: UploadFile = File(...),
backend: str = Form("hf_endpoint"),
endpoint_url: str = Form(""),
hf_token: str = Form(""),
model_id: str = Form(DEFAULT_AF3_MODEL_ID),
device: str = Form("auto"),
torch_dtype: str = Form("auto"),
af3_prompt: str = Form(DEFAULT_AF3_PROMPT),
af3_max_new_tokens: int = Form(1400),
af3_temperature: float = Form(0.1),
openai_api_key: str = Form(""),
openai_model: str = Form(DEFAULT_OPENAI_MODEL),
user_context: str = Form(""),
artist_name: str = Form(""),
track_name: str = Form(""),
enable_web_search: bool = Form(False),
output_json: str = Form(""),
):
suffix = Path(audio_file.filename or "uploaded.wav").suffix or ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
temp_audio = Path(tmp.name)
try:
content = await audio_file.read()
temp_audio.write_bytes(content)
hf_token_val = hf_token or _resolve_token("HF_TOKEN", "hf_token")
openai_key = openai_api_key or _resolve_token("OPENAI_API_KEY", "openai_api_key")
if not openai_key:
raise HTTPException(status_code=400, detail="OPENAI_API_KEY is required.")
endpoint_url_val = endpoint_url or _resolve_token("HF_AF3_ENDPOINT_URL", "hf_af3_endpoint_url")
af3_client = _build_af3_client(
backend=backend,
endpoint_url=endpoint_url_val,
hf_token=hf_token_val,
model_id=model_id,
device=device,
torch_dtype=torch_dtype,
)
result = run_af3_chatgpt_pipeline(
audio_path=str(temp_audio),
af3_client=af3_client,
af3_prompt=af3_prompt,
af3_max_new_tokens=af3_max_new_tokens,
af3_temperature=af3_temperature,
openai_api_key=openai_key,
openai_model=openai_model,
user_context=user_context,
artist_name=artist_name,
track_name=track_name,
enable_web_search=enable_web_search,
)
default_out = Path("outputs") / "af3_chatgpt" / (Path(audio_file.filename or "track").stem + ".json")
save_path = save_sidecar(result["sidecar"], output_json or str(default_out))
return {
"saved_to": save_path,
"af3_analysis": result["af3_analysis"],
"cleaned": result["cleaned"],
"sidecar": result["sidecar"],
}
except HTTPException:
raise
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
finally:
try:
temp_audio.unlink(missing_ok=True)
except Exception:
pass