Spaces:
Sleeping
Sleeping
File size: 6,033 Bytes
68f48a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from fastapi import FastAPI, Request, Response
from fastapi.responses import HTMLResponse, JSONResponse
import markdown
from api.classify.router import router as classify_router
from api.common.logging import log_json, setup_logging
from api.common.middleware import RequestIdMiddleware
from api.label_sets.router import router as label_sets_router
from api.label_sets.registry import LabelSetRegistry
from api.model.clip_store import ClipStore
from api.classify.service import TwoStageClassifier
logger = setup_logging()
BANNER = r"""
____ _ _ ____ _
| _ \| |__ ___ | |_ ___ / ___| | __ _ ___ ___
| |_) | '_ \ / _ \| __/ _ \ | | | |/ _` / __/ __|
| __/| | | | (_) | || (_) | | |___| | (_| \__ \__ \
|_| |_| |_|\___/ \__\___/ \____|_|\__,_|___/___/
"""
@dataclass
class Resources:
store: ClipStore
classifier: TwoStageClassifier
registry: LabelSetRegistry
async def _maybe_aclose(obj) -> None:
aclose = getattr(obj, "aclose", None)
if callable(aclose):
await aclose()
return
close = getattr(obj, "close", None)
if callable(close):
close()
@asynccontextmanager
async def lifespan(app: FastAPI):
store = ClipStore()
classifier = TwoStageClassifier(store=store)
registry = LabelSetRegistry(banks={})
app.state.resources = Resources(store=store, classifier=classifier, registry=registry)
try:
yield
finally:
await _maybe_aclose(registry)
await _maybe_aclose(classifier)
await _maybe_aclose(store)
def render_page(md_path: Path, *, title: str) -> HTMLResponse:
header_path = Path(__file__).resolve().parent / "ui" / "page-banner.html"
page_template_path = Path(__file__).resolve().parent / "ui" / "page.html"
try:
header_html = header_path.read_text(encoding="utf-8")
template_html = page_template_path.read_text(encoding="utf-8")
except Exception:
return HTMLResponse(content="internal server error", status_code=500)
try:
md_text = md_path.read_text(encoding="utf-8")
except Exception:
return HTMLResponse(content="internal server error", status_code=500)
if md_text.lstrip().startswith("---"):
parts = md_text.split("---", 2)
if len(parts) == 3:
md_text = parts[2].lstrip()
content_html = markdown.markdown(
md_text,
extensions=["fenced_code", "tables"],
output_format="html5",
)
html = (
template_html.replace("{{HEADER}}", header_html)
.replace("{{CONTENT}}", content_html)
.replace("{{TITLE}}", title)
)
return HTMLResponse(content=html)
def create_app(*, resources: Resources | None = None) -> FastAPI:
app = FastAPI(
lifespan=lifespan,
title="Photo Classification API",
version="1.0.0",
description=f"```\n{BANNER.strip()}\n```",
docs_url="/docs",
redoc_url=None,
openapi_url="/openapi.json",
)
app.add_middleware(RequestIdMiddleware)
if resources is not None:
app.state.resources = resources
@asynccontextmanager
async def _lifespan_override(_app: FastAPI):
_app.state.resources = resources
try:
yield
finally:
await _maybe_aclose(resources.registry)
await _maybe_aclose(resources.classifier)
await _maybe_aclose(resources.store)
app.router.lifespan_context = _lifespan_override
@app.get("/favicon.ico", include_in_schema=False)
def favicon() -> Response:
svg = (
"<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 64 64'>"
"<rect width='64' height='64' rx='12' fill='#1f2937'/>"
"<circle cx='24' cy='28' r='10' fill='#f59e0b'/>"
"<circle cx='44' cy='28' r='10' fill='#60a5fa'/>"
"<rect x='18' y='38' width='28' height='10' rx='5' fill='#e5e7eb'/>"
"</svg>"
)
return Response(content=svg, media_type="image/svg+xml")
@app.get("/", include_in_schema=False)
def home() -> HTMLResponse:
splash_path = Path(__file__).resolve().parent / "ui" / "splash.html"
try:
html = splash_path.read_text(encoding="utf-8")
except Exception:
html = f"<h1>Photo Classification</h1><p>Missing {splash_path}</p>"
return HTMLResponse(content=html)
@app.get("/readme", include_in_schema=False)
def readme() -> HTMLResponse:
readme_path = Path(__file__).resolve().parents[2] / "README.md"
return render_page(readme_path, title="README")
@app.get("/story", include_in_schema=False)
def story() -> HTMLResponse:
story_path = Path(__file__).resolve().parents[2] / "STORY.md"
return render_page(story_path, title="Story")
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception):
rid = getattr(request.state, "request_id", None)
log_json(logger, event="error.unhandled", request_id=rid, error=str(exc), path=str(request.url.path))
return JSONResponse(status_code=500, content={"detail": "internal server error"})
app.include_router(label_sets_router)
app.include_router(classify_router)
return app
def build_app(
store: ClipStore | None = None,
classifier: TwoStageClassifier | None = None,
registry: LabelSetRegistry | None = None,
) -> FastAPI:
if store is None and classifier is None and registry is None:
return create_app()
store = store or ClipStore()
classifier = classifier or TwoStageClassifier(store=store)
registry = registry or LabelSetRegistry(banks={})
resources = Resources(store=store, classifier=classifier, registry=registry)
return create_app(resources=resources)
app = create_app()
|