import os import PIL.Image import gradio as gr from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import FileResponse from app import analyze as _analyze # existing generator: yields (status, "") then (raw, manifest_html) FRONTEND_DIR = os.path.join(os.path.dirname(__file__), "frontend") app = gr.Server(title="Paint Match") @app.api(name="analyze", stream_every=0.5) def analyze(image: PIL.Image.Image, format_id: str) -> tuple[str, str]: """Stream rotating status lines, then the final (field report, resupply manifest HTML). gr.Server delivers `image` as a Gradio FileData dict (the PIL type hint is not coerced). Accept both a dict (production: open the path) and a real PIL image (so unit tests can pass one directly); None falls through to app.analyze's guard. """ if isinstance(image, dict): image = PIL.Image.open(image["path"]) yield from _analyze(image, format_id) # --------------------------------------------------------------------------- # Frontend middleware — intercepts /, /app.js, /assets/* BEFORE Gradio's # routes win. Middleware runs before the router, so ordering of route # registration does not matter. # --------------------------------------------------------------------------- _ASSETS_DIR = os.path.join(FRONTEND_DIR, "assets") _EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples") class FrontendMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): path = request.url.path if path == "/" or path == "": return FileResponse(os.path.join(FRONTEND_DIR, "index.html")) if path == "/app.js": return FileResponse( os.path.join(FRONTEND_DIR, "app.js"), media_type="text/javascript", ) if path.startswith("/assets/"): asset_rel = path[len("/assets/"):] assets_root = os.path.realpath(_ASSETS_DIR) full = os.path.realpath(os.path.join(assets_root, asset_rel)) if (full == assets_root or full.startswith(assets_root + os.sep)) and os.path.isfile(full): return FileResponse(full) if path.startswith("/examples/"): rel = path[len("/examples/"):] examples_root = os.path.realpath(_EXAMPLES_DIR) full = os.path.realpath(os.path.join(examples_root, rel)) if (full == examples_root or full.startswith(examples_root + os.sep)) and os.path.isfile(full): return FileResponse(full) response = await call_next(request) response.headers["X-Accel-Buffering"] = "no" return response app.add_middleware(FrontendMiddleware) if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", "7860")), allowed_paths=[FRONTEND_DIR], )