File size: 2,890 Bytes
88083ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import requests, json

# FastAPI setup
app = FastAPI()
# Serve static files (if needed)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Templates directory
templates = Jinja2Templates(directory="templates")

# External API config
BASE_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer"
HEADERS = {"Content-Type": "application/json"}
DEFAULT_PAYLOAD = {
    "data": ["", 0, True, 1024, 1024, 4, 60]
}

# Helper: request an event_id for given prompt
def get_event_id(prompt: str) -> str:
    payload = DEFAULT_PAYLOAD.copy()
    # insert prompt
    payload["data"][0] = prompt
    resp = requests.post(BASE_URL, headers=HEADERS, json=payload)
    resp.raise_for_status()
    data = resp.json()
    event_id = data.get("event_id")
    if not event_id:
        raise RuntimeError(f"No event_id returned: {data}")
    return event_id

# Helper: stream SSE until 'complete' event, then extract URL
def stream_until_complete(event_id: str) -> str:
    url = f"{BASE_URL}/{event_id}"
    with requests.get(url, headers=HEADERS, stream=True) as resp:
        resp.raise_for_status()
        buffer = ""
        for chunk in resp.iter_content(chunk_size=None, decode_unicode=True):
            buffer += chunk
            while "\n\n" in buffer:
                message, buffer = buffer.split("\n\n", 1)
                evt = None
                payload = None
                for line in message.splitlines():
                    if line.startswith("event:"):
                        evt = line.split("event:",1)[1].strip()
                    elif line.startswith("data:"):
                        payload = line.split("data:",1)[1].strip()
                if evt == "complete" and payload:
                    parsed = json.loads(payload)
                    file_info = parsed[0]
                    return file_info.get("url")
    raise RuntimeError("Stream ended without complete event")

# Home page: form for prompt
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    # no-cache headers
    return templates.TemplateResponse("index.html", {"request": request})

# Generate endpoint: processes form and renders result
@app.post("/generate", response_class=HTMLResponse)
async def generate(request: Request, prompt: str = Form(...)):
    try:
        event_id = get_event_id(prompt)
        image_url = stream_until_complete(event_id)
        return templates.TemplateResponse(
            "index.html", {"request": request, "image_url": image_url, "prompt": prompt}
        )
    except Exception as e:
        return templates.TemplateResponse(
            "index.html", {"request": request, "error": str(e), "prompt": prompt}
        )