triflix commited on
Commit
1e8540f
·
verified ·
1 Parent(s): 0d5e655

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +178 -0
main.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Form
2
+ from fastapi.templating import Jinja2Templates
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
+ import requests
6
+ import json
7
+ import os
8
+ import shutil
9
+ import uuid
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ app = FastAPI()
18
+
19
+ # Create directories if they don't exist
20
+ CACHE_DIR = Path("./cache")
21
+ CACHE_DIR.mkdir(exist_ok=True)
22
+
23
+ # Mount static files directory
24
+ app.mount("/static", StaticFiles(directory="static"), name="static")
25
+
26
+ # Set up templates
27
+ templates = Jinja2Templates(directory="templates")
28
+
29
+ # API endpoint for image generation
30
+ API_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer"
31
+
32
+ @app.get("/", response_class=HTMLResponse)
33
+ async def read_root(request: Request):
34
+ return templates.TemplateResponse("index.html", {"request": request})
35
+
36
+ @app.post("/generate")
37
+ async def generate_image(
38
+ prompt: str = Form(...),
39
+ width: int = Form(1024),
40
+ height: int = Form(1024),
41
+ steps: int = Form(4),
42
+ guidance_scale: int = Form(60),
43
+ negative_prompt: str = Form(""),
44
+ seed: int = Form(0),
45
+ use_random_seed: bool = Form(True)
46
+ ):
47
+ try:
48
+ # Define the payload for the API
49
+ payload = {
50
+ "data": [
51
+ prompt,
52
+ seed,
53
+ use_random_seed,
54
+ width,
55
+ height,
56
+ steps,
57
+ guidance_scale,
58
+ negative_prompt if negative_prompt else None
59
+ ]
60
+ }
61
+
62
+ # Filter out None values from payload data
63
+ payload["data"] = [item for item in payload["data"] if item is not None]
64
+
65
+ logger.info(f"Sending request with payload: {payload}")
66
+
67
+ # Make the initial POST request
68
+ response = requests.post(
69
+ API_URL,
70
+ headers={"Content-Type": "application/json"},
71
+ data=json.dumps(payload)
72
+ )
73
+
74
+ if response.status_code != 200:
75
+ logger.error(f"API request failed with status code: {response.status_code}")
76
+ logger.error(f"Response: {response.text}")
77
+ return JSONResponse(
78
+ status_code=500,
79
+ content={"error": f"API request failed: {response.text}"}
80
+ )
81
+
82
+ # Parse the JSON response to get the event ID
83
+ response_json = response.json()
84
+ event_id = response_json.get("event_id")
85
+
86
+ if not event_id:
87
+ logger.error(f"No event_id in response: {response_json}")
88
+ return JSONResponse(
89
+ status_code=500,
90
+ content={"error": "No event ID returned from API"}
91
+ )
92
+
93
+ # Return the event ID to the client for tracking
94
+ return JSONResponse(content={"event_id": event_id})
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error generating image: {str(e)}")
98
+ return JSONResponse(
99
+ status_code=500,
100
+ content={"error": f"Error generating image: {str(e)}"}
101
+ )
102
+
103
+ @app.get("/poll/{event_id}")
104
+ async def poll_status(event_id: str):
105
+ try:
106
+ stream_url = f"{API_URL}/{event_id}"
107
+
108
+ # Make a non-streaming request to check status
109
+ response = requests.get(stream_url)
110
+
111
+ if response.status_code != 200:
112
+ return JSONResponse(
113
+ status_code=500,
114
+ content={"error": f"Failed to poll status: {response.text}"}
115
+ )
116
+
117
+ # Process the response content to extract image URLs
118
+ image_urls = []
119
+ complete_event = None
120
+
121
+ # Parse the response line by line
122
+ for line in response.text.splitlines():
123
+ if not line:
124
+ continue
125
+
126
+ # Find the event type and data
127
+ if "event: " in line:
128
+ event_type = line.split("event: ")[1].strip()
129
+ elif "data: " in line and line != "data: null":
130
+ try:
131
+ data = json.loads(line.split("data: ")[1])
132
+
133
+ # If this is a complete event, save it
134
+ if event_type == "complete":
135
+ complete_event = data
136
+
137
+ # Extract image URL
138
+ if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict) and "url" in data[0]:
139
+ image_urls.append(data[0]["url"])
140
+ except json.JSONDecodeError:
141
+ pass
142
+
143
+ # Return the status information
144
+ return JSONResponse(content={
145
+ "status": "complete" if complete_event else "generating",
146
+ "image_urls": image_urls,
147
+ "final_image": image_urls[-1] if image_urls else None
148
+ })
149
+
150
+ except Exception as e:
151
+ logger.error(f"Error polling status: {str(e)}")
152
+ return JSONResponse(
153
+ status_code=500,
154
+ content={"error": f"Error polling status: {str(e)}"}
155
+ )
156
+
157
+ @app.post("/clear-cache")
158
+ async def clear_cache():
159
+ try:
160
+ # Clear local cache directory
161
+ for item in CACHE_DIR.iterdir():
162
+ if item.is_file():
163
+ item.unlink()
164
+ elif item.is_dir():
165
+ shutil.rmtree(item)
166
+
167
+ return JSONResponse(content={"message": "Cache cleared successfully"})
168
+ except Exception as e:
169
+ logger.error(f"Error clearing cache: {str(e)}")
170
+ return JSONResponse(
171
+ status_code=500,
172
+ content={"error": f"Error clearing cache: {str(e)}"}
173
+ )
174
+
175
+ # For development
176
+ if __name__ == "__main__":
177
+ import uvicorn
178
+ uvicorn.run(app, host="0.0.0.0", port=7860)