from __future__ import annotations import logging from contextlib import asynccontextmanager from io import StringIO from pathlib import Path from dataclasses import asdict import os import torch from fastapi import FastAPI, HTTPException, Request from fastapi.responses import FileResponse from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from huggingface_hub import attach_huggingface_oauth from app.api.auth import get_optional_user, require_user from app.core.config import get_settings from app.core.runtime import load_model_bundle from app.core.runtime_pipeline import compute_attribution_analysis from app.core.schemas import ( AnalysisRequest, AnalysisResult, CurrentUserResponse, HealthResponse, SessionCreateRequest, SessionResponse, SessionResultResponse, WarmupResponse, ) from app.services.sessions import SessionAccessError, SessionLimitError, SessionService from app.storage.repository import SessionRepository from app.workers.jobs import build_job_runner logger = logging.getLogger(__name__) FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend" @asynccontextmanager async def lifespan(_app: FastAPI): settings = get_settings() repository = SessionRepository(settings.database_path) jobs = build_job_runner(settings.job_workers) _app.state.repository = repository _app.state.jobs = jobs _app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs) if settings.preload_model: logger.info( "Preloading model '%s' on device preference '%s'.", settings.model_name, settings.device_preference, ) load_model_bundle( settings.model_name, device_preference=settings.device_preference, dtype_preference=settings.dtype_preference, attn_implementation=settings.attn_implementation, trust_remote_code=settings.trust_remote_code, low_cpu_mem_usage=settings.low_cpu_mem_usage, ) yield jobs.shutdown() app = FastAPI( title="CoT Attribution Analysis API", version="0.1.0", lifespan=lifespan, ) if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"): attach_huggingface_oauth(app) app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static") def get_session_service() -> SessionService: return app.state.session_service def _to_session_response(payload: dict) -> SessionResponse: return SessionResponse( id=payload["id"], status=payload["status"], question=payload["question"], model_name=payload["model_name"], error=payload.get("error"), created_at=payload["created_at"], updated_at=payload["updated_at"], answer=payload.get("answer"), raw_trace_text=payload.get("raw_trace_text"), normalized_trace_text=payload.get("normalized_trace_text"), sentences=payload.get("sentences"), generation_metadata=payload.get("generation_metadata"), ) @app.get("/", include_in_schema=False) def index() -> FileResponse: return FileResponse(FRONTEND_DIR / "index.html") @app.get("/healthz", response_model=HealthResponse) def healthz() -> HealthResponse: settings = get_settings() return HealthResponse( status="ok", model_name=settings.model_name, device_preference=settings.device_preference, dtype_preference=settings.dtype_preference, preload_model=settings.preload_model, cuda_available=torch.cuda.is_available(), mps_available=torch.backends.mps.is_available(), require_auth=settings.require_auth, public_api_enabled=settings.public_api_enabled, max_queued_jobs=settings.max_queued_jobs, max_active_jobs_per_user=settings.max_active_jobs_per_user, ) @app.get("/api/me", response_model=CurrentUserResponse) def me(request: Request) -> CurrentUserResponse: settings = get_settings() user = get_optional_user(request) return CurrentUserResponse( authenticated=user is not None, auth_required=settings.require_auth, username=user.username if user else None, full_name=user.display_name if user else None, avatar_url=user.avatar_url if user else None, ) @app.post("/api/warmup", response_model=WarmupResponse) def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse: settings = get_settings() bundle = load_model_bundle( model_name or settings.model_name, device_preference=device_preference or settings.device_preference, dtype_preference=settings.dtype_preference, attn_implementation=settings.attn_implementation, trust_remote_code=settings.trust_remote_code, low_cpu_mem_usage=settings.low_cpu_mem_usage, ) return WarmupResponse( status="ready", model_name=bundle.model_name, device=str(bundle.device), dtype=str(bundle.dtype), capability=asdict(bundle.capability), ) @app.post("/api/analyze", response_model=AnalysisResult) def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult: settings = get_settings() require_user(http_request, settings) try: return compute_attribution_analysis( question=request.question, model_name=request.model_name, take_log=request.take_log, max_sentences=request.max_sentences, max_trace_tokens=request.max_trace_tokens, validate_top_k=request.validate_top_k, max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_p=request.top_p, device_preference=request.device_preference, dtype_preference=request.dtype_preference, attn_implementation=request.attn_implementation, trust_remote_code=request.trust_remote_code, low_cpu_mem_usage=request.low_cpu_mem_usage, ) except Exception as exc: # pragma: no cover - runtime path logger.exception("Analysis request failed") raise HTTPException(status_code=500, detail=str(exc)) from exc @app.get("/api/sessions", response_model=list[SessionResponse]) def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]: settings = get_settings() user = require_user(request, settings) service = get_session_service() payloads = service.list_sessions(user.id, limit=limit) return [_to_session_response(payload) for payload in payloads] @app.post("/api/sessions", response_model=SessionResponse) def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse: settings = get_settings() user = require_user(http_request, settings) service = get_session_service() try: session = service.create_session( AnalysisRequest( question=request.question, model_name=request.model_name, take_log=request.take_log, max_sentences=request.max_sentences, max_trace_tokens=request.max_trace_tokens, validate_top_k=request.validate_top_k, max_new_tokens=request.max_new_tokens, temperature=request.temperature, top_p=request.top_p, device_preference=request.device_preference, dtype_preference=request.dtype_preference, attn_implementation=request.attn_implementation, trust_remote_code=request.trust_remote_code, low_cpu_mem_usage=request.low_cpu_mem_usage, ), owner_id=user.id, owner_name=user.display_name, ) except SessionLimitError as exc: raise HTTPException(status_code=429, detail=str(exc)) from exc payload = service.get_session_payload(session.id, owner_id=user.id) return _to_session_response(payload) @app.get("/api/sessions/{session_id}", response_model=SessionResponse) def get_session(session_id: str, request: Request) -> SessionResponse: settings = get_settings() user = require_user(request, settings) service = get_session_service() try: payload = service.get_session_payload(session_id, owner_id=user.id) except KeyError as exc: raise HTTPException(status_code=404, detail="Session not found") from exc except SessionAccessError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc return _to_session_response(payload) @app.post("/api/sessions/{session_id}/analyze", response_model=SessionResponse) def analyze_session(session_id: str, request: Request) -> SessionResponse: settings = get_settings() user = require_user(request, settings) service = get_session_service() try: session = service.start_analysis(session_id, owner_id=user.id) payload = service.get_session_payload(session.id, owner_id=user.id) except KeyError as exc: raise HTTPException(status_code=404, detail="Session not found") from exc except SessionAccessError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc return _to_session_response(payload) @app.get("/api/sessions/{session_id}/result", response_model=SessionResultResponse) def get_session_result(session_id: str, request: Request) -> SessionResultResponse: settings = get_settings() user = require_user(request, settings) service = get_session_service() try: payload = service.get_session_payload(session_id, owner_id=user.id) except KeyError as exc: raise HTTPException(status_code=404, detail="Session not found") from exc except SessionAccessError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc session_response = _to_session_response(payload) analysis_payload = payload.get("analysis") return SessionResultResponse( session=session_response, analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None, ) @app.get("/api/sessions/{session_id}/export.json") def export_session_json(session_id: str, request: Request) -> StreamingResponse: settings = get_settings() user = require_user(request, settings) service = get_session_service() try: payload = service.get_session_payload(session_id, owner_id=user.id) except KeyError as exc: raise HTTPException(status_code=404, detail="Session not found") from exc except SessionAccessError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc result = SessionResultResponse( session=_to_session_response(payload), analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None, ) return StreamingResponse( iter([result.model_dump_json(indent=2)]), media_type="application/json", headers={"content-disposition": f'attachment; filename="{session_id}.json"'}, ) @app.get("/api/sessions/{session_id}/export.csv") def export_session_csv(session_id: str, request: Request) -> StreamingResponse: settings = get_settings() user = require_user(request, settings) service = get_session_service() try: result = service.get_analysis_result(session_id, owner_id=user.id) except KeyError as exc: raise HTTPException(status_code=404, detail="Analysis result not found") from exc except SessionAccessError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc buffer = StringIO() buffer.write("source_sentence_idx,target_sentence_idx,score\n") for edge in result.top_edges: buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n") return StreamingResponse( iter([buffer.getvalue()]), media_type="text/csv", headers={"content-disposition": f'attachment; filename="{session_id}.csv"'}, )