cot-anc / app /api /main.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from io import StringIO
from pathlib import Path
from dataclasses import asdict
import os
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import attach_huggingface_oauth
from app.api.auth import get_optional_user, require_user
from app.core.config import get_settings
from app.core.runtime import load_model_bundle
from app.core.runtime_pipeline import compute_attribution_analysis
from app.core.schemas import (
AnalysisRequest,
AnalysisResult,
CurrentUserResponse,
HealthResponse,
SessionCreateRequest,
SessionResponse,
SessionResultResponse,
WarmupResponse,
)
from app.services.sessions import SessionAccessError, SessionLimitError, SessionService
from app.storage.repository import SessionRepository
from app.workers.jobs import build_job_runner
logger = logging.getLogger(__name__)
FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend"
@asynccontextmanager
async def lifespan(_app: FastAPI):
settings = get_settings()
repository = SessionRepository(settings.database_path)
jobs = build_job_runner(settings.job_workers)
_app.state.repository = repository
_app.state.jobs = jobs
_app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs)
if settings.preload_model:
logger.info(
"Preloading model '%s' on device preference '%s'.",
settings.model_name,
settings.device_preference,
)
load_model_bundle(
settings.model_name,
device_preference=settings.device_preference,
dtype_preference=settings.dtype_preference,
attn_implementation=settings.attn_implementation,
trust_remote_code=settings.trust_remote_code,
low_cpu_mem_usage=settings.low_cpu_mem_usage,
)
yield
jobs.shutdown()
app = FastAPI(
title="CoT Attribution Analysis API",
version="0.1.0",
lifespan=lifespan,
)
if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"):
attach_huggingface_oauth(app)
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
def get_session_service() -> SessionService:
return app.state.session_service
def _to_session_response(payload: dict) -> SessionResponse:
return SessionResponse(
id=payload["id"],
status=payload["status"],
question=payload["question"],
model_name=payload["model_name"],
error=payload.get("error"),
created_at=payload["created_at"],
updated_at=payload["updated_at"],
answer=payload.get("answer"),
raw_trace_text=payload.get("raw_trace_text"),
normalized_trace_text=payload.get("normalized_trace_text"),
sentences=payload.get("sentences"),
generation_metadata=payload.get("generation_metadata"),
)
@app.get("/", include_in_schema=False)
def index() -> FileResponse:
return FileResponse(FRONTEND_DIR / "index.html")
@app.get("/healthz", response_model=HealthResponse)
def healthz() -> HealthResponse:
settings = get_settings()
return HealthResponse(
status="ok",
model_name=settings.model_name,
device_preference=settings.device_preference,
dtype_preference=settings.dtype_preference,
preload_model=settings.preload_model,
cuda_available=torch.cuda.is_available(),
mps_available=torch.backends.mps.is_available(),
require_auth=settings.require_auth,
public_api_enabled=settings.public_api_enabled,
max_queued_jobs=settings.max_queued_jobs,
max_active_jobs_per_user=settings.max_active_jobs_per_user,
)
@app.get("/api/me", response_model=CurrentUserResponse)
def me(request: Request) -> CurrentUserResponse:
settings = get_settings()
user = get_optional_user(request)
return CurrentUserResponse(
authenticated=user is not None,
auth_required=settings.require_auth,
username=user.username if user else None,
full_name=user.display_name if user else None,
avatar_url=user.avatar_url if user else None,
)
@app.post("/api/warmup", response_model=WarmupResponse)
def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse:
settings = get_settings()
bundle = load_model_bundle(
model_name or settings.model_name,
device_preference=device_preference or settings.device_preference,
dtype_preference=settings.dtype_preference,
attn_implementation=settings.attn_implementation,
trust_remote_code=settings.trust_remote_code,
low_cpu_mem_usage=settings.low_cpu_mem_usage,
)
return WarmupResponse(
status="ready",
model_name=bundle.model_name,
device=str(bundle.device),
dtype=str(bundle.dtype),
capability=asdict(bundle.capability),
)
@app.post("/api/analyze", response_model=AnalysisResult)
def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult:
settings = get_settings()
require_user(http_request, settings)
try:
return compute_attribution_analysis(
question=request.question,
model_name=request.model_name,
take_log=request.take_log,
max_sentences=request.max_sentences,
max_trace_tokens=request.max_trace_tokens,
validate_top_k=request.validate_top_k,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
device_preference=request.device_preference,
dtype_preference=request.dtype_preference,
attn_implementation=request.attn_implementation,
trust_remote_code=request.trust_remote_code,
low_cpu_mem_usage=request.low_cpu_mem_usage,
)
except Exception as exc: # pragma: no cover - runtime path
logger.exception("Analysis request failed")
raise HTTPException(status_code=500, detail=str(exc)) from exc
@app.get("/api/sessions", response_model=list[SessionResponse])
def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]:
settings = get_settings()
user = require_user(request, settings)
service = get_session_service()
payloads = service.list_sessions(user.id, limit=limit)
return [_to_session_response(payload) for payload in payloads]
@app.post("/api/sessions", response_model=SessionResponse)
def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse:
settings = get_settings()
user = require_user(http_request, settings)
service = get_session_service()
try:
session = service.create_session(
AnalysisRequest(
question=request.question,
model_name=request.model_name,
take_log=request.take_log,
max_sentences=request.max_sentences,
max_trace_tokens=request.max_trace_tokens,
validate_top_k=request.validate_top_k,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_p=request.top_p,
device_preference=request.device_preference,
dtype_preference=request.dtype_preference,
attn_implementation=request.attn_implementation,
trust_remote_code=request.trust_remote_code,
low_cpu_mem_usage=request.low_cpu_mem_usage,
),
owner_id=user.id,
owner_name=user.display_name,
)
except SessionLimitError as exc:
raise HTTPException(status_code=429, detail=str(exc)) from exc
payload = service.get_session_payload(session.id, owner_id=user.id)
return _to_session_response(payload)
@app.get("/api/sessions/{session_id}", response_model=SessionResponse)
def get_session(session_id: str, request: Request) -> SessionResponse:
settings = get_settings()
user = require_user(request, settings)
service = get_session_service()
try:
payload = service.get_session_payload(session_id, owner_id=user.id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Session not found") from exc
except SessionAccessError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
return _to_session_response(payload)
@app.post("/api/sessions/{session_id}/analyze", response_model=SessionResponse)
def analyze_session(session_id: str, request: Request) -> SessionResponse:
settings = get_settings()
user = require_user(request, settings)
service = get_session_service()
try:
session = service.start_analysis(session_id, owner_id=user.id)
payload = service.get_session_payload(session.id, owner_id=user.id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Session not found") from exc
except SessionAccessError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
return _to_session_response(payload)
@app.get("/api/sessions/{session_id}/result", response_model=SessionResultResponse)
def get_session_result(session_id: str, request: Request) -> SessionResultResponse:
settings = get_settings()
user = require_user(request, settings)
service = get_session_service()
try:
payload = service.get_session_payload(session_id, owner_id=user.id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Session not found") from exc
except SessionAccessError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
session_response = _to_session_response(payload)
analysis_payload = payload.get("analysis")
return SessionResultResponse(
session=session_response,
analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None,
)
@app.get("/api/sessions/{session_id}/export.json")
def export_session_json(session_id: str, request: Request) -> StreamingResponse:
settings = get_settings()
user = require_user(request, settings)
service = get_session_service()
try:
payload = service.get_session_payload(session_id, owner_id=user.id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Session not found") from exc
except SessionAccessError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
result = SessionResultResponse(
session=_to_session_response(payload),
analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None,
)
return StreamingResponse(
iter([result.model_dump_json(indent=2)]),
media_type="application/json",
headers={"content-disposition": f'attachment; filename="{session_id}.json"'},
)
@app.get("/api/sessions/{session_id}/export.csv")
def export_session_csv(session_id: str, request: Request) -> StreamingResponse:
settings = get_settings()
user = require_user(request, settings)
service = get_session_service()
try:
result = service.get_analysis_result(session_id, owner_id=user.id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="Analysis result not found") from exc
except SessionAccessError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
buffer = StringIO()
buffer.write("source_sentence_idx,target_sentence_idx,score\n")
for edge in result.top_edges:
buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n")
return StreamingResponse(
iter([buffer.getvalue()]),
media_type="text/csv",
headers={"content-disposition": f'attachment; filename="{session_id}.csv"'},
)