triflix commited on
Commit
88083ce
·
verified ·
1 Parent(s): 88b0391

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +75 -0
main.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Form
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ import requests, json
6
+
7
+ # FastAPI setup
8
+ app = FastAPI()
9
+ # Serve static files (if needed)
10
+ app.mount("/static", StaticFiles(directory="static"), name="static")
11
+ # Templates directory
12
+ templates = Jinja2Templates(directory="templates")
13
+
14
+ # External API config
15
+ BASE_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer"
16
+ HEADERS = {"Content-Type": "application/json"}
17
+ DEFAULT_PAYLOAD = {
18
+ "data": ["", 0, True, 1024, 1024, 4, 60]
19
+ }
20
+
21
+ # Helper: request an event_id for given prompt
22
+ def get_event_id(prompt: str) -> str:
23
+ payload = DEFAULT_PAYLOAD.copy()
24
+ # insert prompt
25
+ payload["data"][0] = prompt
26
+ resp = requests.post(BASE_URL, headers=HEADERS, json=payload)
27
+ resp.raise_for_status()
28
+ data = resp.json()
29
+ event_id = data.get("event_id")
30
+ if not event_id:
31
+ raise RuntimeError(f"No event_id returned: {data}")
32
+ return event_id
33
+
34
+ # Helper: stream SSE until 'complete' event, then extract URL
35
+ def stream_until_complete(event_id: str) -> str:
36
+ url = f"{BASE_URL}/{event_id}"
37
+ with requests.get(url, headers=HEADERS, stream=True) as resp:
38
+ resp.raise_for_status()
39
+ buffer = ""
40
+ for chunk in resp.iter_content(chunk_size=None, decode_unicode=True):
41
+ buffer += chunk
42
+ while "\n\n" in buffer:
43
+ message, buffer = buffer.split("\n\n", 1)
44
+ evt = None
45
+ payload = None
46
+ for line in message.splitlines():
47
+ if line.startswith("event:"):
48
+ evt = line.split("event:",1)[1].strip()
49
+ elif line.startswith("data:"):
50
+ payload = line.split("data:",1)[1].strip()
51
+ if evt == "complete" and payload:
52
+ parsed = json.loads(payload)
53
+ file_info = parsed[0]
54
+ return file_info.get("url")
55
+ raise RuntimeError("Stream ended without complete event")
56
+
57
+ # Home page: form for prompt
58
+ @app.get("/", response_class=HTMLResponse)
59
+ async def home(request: Request):
60
+ # no-cache headers
61
+ return templates.TemplateResponse("index.html", {"request": request})
62
+
63
+ # Generate endpoint: processes form and renders result
64
+ @app.post("/generate", response_class=HTMLResponse)
65
+ async def generate(request: Request, prompt: str = Form(...)):
66
+ try:
67
+ event_id = get_event_id(prompt)
68
+ image_url = stream_until_complete(event_id)
69
+ return templates.TemplateResponse(
70
+ "index.html", {"request": request, "image_url": image_url, "prompt": prompt}
71
+ )
72
+ except Exception as e:
73
+ return templates.TemplateResponse(
74
+ "index.html", {"request": request, "error": str(e), "prompt": prompt}
75
+ )