paint_match / server.py
szwendaczjakomaj's picture
fix: remove hardcoded LLAMA_SERVER URL + improve error visibility through Cloudflare
8cbae1e
Raw
History Blame Contribute Delete
2.9 kB
import os
import PIL.Image
import gradio as gr
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import FileResponse
from app import analyze as _analyze # existing generator: yields (status, "") then (raw, manifest_html)
FRONTEND_DIR = os.path.join(os.path.dirname(__file__), "frontend")
app = gr.Server(title="Paint Match")
@app.api(name="analyze", stream_every=0.5)
def analyze(image: PIL.Image.Image, format_id: str) -> tuple[str, str]:
"""Stream rotating status lines, then the final (field report, resupply manifest HTML).
gr.Server delivers `image` as a Gradio FileData dict (the PIL type hint is not
coerced). Accept both a dict (production: open the path) and a real PIL image
(so unit tests can pass one directly); None falls through to app.analyze's guard.
"""
if isinstance(image, dict):
image = PIL.Image.open(image["path"])
yield from _analyze(image, format_id)
# ---------------------------------------------------------------------------
# Frontend middleware — intercepts /, /app.js, /assets/* BEFORE Gradio's
# routes win. Middleware runs before the router, so ordering of route
# registration does not matter.
# ---------------------------------------------------------------------------
_ASSETS_DIR = os.path.join(FRONTEND_DIR, "assets")
_EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples")
class FrontendMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
path = request.url.path
if path == "/" or path == "":
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
if path == "/app.js":
return FileResponse(
os.path.join(FRONTEND_DIR, "app.js"),
media_type="text/javascript",
)
if path.startswith("/assets/"):
asset_rel = path[len("/assets/"):]
assets_root = os.path.realpath(_ASSETS_DIR)
full = os.path.realpath(os.path.join(assets_root, asset_rel))
if (full == assets_root or full.startswith(assets_root + os.sep)) and os.path.isfile(full):
return FileResponse(full)
if path.startswith("/examples/"):
rel = path[len("/examples/"):]
examples_root = os.path.realpath(_EXAMPLES_DIR)
full = os.path.realpath(os.path.join(examples_root, rel))
if (full == examples_root or full.startswith(examples_root + os.sep)) and os.path.isfile(full):
return FileResponse(full)
response = await call_next(request)
response.headers["X-Accel-Buffering"] = "no"
return response
app.add_middleware(FrontendMiddleware)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", "7860")),
allowed_paths=[FRONTEND_DIR],
)