from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import requests, json # FastAPI setup app = FastAPI() # Serve static files (if needed) app.mount("/static", StaticFiles(directory="static"), name="static") # Templates directory templates = Jinja2Templates(directory="templates") # External API config BASE_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer" HEADERS = {"Content-Type": "application/json"} DEFAULT_PAYLOAD = { "data": ["", 0, True, 1024, 1024, 4, 60] } # Helper: request an event_id for given prompt def get_event_id(prompt: str) -> str: payload = DEFAULT_PAYLOAD.copy() # insert prompt payload["data"][0] = prompt resp = requests.post(BASE_URL, headers=HEADERS, json=payload) resp.raise_for_status() data = resp.json() event_id = data.get("event_id") if not event_id: raise RuntimeError(f"No event_id returned: {data}") return event_id # Helper: stream SSE until 'complete' event, then extract URL def stream_until_complete(event_id: str) -> str: url = f"{BASE_URL}/{event_id}" with requests.get(url, headers=HEADERS, stream=True) as resp: resp.raise_for_status() buffer = "" for chunk in resp.iter_content(chunk_size=None, decode_unicode=True): buffer += chunk while "\n\n" in buffer: message, buffer = buffer.split("\n\n", 1) evt = None payload = None for line in message.splitlines(): if line.startswith("event:"): evt = line.split("event:",1)[1].strip() elif line.startswith("data:"): payload = line.split("data:",1)[1].strip() if evt == "complete" and payload: parsed = json.loads(payload) file_info = parsed[0] return file_info.get("url") raise RuntimeError("Stream ended without complete event") # Home page: form for prompt @app.get("/", response_class=HTMLResponse) async def home(request: Request): # no-cache headers return templates.TemplateResponse("index.html", {"request": request}) # Generate endpoint: processes form and renders result @app.post("/generate", response_class=HTMLResponse) async def generate(request: Request, prompt: str = Form(...)): try: event_id = get_event_id(prompt) image_url = stream_until_complete(event_id) return templates.TemplateResponse( "index.html", {"request": request, "image_url": image_url, "prompt": prompt} ) except Exception as e: return templates.TemplateResponse( "index.html", {"request": request, "error": str(e), "prompt": prompt} )