Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from io import StringIO | |
| from pathlib import Path | |
| from dataclasses import asdict | |
| import os | |
| import torch | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import FileResponse | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from huggingface_hub import attach_huggingface_oauth | |
| from app.api.auth import get_optional_user, require_user | |
| from app.core.config import get_settings | |
| from app.core.runtime import load_model_bundle | |
| from app.core.runtime_pipeline import compute_attribution_analysis | |
| from app.core.schemas import ( | |
| AnalysisRequest, | |
| AnalysisResult, | |
| CurrentUserResponse, | |
| HealthResponse, | |
| SessionCreateRequest, | |
| SessionResponse, | |
| SessionResultResponse, | |
| WarmupResponse, | |
| ) | |
| from app.services.sessions import SessionAccessError, SessionLimitError, SessionService | |
| from app.storage.repository import SessionRepository | |
| from app.workers.jobs import build_job_runner | |
| logger = logging.getLogger(__name__) | |
| FRONTEND_DIR = Path(__file__).resolve().parents[1] / "frontend" | |
| async def lifespan(_app: FastAPI): | |
| settings = get_settings() | |
| repository = SessionRepository(settings.database_path) | |
| jobs = build_job_runner(settings.job_workers) | |
| _app.state.repository = repository | |
| _app.state.jobs = jobs | |
| _app.state.session_service = SessionService(settings=settings, repository=repository, jobs=jobs) | |
| if settings.preload_model: | |
| logger.info( | |
| "Preloading model '%s' on device preference '%s'.", | |
| settings.model_name, | |
| settings.device_preference, | |
| ) | |
| load_model_bundle( | |
| settings.model_name, | |
| device_preference=settings.device_preference, | |
| dtype_preference=settings.dtype_preference, | |
| attn_implementation=settings.attn_implementation, | |
| trust_remote_code=settings.trust_remote_code, | |
| low_cpu_mem_usage=settings.low_cpu_mem_usage, | |
| ) | |
| yield | |
| jobs.shutdown() | |
| app = FastAPI( | |
| title="CoT Attribution Analysis API", | |
| version="0.1.0", | |
| lifespan=lifespan, | |
| ) | |
| if os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"): | |
| attach_huggingface_oauth(app) | |
| app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static") | |
| def get_session_service() -> SessionService: | |
| return app.state.session_service | |
| def _to_session_response(payload: dict) -> SessionResponse: | |
| return SessionResponse( | |
| id=payload["id"], | |
| status=payload["status"], | |
| question=payload["question"], | |
| model_name=payload["model_name"], | |
| error=payload.get("error"), | |
| created_at=payload["created_at"], | |
| updated_at=payload["updated_at"], | |
| answer=payload.get("answer"), | |
| raw_trace_text=payload.get("raw_trace_text"), | |
| normalized_trace_text=payload.get("normalized_trace_text"), | |
| sentences=payload.get("sentences"), | |
| generation_metadata=payload.get("generation_metadata"), | |
| ) | |
| def index() -> FileResponse: | |
| return FileResponse(FRONTEND_DIR / "index.html") | |
| def healthz() -> HealthResponse: | |
| settings = get_settings() | |
| return HealthResponse( | |
| status="ok", | |
| model_name=settings.model_name, | |
| device_preference=settings.device_preference, | |
| dtype_preference=settings.dtype_preference, | |
| preload_model=settings.preload_model, | |
| cuda_available=torch.cuda.is_available(), | |
| mps_available=torch.backends.mps.is_available(), | |
| require_auth=settings.require_auth, | |
| public_api_enabled=settings.public_api_enabled, | |
| max_queued_jobs=settings.max_queued_jobs, | |
| max_active_jobs_per_user=settings.max_active_jobs_per_user, | |
| ) | |
| def me(request: Request) -> CurrentUserResponse: | |
| settings = get_settings() | |
| user = get_optional_user(request) | |
| return CurrentUserResponse( | |
| authenticated=user is not None, | |
| auth_required=settings.require_auth, | |
| username=user.username if user else None, | |
| full_name=user.display_name if user else None, | |
| avatar_url=user.avatar_url if user else None, | |
| ) | |
| def warmup(model_name: str | None = None, device_preference: str | None = None) -> WarmupResponse: | |
| settings = get_settings() | |
| bundle = load_model_bundle( | |
| model_name or settings.model_name, | |
| device_preference=device_preference or settings.device_preference, | |
| dtype_preference=settings.dtype_preference, | |
| attn_implementation=settings.attn_implementation, | |
| trust_remote_code=settings.trust_remote_code, | |
| low_cpu_mem_usage=settings.low_cpu_mem_usage, | |
| ) | |
| return WarmupResponse( | |
| status="ready", | |
| model_name=bundle.model_name, | |
| device=str(bundle.device), | |
| dtype=str(bundle.dtype), | |
| capability=asdict(bundle.capability), | |
| ) | |
| def analyze(request: AnalysisRequest, http_request: Request) -> AnalysisResult: | |
| settings = get_settings() | |
| require_user(http_request, settings) | |
| try: | |
| return compute_attribution_analysis( | |
| question=request.question, | |
| model_name=request.model_name, | |
| take_log=request.take_log, | |
| max_sentences=request.max_sentences, | |
| max_trace_tokens=request.max_trace_tokens, | |
| validate_top_k=request.validate_top_k, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| device_preference=request.device_preference, | |
| dtype_preference=request.dtype_preference, | |
| attn_implementation=request.attn_implementation, | |
| trust_remote_code=request.trust_remote_code, | |
| low_cpu_mem_usage=request.low_cpu_mem_usage, | |
| ) | |
| except Exception as exc: # pragma: no cover - runtime path | |
| logger.exception("Analysis request failed") | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| def list_sessions(request: Request, limit: int = 20) -> list[SessionResponse]: | |
| settings = get_settings() | |
| user = require_user(request, settings) | |
| service = get_session_service() | |
| payloads = service.list_sessions(user.id, limit=limit) | |
| return [_to_session_response(payload) for payload in payloads] | |
| def create_session(request: SessionCreateRequest, http_request: Request) -> SessionResponse: | |
| settings = get_settings() | |
| user = require_user(http_request, settings) | |
| service = get_session_service() | |
| try: | |
| session = service.create_session( | |
| AnalysisRequest( | |
| question=request.question, | |
| model_name=request.model_name, | |
| take_log=request.take_log, | |
| max_sentences=request.max_sentences, | |
| max_trace_tokens=request.max_trace_tokens, | |
| validate_top_k=request.validate_top_k, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| device_preference=request.device_preference, | |
| dtype_preference=request.dtype_preference, | |
| attn_implementation=request.attn_implementation, | |
| trust_remote_code=request.trust_remote_code, | |
| low_cpu_mem_usage=request.low_cpu_mem_usage, | |
| ), | |
| owner_id=user.id, | |
| owner_name=user.display_name, | |
| ) | |
| except SessionLimitError as exc: | |
| raise HTTPException(status_code=429, detail=str(exc)) from exc | |
| payload = service.get_session_payload(session.id, owner_id=user.id) | |
| return _to_session_response(payload) | |
| def get_session(session_id: str, request: Request) -> SessionResponse: | |
| settings = get_settings() | |
| user = require_user(request, settings) | |
| service = get_session_service() | |
| try: | |
| payload = service.get_session_payload(session_id, owner_id=user.id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail="Session not found") from exc | |
| except SessionAccessError as exc: | |
| raise HTTPException(status_code=403, detail=str(exc)) from exc | |
| return _to_session_response(payload) | |
| def analyze_session(session_id: str, request: Request) -> SessionResponse: | |
| settings = get_settings() | |
| user = require_user(request, settings) | |
| service = get_session_service() | |
| try: | |
| session = service.start_analysis(session_id, owner_id=user.id) | |
| payload = service.get_session_payload(session.id, owner_id=user.id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail="Session not found") from exc | |
| except SessionAccessError as exc: | |
| raise HTTPException(status_code=403, detail=str(exc)) from exc | |
| return _to_session_response(payload) | |
| def get_session_result(session_id: str, request: Request) -> SessionResultResponse: | |
| settings = get_settings() | |
| user = require_user(request, settings) | |
| service = get_session_service() | |
| try: | |
| payload = service.get_session_payload(session_id, owner_id=user.id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail="Session not found") from exc | |
| except SessionAccessError as exc: | |
| raise HTTPException(status_code=403, detail=str(exc)) from exc | |
| session_response = _to_session_response(payload) | |
| analysis_payload = payload.get("analysis") | |
| return SessionResultResponse( | |
| session=session_response, | |
| analysis=AnalysisResult.model_validate(analysis_payload) if analysis_payload else None, | |
| ) | |
| def export_session_json(session_id: str, request: Request) -> StreamingResponse: | |
| settings = get_settings() | |
| user = require_user(request, settings) | |
| service = get_session_service() | |
| try: | |
| payload = service.get_session_payload(session_id, owner_id=user.id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail="Session not found") from exc | |
| except SessionAccessError as exc: | |
| raise HTTPException(status_code=403, detail=str(exc)) from exc | |
| result = SessionResultResponse( | |
| session=_to_session_response(payload), | |
| analysis=AnalysisResult.model_validate(payload["analysis"]) if payload.get("analysis") else None, | |
| ) | |
| return StreamingResponse( | |
| iter([result.model_dump_json(indent=2)]), | |
| media_type="application/json", | |
| headers={"content-disposition": f'attachment; filename="{session_id}.json"'}, | |
| ) | |
| def export_session_csv(session_id: str, request: Request) -> StreamingResponse: | |
| settings = get_settings() | |
| user = require_user(request, settings) | |
| service = get_session_service() | |
| try: | |
| result = service.get_analysis_result(session_id, owner_id=user.id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail="Analysis result not found") from exc | |
| except SessionAccessError as exc: | |
| raise HTTPException(status_code=403, detail=str(exc)) from exc | |
| buffer = StringIO() | |
| buffer.write("source_sentence_idx,target_sentence_idx,score\n") | |
| for edge in result.top_edges: | |
| buffer.write(f"{edge.source_sentence_idx},{edge.target_sentence_idx},{edge.score:.6f}\n") | |
| return StreamingResponse( | |
| iter([buffer.getvalue()]), | |
| media_type="text/csv", | |
| headers={"content-disposition": f'attachment; filename="{session_id}.csv"'}, | |
| ) | |