Spaces:
Sleeping
Sleeping
| """ | |
| main.py β FastAPI application for the Attention Visualization Tool. | |
| Routes | |
| ------ | |
| GET /api/models β list of available models + metadata | |
| POST /api/attend β run inference, return tokens + attention weights | |
| GET / β serve React frontend (after build) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, Field | |
| from models import MODEL_REGISTRY | |
| from attention import get_attention | |
| # ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s β %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Attention Visualizer API", | |
| description="Extracts and serves transformer attention weights for visualization.", | |
| version="1.0.0", | |
| ) | |
| # Allow any localhost port (Vite may pick 5173, 5174, 5175, β¦) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten this for production | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Schemas βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AttendRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=512, example="The cat sat on the mat.") | |
| model_id: str = Field(..., example="bert-base-uncased") | |
| # ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "ok"} | |
| def list_models(): | |
| """Return the list of available models with metadata.""" | |
| return MODEL_REGISTRY | |
| def attend(req: AttendRequest): | |
| """ | |
| Run a forward pass through the requested model and return | |
| tokenized text plus all attention weight matrices. | |
| """ | |
| logger.info("attend β model=%s text=%r", req.model_id, req.text[:80]) | |
| try: | |
| result = get_attention(req.text, req.model_id) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| except Exception as exc: | |
| logger.exception("Inference error") | |
| raise HTTPException(status_code=500, detail=f"Inference error: {exc}") | |
| return result | |
| # ββ Serve built React frontend ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FRONTEND_DIST = Path(__file__).parent.parent / "frontend" / "dist" | |
| if FRONTEND_DIST.exists(): | |
| app.mount( | |
| "/assets", | |
| StaticFiles(directory=str(FRONTEND_DIST / "assets")), | |
| name="assets", | |
| ) | |
| def serve_spa(full_path: str = ""): | |
| index = FRONTEND_DIST / "index.html" | |
| if index.exists(): | |
| return FileResponse(str(index)) | |
| return {"detail": "Frontend not built. Run: cd frontend && npm run build"} | |