| from fastapi import FastAPI, Request, Form |
| from fastapi.templating import Jinja2Templates |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import HTMLResponse, JSONResponse |
| import requests |
| import json |
| import os |
| import shutil |
| import uuid |
| import logging |
| from pathlib import Path |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI() |
|
|
| |
| CACHE_DIR = Path("./cache") |
| CACHE_DIR.mkdir(exist_ok=True) |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| |
| templates = Jinja2Templates(directory="templates") |
|
|
| |
| API_URL = "https://black-forest-labs-flux-1-dev.hf.space/gradio_api/call/infer" |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def read_root(request: Request): |
| return templates.TemplateResponse("index.html", {"request": request}) |
|
|
| @app.post("/generate") |
| async def generate_image( |
| prompt: str = Form(...), |
| width: int = Form(1024), |
| height: int = Form(1024), |
| steps: int = Form(4), |
| guidance_scale: int = Form(60), |
| negative_prompt: str = Form(""), |
| seed: int = Form(0), |
| use_random_seed: bool = Form(True) |
| ): |
| try: |
| |
| payload = { |
| "data": [ |
| prompt, |
| seed, |
| use_random_seed, |
| width, |
| height, |
| steps, |
| guidance_scale, |
| negative_prompt if negative_prompt else None |
| ] |
| } |
| |
| |
| payload["data"] = [item for item in payload["data"] if item is not None] |
| |
| logger.info(f"Sending request with payload: {payload}") |
| |
| |
| response = requests.post( |
| API_URL, |
| headers={"Content-Type": "application/json"}, |
| data=json.dumps(payload) |
| ) |
| |
| if response.status_code != 200: |
| logger.error(f"API request failed with status code: {response.status_code}") |
| logger.error(f"Response: {response.text}") |
| return JSONResponse( |
| status_code=500, |
| content={"error": f"API request failed: {response.text}"} |
| ) |
| |
| |
| response_json = response.json() |
| event_id = response_json.get("event_id") |
| |
| if not event_id: |
| logger.error(f"No event_id in response: {response_json}") |
| return JSONResponse( |
| status_code=500, |
| content={"error": "No event ID returned from API"} |
| ) |
| |
| |
| return JSONResponse(content={"event_id": event_id}) |
| |
| except Exception as e: |
| logger.error(f"Error generating image: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"error": f"Error generating image: {str(e)}"} |
| ) |
|
|
| @app.get("/poll/{event_id}") |
| async def poll_status(event_id: str): |
| try: |
| stream_url = f"{API_URL}/{event_id}" |
| |
| |
| response = requests.get(stream_url) |
| |
| if response.status_code != 200: |
| return JSONResponse( |
| status_code=500, |
| content={"error": f"Failed to poll status: {response.text}"} |
| ) |
| |
| |
| image_urls = [] |
| complete_event = None |
| |
| |
| for line in response.text.splitlines(): |
| if not line: |
| continue |
| |
| |
| if "event: " in line: |
| event_type = line.split("event: ")[1].strip() |
| elif "data: " in line and line != "data: null": |
| try: |
| data = json.loads(line.split("data: ")[1]) |
| |
| |
| if event_type == "complete": |
| complete_event = data |
| |
| |
| if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict) and "url" in data[0]: |
| image_urls.append(data[0]["url"]) |
| except json.JSONDecodeError: |
| pass |
| |
| |
| return JSONResponse(content={ |
| "status": "complete" if complete_event else "generating", |
| "image_urls": image_urls, |
| "final_image": image_urls[-1] if image_urls else None |
| }) |
| |
| except Exception as e: |
| logger.error(f"Error polling status: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"error": f"Error polling status: {str(e)}"} |
| ) |
|
|
| @app.post("/clear-cache") |
| async def clear_cache(): |
| try: |
| |
| for item in CACHE_DIR.iterdir(): |
| if item.is_file(): |
| item.unlink() |
| elif item.is_dir(): |
| shutil.rmtree(item) |
| |
| return JSONResponse(content={"message": "Cache cleared successfully"}) |
| except Exception as e: |
| logger.error(f"Error clearing cache: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"error": f"Error clearing cache: {str(e)}"} |
| ) |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |