Spaces:
Running
Running
| """FastAPI wrapper around the summary uncertainty scoring service.""" | |
| from __future__ import annotations | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Callable, Protocol, Sequence | |
| import aiohttp | |
| logger = logging.getLogger(__name__) | |
| from fastapi import FastAPI, Header, HTTPException | |
| from pydantic import BaseModel, Field, field_validator | |
| from .dummy_backend import build_dummy_scorer | |
| from .normalization import QuantileNormalizer, load_quantile_normalizer | |
| from .scorer import SummaryScore | |
| class ScoringService(Protocol): | |
| """Minimal service interface required by the API layer.""" | |
| def score_summary( | |
| self, | |
| source: str, | |
| summary: str, | |
| sentences: Sequence[str] | None = None, | |
| sample_count: int = 40, | |
| top_k_tokens: int | None = None, | |
| seed: int | None = None, | |
| ) -> SummaryScore: | |
| """Score a displayed summary and return sentence-level uncertainty.""" | |
| class ScoreRequest(BaseModel): | |
| """Request payload for uncertainty scoring.""" | |
| source: str = Field(min_length=1) | |
| summary: str = Field(min_length=1) | |
| sentences: list[str] | None = None | |
| sample_count: int = Field(default=7, ge=1, le=100) | |
| top_k_tokens: int | None = Field(default=None, ge=1) | |
| seed: int | None = None | |
| compute_consistency: bool = True | |
| def validate_sentences(cls, sentences: list[str] | None) -> list[str] | None: | |
| """Reject blank sentence strings at the API boundary.""" | |
| if sentences is None: | |
| return None | |
| normalized_sentences = [sentence.strip() for sentence in sentences if sentence.strip()] | |
| if not normalized_sentences: | |
| raise ValueError("sentences must contain at least one non-empty sentence.") | |
| return normalized_sentences | |
| class HealthResponse(BaseModel): | |
| """Health check response.""" | |
| status: str | |
| class WakeResponse(BaseModel): | |
| """Wake-up response.""" | |
| status: str | |
| class ReadinessResponse(BaseModel): | |
| """Readiness check response.""" | |
| ready: bool | |
| async def _keep_alive_task(frontend_url: str, interval_seconds: int = 60) -> None: | |
| """Periodically ping the frontend to keep it from going to sleep. | |
| Transforms huggingface.co/spaces URLs to the corresponding hf.space direct URL. | |
| """ | |
| await asyncio.sleep(5) # Wait a bit before starting pings | |
| # Transform https://huggingface.co/spaces/owner/repo to https://owner-repo.hf.space/health | |
| if "huggingface.co/spaces/" in frontend_url: | |
| spaces_path = frontend_url.split("huggingface.co/spaces/")[1] | |
| owner_repo = spaces_path.replace("/", "-") | |
| ping_url = f"https://{owner_repo}.hf.space/health" | |
| else: | |
| ping_url = frontend_url | |
| async with aiohttp.ClientSession() as session: | |
| while True: | |
| try: | |
| logger.debug(f"Sending keep-alive ping to frontend") | |
| async with session.get( | |
| ping_url, | |
| timeout=aiohttp.ClientTimeout(total=10), | |
| allow_redirects=False, | |
| ): | |
| pass | |
| except asyncio.TimeoutError: | |
| logger.warning(f"Timeout pinging frontend at {ping_url}") | |
| except Exception as e: | |
| logger.debug(f"Keep-alive ping error (not critical): {e}") | |
| await asyncio.sleep(interval_seconds) | |
| def create_app( | |
| scoring_service_factory: Callable[[], ScoringService], | |
| *, | |
| normalizer: QuantileNormalizer, | |
| ambiguity_normalizer: QuantileNormalizer, | |
| consistency_normalizer: QuantileNormalizer, | |
| api_token: str | None = None, | |
| title: str = "Summary Uncertainty API", | |
| ) -> FastAPI: | |
| """Create the FastAPI application with an injected scoring service factory. | |
| The factory is called in a background thread during lifespan startup so the | |
| server can accept /wake and /is-ready requests while the model loads. | |
| """ | |
| async def _load_service(app: FastAPI) -> None: | |
| loop = asyncio.get_event_loop() | |
| app.state.scoring_service = await loop.run_in_executor(None, scoring_service_factory) | |
| app.state.ready = True | |
| logger.info("Scoring service ready") | |
| async def lifespan(app: FastAPI) -> Any: | |
| app.state.ready = False | |
| app.state.scoring_service = None | |
| app.state.normalizer = normalizer | |
| app.state.ambiguity_normalizer = ambiguity_normalizer | |
| app.state.consistency_normalizer = consistency_normalizer | |
| asyncio.create_task(_load_service(app)) | |
| frontend_url = os.environ.get("FRONTEND_URL") | |
| if frontend_url: | |
| asyncio.create_task(_keep_alive_task(frontend_url)) | |
| yield | |
| app = FastAPI(title=title, lifespan=lifespan) | |
| async def root() -> dict[str, Any]: | |
| """Return a brief description of the service and its endpoints.""" | |
| return { | |
| "service": title, | |
| "description": ( | |
| "Estimates per-sentence epistemic uncertainty for a given " | |
| "source text and its summary." | |
| ), | |
| "endpoints": { | |
| "GET /health": "Liveness check.", | |
| "GET /wake": "Wake-up ping; call on frontend start to trigger cold-start recovery.", | |
| "GET /is-ready": "Returns {ready: true} once the scoring model is fully loaded.", | |
| "POST /score": ( | |
| "Score a summary. Required fields: source (str), summary (str). " | |
| "Optional: sample_count (int, 1-100), sentences (list[str]), " | |
| "top_k_tokens (int), seed (int)." | |
| ), | |
| }, | |
| "docs": "/docs", | |
| } | |
| async def health() -> HealthResponse: | |
| """Return API liveness information.""" | |
| return HealthResponse(status="ok") | |
| async def wake() -> WakeResponse: | |
| """Wake-up ping for cold-start scenarios. | |
| Call this endpoint when the frontend starts so that the API server | |
| is fully warmed up by the time the user submits their first request. | |
| """ | |
| logger.info("GET /wake — server awake") | |
| return WakeResponse(status="awake") | |
| async def is_ready() -> ReadinessResponse: | |
| """Return whether the scoring model has finished loading.""" | |
| return ReadinessResponse(ready=getattr(app.state, "ready", False)) | |
| async def score_summary( | |
| request: ScoreRequest, | |
| x_api_token: str | None = Header(default=None), | |
| ) -> dict[str, Any]: | |
| """Score the displayed summary without re-generating it.""" | |
| if api_token and x_api_token != api_token.strip(): | |
| raise HTTPException(status_code=401, detail="Invalid or missing API token.") | |
| if not app.state.ready: | |
| raise HTTPException(status_code=503, detail="Scoring service is still loading.") | |
| logger.info( | |
| "POST /score — sample_count=%d top_k_tokens=%s seed=%s", | |
| request.sample_count, | |
| request.top_k_tokens, | |
| request.seed, | |
| ) | |
| try: | |
| result = app.state.scoring_service.score_summary( | |
| source=request.source, | |
| summary=request.summary, | |
| sentences=request.sentences, | |
| sample_count=request.sample_count, | |
| top_k_tokens=request.top_k_tokens, | |
| seed=request.seed, | |
| ) | |
| except ValueError as error: | |
| raise HTTPException(status_code=400, detail=str(error)) from error | |
| except NotImplementedError as error: | |
| raise HTTPException(status_code=503, detail=str(error)) from error | |
| except Exception as error: | |
| raise HTTPException(status_code=500, detail=str(error)) from error | |
| return _serialize_summary_score( | |
| result, | |
| app.state.normalizer, | |
| app.state.ambiguity_normalizer, | |
| app.state.consistency_normalizer, | |
| compute_consistency=request.compute_consistency, | |
| ) | |
| return app | |
| class UnconfiguredScoringService: | |
| """Fallback scoring service that forces explicit model wiring.""" | |
| def score_summary( | |
| self, | |
| source: str, | |
| summary: str, | |
| sentences: Sequence[str] | None = None, | |
| sample_count: int = 40, | |
| top_k_tokens: int | None = None, | |
| seed: int | None = None, | |
| ) -> SummaryScore: | |
| """Raise until a real posterior-scoring backend is configured.""" | |
| del source | |
| del summary | |
| del sentences | |
| del sample_count | |
| del top_k_tokens | |
| del seed | |
| raise NotImplementedError( | |
| "No scoring service has been configured. " | |
| "Inject a SummaryUncertaintyScorer backed by the trained model and Laplace sampler." | |
| ) | |
| def _build_default_service() -> ScoringService: | |
| """Build the default scoring service from environment configuration. | |
| Recognised SCORING_BACKEND values: | |
| - ``dummy`` – rule-based mock, no model required (default) | |
| - ``mc_dropout`` – teacher-forced MC Dropout over a HuggingFace seq2seq model | |
| - ``lora_laplace`` – LoRA + diagonal Laplace approximation | |
| - ``unconfigured`` – raises on every request (forces explicit wiring) | |
| MC Dropout environment variables: | |
| - ``MC_DROPOUT_MODEL`` – HuggingFace model identifier (default: facebook/bart-large-cnn) | |
| - ``MC_DROPOUT_DEVICE`` – torch device string, e.g. ``cpu`` or ``cuda`` (auto-detected when unset) | |
| LoRA-Laplace environment variables: | |
| - ``LORA_BASE_MODEL`` – HuggingFace base model identifier (default: google/flan-t5-small) | |
| - ``LORA_ADAPTER_PATH`` – path to the PEFT adapter checkpoint directory (required) | |
| - ``LORA_SAMPLER_PATH`` – path to a pre-fitted laplace_sampler.npz (required); | |
| fit offline with compute_uncertainty_scores_lora_laplace.py --save-sampler | |
| - ``LORA_DEVICE`` – torch device string (auto-detected when unset) | |
| """ | |
| backend_name = os.environ.get("SCORING_BACKEND", "dummy").strip().lower() | |
| if backend_name == "dummy": | |
| return build_dummy_scorer() | |
| if backend_name == "mc_dropout": | |
| from .mc_dropout_backend import build_mc_dropout_scorer | |
| model_name = os.environ.get("MC_DROPOUT_MODEL", "facebook/bart-large-cnn") | |
| device = os.environ.get("MC_DROPOUT_DEVICE") or None | |
| return build_mc_dropout_scorer(model_name=model_name, device=device) | |
| if backend_name == "lora_laplace": | |
| from peft import PeftModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from .lora_laplace_backend import LoraLaplaceBackend, load_laplace_sampler | |
| from .scorer import SummaryUncertaintyScorer | |
| base_model_name = os.environ.get("LORA_BASE_MODEL", "google/flan-t5-small") | |
| adapter_path = os.environ.get("LORA_ADAPTER_PATH", "") | |
| sampler_path = os.environ.get("LORA_SAMPLER_PATH", "") | |
| device = os.environ.get("LORA_DEVICE") or None | |
| if not adapter_path: | |
| raise RuntimeError("LORA_ADAPTER_PATH must be set for the lora_laplace backend.") | |
| if not sampler_path: | |
| raise RuntimeError( | |
| "LORA_SAMPLER_PATH must be set for the lora_laplace backend. " | |
| "Fit the sampler offline with compute_uncertainty_scores_lora_laplace.py " | |
| "--save-sampler and provide the resulting .npz path here." | |
| ) | |
| import torch | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) | |
| peft_model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| backend = LoraLaplaceBackend(peft_model=peft_model, tokenizer=tokenizer, device=device) | |
| sampler = load_laplace_sampler(sampler_path) | |
| return SummaryUncertaintyScorer(backend=backend, posterior_sampler=sampler) | |
| if backend_name == "unconfigured": | |
| return UnconfiguredScoringService() | |
| raise RuntimeError(f"Unsupported SCORING_BACKEND value: {backend_name}") | |
| def _build_default_normalizer() -> QuantileNormalizer: | |
| """Load the configured uncertainty normalizer.""" | |
| default_path = Path(__file__).resolve().parent.parent / "config" / "uncertainty_quantiles_mc_dropout.json" | |
| config_path = os.environ.get("QUANTILE_CONFIG_PATH", str(default_path)) | |
| normalizer = load_quantile_normalizer(config_path) | |
| logger.info( | |
| "Quantile normalizer loaded from %r — boundaries: %s", | |
| config_path, | |
| [f"{b:.4f}" for b in normalizer.boundaries], | |
| ) | |
| return normalizer | |
| def _build_default_consistency_normalizer() -> QuantileNormalizer: | |
| """Load the configured consistency normalizer, falling back to the uncertainty normalizer. | |
| Boundaries are over -mean_logprob (a positive value; higher = less consistent). | |
| Fit them from a calibration corpus with compute_uncertainty_scores_lora_laplace.py | |
| and save to the path set in CONSISTENCY_QUANTILE_CONFIG_PATH. | |
| """ | |
| default_path = Path(__file__).resolve().parent.parent / "config" / "consistency_quantiles_lora_laplace.json" | |
| config_path = os.environ.get("CONSISTENCY_QUANTILE_CONFIG_PATH", str(default_path)) | |
| if not Path(config_path).exists(): | |
| logger.warning( | |
| "Consistency quantile config not found at %r — falling back to uncertainty normalizer.", | |
| config_path, | |
| ) | |
| return _build_default_normalizer() | |
| normalizer = load_quantile_normalizer(config_path) | |
| logger.info("Consistency normalizer loaded from %r", config_path) | |
| return normalizer | |
| def _build_default_ambiguity_normalizer() -> QuantileNormalizer: | |
| """Load the configured ambiguity normalizer, falling back to the uncertainty normalizer.""" | |
| default_path = Path(__file__).resolve().parent.parent / "config" / "ambiguity_quantiles_mc_dropout.json" | |
| config_path = os.environ.get("AMBIGUITY_QUANTILE_CONFIG_PATH", str(default_path)) | |
| if not Path(config_path).exists(): | |
| logger.warning( | |
| "Ambiguity quantile config not found at %r — falling back to uncertainty normalizer.", | |
| config_path, | |
| ) | |
| return _build_default_normalizer() | |
| normalizer = load_quantile_normalizer(config_path) | |
| logger.info( | |
| "Ambiguity normalizer loaded from %r — boundaries: %s", | |
| config_path, | |
| [f"{b:.4f}" for b in normalizer.boundaries], | |
| ) | |
| return normalizer | |
| def _serialize_summary_score( | |
| summary_score: SummaryScore, | |
| normalizer: QuantileNormalizer, | |
| ambiguity_normalizer: QuantileNormalizer, | |
| consistency_normalizer: QuantileNormalizer, | |
| compute_consistency: bool = True, | |
| ) -> dict[str, Any]: | |
| """Serialize a summary score and attach display-oriented uncertainty values.""" | |
| payload = summary_score.to_dict() | |
| normalization: dict[str, Any] = { | |
| "boundaries": list(normalizer.boundaries), | |
| "ambiguity_boundaries": list(ambiguity_normalizer.boundaries), | |
| } | |
| if compute_consistency: | |
| normalization["consistency_boundaries"] = list(consistency_normalizer.boundaries) | |
| payload["normalization"] = normalization | |
| for sentence_result in payload["sentence_results"]: | |
| raw_uncertainty = float(sentence_result["uncertainty"]) | |
| sentence_result["uncertainty_raw"] = raw_uncertainty | |
| sentence_result["uncertainty_score"] = normalizer.normalize(raw_uncertainty) | |
| sentence_result["uncertainty_band"] = normalizer.band(raw_uncertainty) | |
| raw_ambiguity = float(sentence_result["expected_entropy"]) | |
| sentence_result["ambiguity_score"] = ambiguity_normalizer.normalize(raw_ambiguity) | |
| sentence_result["ambiguity_band"] = ambiguity_normalizer.band(raw_ambiguity) | |
| if compute_consistency: | |
| # consistency_score: higher = more consistent with the source. | |
| # mean_logprob is negative; negate it so higher raw = less consistent, | |
| # then invert the 0-100 scale so the final score reads naturally. | |
| raw_inconsistency = -float(sentence_result["mean_logprob"]) | |
| inconsistency_normalized = consistency_normalizer.normalize(raw_inconsistency) | |
| sentence_result["consistency_score"] = round(100.0 - inconsistency_normalized, 4) | |
| sentence_result["consistency_band"] = _invert_band( | |
| consistency_normalizer.band(raw_inconsistency) | |
| ) | |
| return payload | |
| def _invert_band(band: str) -> str: | |
| if band == "low": | |
| return "high" | |
| if band == "high": | |
| return "low" | |
| return band | |
| _api_token = os.environ.get("API_TOKEN") or None | |
| app = create_app( | |
| _build_default_service, | |
| normalizer=_build_default_normalizer(), | |
| ambiguity_normalizer=_build_default_ambiguity_normalizer(), | |
| consistency_normalizer=_build_default_consistency_normalizer(), | |
| api_token=_api_token, | |
| ) | |