fluxxydev / main.py
triflix's picture
Create main.py
1e8540f verified
from fastapi import FastAPI, Request, Form
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
import requests
import json
import os
import shutil
import uuid
import logging
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Create directories if they don't exist
CACHE_DIR = Path("./cache")
CACHE_DIR.mkdir(exist_ok=True)
# Mount static files directory
app.mount("/static", StaticFiles(directory="static"), name="static")
# Set up templates
templates = Jinja2Templates(directory="templates")
# API endpoint for image generation
API_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer"
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/generate")
async def generate_image(
prompt: str = Form(...),
width: int = Form(1024),
height: int = Form(1024),
steps: int = Form(4),
guidance_scale: int = Form(60),
negative_prompt: str = Form(""),
seed: int = Form(0),
use_random_seed: bool = Form(True)
):
try:
# Define the payload for the API
payload = {
"data": [
prompt,
seed,
use_random_seed,
width,
height,
steps,
guidance_scale,
negative_prompt if negative_prompt else None
]
}
# Filter out None values from payload data
payload["data"] = [item for item in payload["data"] if item is not None]
logger.info(f"Sending request with payload: {payload}")
# Make the initial POST request
response = requests.post(
API_URL,
headers={"Content-Type": "application/json"},
data=json.dumps(payload)
)
if response.status_code != 200:
logger.error(f"API request failed with status code: {response.status_code}")
logger.error(f"Response: {response.text}")
return JSONResponse(
status_code=500,
content={"error": f"API request failed: {response.text}"}
)
# Parse the JSON response to get the event ID
response_json = response.json()
event_id = response_json.get("event_id")
if not event_id:
logger.error(f"No event_id in response: {response_json}")
return JSONResponse(
status_code=500,
content={"error": "No event ID returned from API"}
)
# Return the event ID to the client for tracking
return JSONResponse(content={"event_id": event_id})
except Exception as e:
logger.error(f"Error generating image: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Error generating image: {str(e)}"}
)
@app.get("/poll/{event_id}")
async def poll_status(event_id: str):
try:
stream_url = f"{API_URL}/{event_id}"
# Make a non-streaming request to check status
response = requests.get(stream_url)
if response.status_code != 200:
return JSONResponse(
status_code=500,
content={"error": f"Failed to poll status: {response.text}"}
)
# Process the response content to extract image URLs
image_urls = []
complete_event = None
# Parse the response line by line
for line in response.text.splitlines():
if not line:
continue
# Find the event type and data
if "event: " in line:
event_type = line.split("event: ")[1].strip()
elif "data: " in line and line != "data: null":
try:
data = json.loads(line.split("data: ")[1])
# If this is a complete event, save it
if event_type == "complete":
complete_event = data
# Extract image URL
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict) and "url" in data[0]:
image_urls.append(data[0]["url"])
except json.JSONDecodeError:
pass
# Return the status information
return JSONResponse(content={
"status": "complete" if complete_event else "generating",
"image_urls": image_urls,
"final_image": image_urls[-1] if image_urls else None
})
except Exception as e:
logger.error(f"Error polling status: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Error polling status: {str(e)}"}
)
@app.post("/clear-cache")
async def clear_cache():
try:
# Clear local cache directory
for item in CACHE_DIR.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
shutil.rmtree(item)
return JSONResponse(content={"message": "Cache cleared successfully"})
except Exception as e:
logger.error(f"Error clearing cache: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Error clearing cache: {str(e)}"}
)
# For development
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)