Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import requests, json | |
| # FastAPI setup | |
| app = FastAPI() | |
| # Serve static files (if needed) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Templates directory | |
| templates = Jinja2Templates(directory="templates") | |
| # External API config | |
| BASE_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer" | |
| HEADERS = {"Content-Type": "application/json"} | |
| DEFAULT_PAYLOAD = { | |
| "data": ["", 0, True, 1024, 1024, 4, 60] | |
| } | |
| # Helper: request an event_id for given prompt | |
| def get_event_id(prompt: str) -> str: | |
| payload = DEFAULT_PAYLOAD.copy() | |
| # insert prompt | |
| payload["data"][0] = prompt | |
| resp = requests.post(BASE_URL, headers=HEADERS, json=payload) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| event_id = data.get("event_id") | |
| if not event_id: | |
| raise RuntimeError(f"No event_id returned: {data}") | |
| return event_id | |
| # Helper: stream SSE until 'complete' event, then extract URL | |
| def stream_until_complete(event_id: str) -> str: | |
| url = f"{BASE_URL}/{event_id}" | |
| with requests.get(url, headers=HEADERS, stream=True) as resp: | |
| resp.raise_for_status() | |
| buffer = "" | |
| for chunk in resp.iter_content(chunk_size=None, decode_unicode=True): | |
| buffer += chunk | |
| while "\n\n" in buffer: | |
| message, buffer = buffer.split("\n\n", 1) | |
| evt = None | |
| payload = None | |
| for line in message.splitlines(): | |
| if line.startswith("event:"): | |
| evt = line.split("event:",1)[1].strip() | |
| elif line.startswith("data:"): | |
| payload = line.split("data:",1)[1].strip() | |
| if evt == "complete" and payload: | |
| parsed = json.loads(payload) | |
| file_info = parsed[0] | |
| return file_info.get("url") | |
| raise RuntimeError("Stream ended without complete event") | |
| # Home page: form for prompt | |
| async def home(request: Request): | |
| # no-cache headers | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # Generate endpoint: processes form and renders result | |
| async def generate(request: Request, prompt: str = Form(...)): | |
| try: | |
| event_id = get_event_id(prompt) | |
| image_url = stream_until_complete(event_id) | |
| return templates.TemplateResponse( | |
| "index.html", {"request": request, "image_url": image_url, "prompt": prompt} | |
| ) | |
| except Exception as e: | |
| return templates.TemplateResponse( | |
| "index.html", {"request": request, "error": str(e), "prompt": prompt} | |
| ) | |