NoobNovel's picture
HF Space deployment: Attention Visualizer (FastAPI + React)
4e8e113
"""
main.py β€” FastAPI application for the Attention Visualization Tool.
Routes
------
GET /api/models β†’ list of available models + metadata
POST /api/attend β†’ run inference, return tokens + attention weights
GET / β†’ serve React frontend (after build)
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from models import MODEL_REGISTRY
from attention import get_attention
# ── Logging ───────────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s β€” %(message)s",
)
logger = logging.getLogger(__name__)
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Attention Visualizer API",
description="Extracts and serves transformer attention weights for visualization.",
version="1.0.0",
)
# Allow any localhost port (Vite may pick 5173, 5174, 5175, …)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten this for production
allow_methods=["*"],
allow_headers=["*"],
)
# ── Schemas ───────────────────────────────────────────────────────────────────
class AttendRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=512, example="The cat sat on the mat.")
model_id: str = Field(..., example="bert-base-uncased")
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/api/health")
def health():
return {"status": "ok"}
@app.get("/api/models")
def list_models():
"""Return the list of available models with metadata."""
return MODEL_REGISTRY
@app.post("/api/attend")
def attend(req: AttendRequest):
"""
Run a forward pass through the requested model and return
tokenized text plus all attention weight matrices.
"""
logger.info("attend β†’ model=%s text=%r", req.model_id, req.text[:80])
try:
result = get_attention(req.text, req.model_id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except Exception as exc:
logger.exception("Inference error")
raise HTTPException(status_code=500, detail=f"Inference error: {exc}")
return result
# ── Serve built React frontend ────────────────────────────────────────────────
FRONTEND_DIST = Path(__file__).parent.parent / "frontend" / "dist"
if FRONTEND_DIST.exists():
app.mount(
"/assets",
StaticFiles(directory=str(FRONTEND_DIST / "assets")),
name="assets",
)
@app.get("/", include_in_schema=False)
@app.get("/{full_path:path}", include_in_schema=False)
def serve_spa(full_path: str = ""):
index = FRONTEND_DIST / "index.html"
if index.exists():
return FileResponse(str(index))
return {"detail": "Frontend not built. Run: cd frontend && npm run build"}