Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from fastapi import FastAPI, Request, Response | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| import markdown | |
| from api.classify.router import router as classify_router | |
| from api.common.logging import log_json, setup_logging | |
| from api.common.middleware import RequestIdMiddleware | |
| from api.label_sets.router import router as label_sets_router | |
| from api.label_sets.registry import LabelSetRegistry | |
| from api.model.clip_store import ClipStore | |
| from api.classify.service import TwoStageClassifier | |
| logger = setup_logging() | |
| BANNER = r""" | |
| ____ _ _ ____ _ | |
| | _ \| |__ ___ | |_ ___ / ___| | __ _ ___ ___ | |
| | |_) | '_ \ / _ \| __/ _ \ | | | |/ _` / __/ __| | |
| | __/| | | | (_) | || (_) | | |___| | (_| \__ \__ \ | |
| |_| |_| |_|\___/ \__\___/ \____|_|\__,_|___/___/ | |
| """ | |
| class Resources: | |
| store: ClipStore | |
| classifier: TwoStageClassifier | |
| registry: LabelSetRegistry | |
| async def _maybe_aclose(obj) -> None: | |
| aclose = getattr(obj, "aclose", None) | |
| if callable(aclose): | |
| await aclose() | |
| return | |
| close = getattr(obj, "close", None) | |
| if callable(close): | |
| close() | |
| async def lifespan(app: FastAPI): | |
| store = ClipStore() | |
| classifier = TwoStageClassifier(store=store) | |
| registry = LabelSetRegistry(banks={}) | |
| app.state.resources = Resources(store=store, classifier=classifier, registry=registry) | |
| try: | |
| yield | |
| finally: | |
| await _maybe_aclose(registry) | |
| await _maybe_aclose(classifier) | |
| await _maybe_aclose(store) | |
| def render_page(md_path: Path, *, title: str) -> HTMLResponse: | |
| header_path = Path(__file__).resolve().parent / "ui" / "page-banner.html" | |
| page_template_path = Path(__file__).resolve().parent / "ui" / "page.html" | |
| try: | |
| header_html = header_path.read_text(encoding="utf-8") | |
| template_html = page_template_path.read_text(encoding="utf-8") | |
| except Exception: | |
| return HTMLResponse(content="internal server error", status_code=500) | |
| try: | |
| md_text = md_path.read_text(encoding="utf-8") | |
| except Exception: | |
| return HTMLResponse(content="internal server error", status_code=500) | |
| if md_text.lstrip().startswith("---"): | |
| parts = md_text.split("---", 2) | |
| if len(parts) == 3: | |
| md_text = parts[2].lstrip() | |
| content_html = markdown.markdown( | |
| md_text, | |
| extensions=["fenced_code", "tables"], | |
| output_format="html5", | |
| ) | |
| html = ( | |
| template_html.replace("{{HEADER}}", header_html) | |
| .replace("{{CONTENT}}", content_html) | |
| .replace("{{TITLE}}", title) | |
| ) | |
| return HTMLResponse(content=html) | |
| def create_app(*, resources: Resources | None = None) -> FastAPI: | |
| app = FastAPI( | |
| lifespan=lifespan, | |
| title="Photo Classification API", | |
| version="1.0.0", | |
| description=f"```\n{BANNER.strip()}\n```", | |
| docs_url="/docs", | |
| redoc_url=None, | |
| openapi_url="/openapi.json", | |
| ) | |
| app.add_middleware(RequestIdMiddleware) | |
| if resources is not None: | |
| app.state.resources = resources | |
| async def _lifespan_override(_app: FastAPI): | |
| _app.state.resources = resources | |
| try: | |
| yield | |
| finally: | |
| await _maybe_aclose(resources.registry) | |
| await _maybe_aclose(resources.classifier) | |
| await _maybe_aclose(resources.store) | |
| app.router.lifespan_context = _lifespan_override | |
| def favicon() -> Response: | |
| svg = ( | |
| "<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 64 64'>" | |
| "<rect width='64' height='64' rx='12' fill='#1f2937'/>" | |
| "<circle cx='24' cy='28' r='10' fill='#f59e0b'/>" | |
| "<circle cx='44' cy='28' r='10' fill='#60a5fa'/>" | |
| "<rect x='18' y='38' width='28' height='10' rx='5' fill='#e5e7eb'/>" | |
| "</svg>" | |
| ) | |
| return Response(content=svg, media_type="image/svg+xml") | |
| def home() -> HTMLResponse: | |
| splash_path = Path(__file__).resolve().parent / "ui" / "splash.html" | |
| try: | |
| html = splash_path.read_text(encoding="utf-8") | |
| except Exception: | |
| html = f"<h1>Photo Classification</h1><p>Missing {splash_path}</p>" | |
| return HTMLResponse(content=html) | |
| def readme() -> HTMLResponse: | |
| readme_path = Path(__file__).resolve().parents[2] / "README.md" | |
| return render_page(readme_path, title="README") | |
| def story() -> HTMLResponse: | |
| story_path = Path(__file__).resolve().parents[2] / "STORY.md" | |
| return render_page(story_path, title="Story") | |
| async def unhandled_exception_handler(request: Request, exc: Exception): | |
| rid = getattr(request.state, "request_id", None) | |
| log_json(logger, event="error.unhandled", request_id=rid, error=str(exc), path=str(request.url.path)) | |
| return JSONResponse(status_code=500, content={"detail": "internal server error"}) | |
| app.include_router(label_sets_router) | |
| app.include_router(classify_router) | |
| return app | |
| def build_app( | |
| store: ClipStore | None = None, | |
| classifier: TwoStageClassifier | None = None, | |
| registry: LabelSetRegistry | None = None, | |
| ) -> FastAPI: | |
| if store is None and classifier is None and registry is None: | |
| return create_app() | |
| store = store or ClipStore() | |
| classifier = classifier or TwoStageClassifier(store=store) | |
| registry = registry or LabelSetRegistry(banks={}) | |
| resources = Resources(store=store, classifier=classifier, registry=registry) | |
| return create_app(resources=resources) | |
| app = create_app() | |