sentence-uncertainty / src /api_server.py
rdisipio's picture
model is flan-t5-small
2dfb0f9 unverified
"""FastAPI wrapper around the summary uncertainty scoring service."""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import logging
import os
from pathlib import Path
from typing import Any, Callable, Protocol, Sequence
import aiohttp
logger = logging.getLogger(__name__)
from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel, Field, field_validator
from .dummy_backend import build_dummy_scorer
from .normalization import QuantileNormalizer, load_quantile_normalizer
from .scorer import SummaryScore
class ScoringService(Protocol):
"""Minimal service interface required by the API layer."""
def score_summary(
self,
source: str,
summary: str,
sentences: Sequence[str] | None = None,
sample_count: int = 40,
top_k_tokens: int | None = None,
seed: int | None = None,
) -> SummaryScore:
"""Score a displayed summary and return sentence-level uncertainty."""
class ScoreRequest(BaseModel):
"""Request payload for uncertainty scoring."""
source: str = Field(min_length=1)
summary: str = Field(min_length=1)
sentences: list[str] | None = None
sample_count: int = Field(default=7, ge=1, le=100)
top_k_tokens: int | None = Field(default=None, ge=1)
seed: int | None = None
compute_consistency: bool = True
@field_validator("sentences")
@classmethod
def validate_sentences(cls, sentences: list[str] | None) -> list[str] | None:
"""Reject blank sentence strings at the API boundary."""
if sentences is None:
return None
normalized_sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
if not normalized_sentences:
raise ValueError("sentences must contain at least one non-empty sentence.")
return normalized_sentences
class HealthResponse(BaseModel):
"""Health check response."""
status: str
class WakeResponse(BaseModel):
"""Wake-up response."""
status: str
class ReadinessResponse(BaseModel):
"""Readiness check response."""
ready: bool
async def _keep_alive_task(frontend_url: str, interval_seconds: int = 60) -> None:
"""Periodically ping the frontend to keep it from going to sleep.
Transforms huggingface.co/spaces URLs to the corresponding hf.space direct URL.
"""
await asyncio.sleep(5) # Wait a bit before starting pings
# Transform https://huggingface.co/spaces/owner/repo to https://owner-repo.hf.space/health
if "huggingface.co/spaces/" in frontend_url:
spaces_path = frontend_url.split("huggingface.co/spaces/")[1]
owner_repo = spaces_path.replace("/", "-")
ping_url = f"https://{owner_repo}.hf.space/health"
else:
ping_url = frontend_url
async with aiohttp.ClientSession() as session:
while True:
try:
logger.debug(f"Sending keep-alive ping to frontend")
async with session.get(
ping_url,
timeout=aiohttp.ClientTimeout(total=10),
allow_redirects=False,
):
pass
except asyncio.TimeoutError:
logger.warning(f"Timeout pinging frontend at {ping_url}")
except Exception as e:
logger.debug(f"Keep-alive ping error (not critical): {e}")
await asyncio.sleep(interval_seconds)
def create_app(
scoring_service_factory: Callable[[], ScoringService],
*,
normalizer: QuantileNormalizer,
ambiguity_normalizer: QuantileNormalizer,
consistency_normalizer: QuantileNormalizer,
api_token: str | None = None,
title: str = "Summary Uncertainty API",
) -> FastAPI:
"""Create the FastAPI application with an injected scoring service factory.
The factory is called in a background thread during lifespan startup so the
server can accept /wake and /is-ready requests while the model loads.
"""
async def _load_service(app: FastAPI) -> None:
loop = asyncio.get_event_loop()
app.state.scoring_service = await loop.run_in_executor(None, scoring_service_factory)
app.state.ready = True
logger.info("Scoring service ready")
@asynccontextmanager
async def lifespan(app: FastAPI) -> Any:
app.state.ready = False
app.state.scoring_service = None
app.state.normalizer = normalizer
app.state.ambiguity_normalizer = ambiguity_normalizer
app.state.consistency_normalizer = consistency_normalizer
asyncio.create_task(_load_service(app))
frontend_url = os.environ.get("FRONTEND_URL")
if frontend_url:
asyncio.create_task(_keep_alive_task(frontend_url))
yield
app = FastAPI(title=title, lifespan=lifespan)
@app.get("/")
async def root() -> dict[str, Any]:
"""Return a brief description of the service and its endpoints."""
return {
"service": title,
"description": (
"Estimates per-sentence epistemic uncertainty for a given "
"source text and its summary."
),
"endpoints": {
"GET /health": "Liveness check.",
"GET /wake": "Wake-up ping; call on frontend start to trigger cold-start recovery.",
"GET /is-ready": "Returns {ready: true} once the scoring model is fully loaded.",
"POST /score": (
"Score a summary. Required fields: source (str), summary (str). "
"Optional: sample_count (int, 1-100), sentences (list[str]), "
"top_k_tokens (int), seed (int)."
),
},
"docs": "/docs",
}
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
"""Return API liveness information."""
return HealthResponse(status="ok")
@app.get("/wake", response_model=WakeResponse)
async def wake() -> WakeResponse:
"""Wake-up ping for cold-start scenarios.
Call this endpoint when the frontend starts so that the API server
is fully warmed up by the time the user submits their first request.
"""
logger.info("GET /wake — server awake")
return WakeResponse(status="awake")
@app.get("/is-ready", response_model=ReadinessResponse)
async def is_ready() -> ReadinessResponse:
"""Return whether the scoring model has finished loading."""
return ReadinessResponse(ready=getattr(app.state, "ready", False))
@app.post("/score")
async def score_summary(
request: ScoreRequest,
x_api_token: str | None = Header(default=None),
) -> dict[str, Any]:
"""Score the displayed summary without re-generating it."""
if api_token and x_api_token != api_token.strip():
raise HTTPException(status_code=401, detail="Invalid or missing API token.")
if not app.state.ready:
raise HTTPException(status_code=503, detail="Scoring service is still loading.")
logger.info(
"POST /score — sample_count=%d top_k_tokens=%s seed=%s",
request.sample_count,
request.top_k_tokens,
request.seed,
)
try:
result = app.state.scoring_service.score_summary(
source=request.source,
summary=request.summary,
sentences=request.sentences,
sample_count=request.sample_count,
top_k_tokens=request.top_k_tokens,
seed=request.seed,
)
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
except NotImplementedError as error:
raise HTTPException(status_code=503, detail=str(error)) from error
except Exception as error:
raise HTTPException(status_code=500, detail=str(error)) from error
return _serialize_summary_score(
result,
app.state.normalizer,
app.state.ambiguity_normalizer,
app.state.consistency_normalizer,
compute_consistency=request.compute_consistency,
)
return app
class UnconfiguredScoringService:
"""Fallback scoring service that forces explicit model wiring."""
def score_summary(
self,
source: str,
summary: str,
sentences: Sequence[str] | None = None,
sample_count: int = 40,
top_k_tokens: int | None = None,
seed: int | None = None,
) -> SummaryScore:
"""Raise until a real posterior-scoring backend is configured."""
del source
del summary
del sentences
del sample_count
del top_k_tokens
del seed
raise NotImplementedError(
"No scoring service has been configured. "
"Inject a SummaryUncertaintyScorer backed by the trained model and Laplace sampler."
)
def _build_default_service() -> ScoringService:
"""Build the default scoring service from environment configuration.
Recognised SCORING_BACKEND values:
- ``dummy`` – rule-based mock, no model required (default)
- ``mc_dropout`` – teacher-forced MC Dropout over a HuggingFace seq2seq model
- ``lora_laplace`` – LoRA + diagonal Laplace approximation
- ``unconfigured`` – raises on every request (forces explicit wiring)
MC Dropout environment variables:
- ``MC_DROPOUT_MODEL`` – HuggingFace model identifier (default: facebook/bart-large-cnn)
- ``MC_DROPOUT_DEVICE`` – torch device string, e.g. ``cpu`` or ``cuda`` (auto-detected when unset)
LoRA-Laplace environment variables:
- ``LORA_BASE_MODEL`` – HuggingFace base model identifier (default: google/flan-t5-small)
- ``LORA_ADAPTER_PATH`` – path to the PEFT adapter checkpoint directory (required)
- ``LORA_SAMPLER_PATH`` – path to a pre-fitted laplace_sampler.npz (required);
fit offline with compute_uncertainty_scores_lora_laplace.py --save-sampler
- ``LORA_DEVICE`` – torch device string (auto-detected when unset)
"""
backend_name = os.environ.get("SCORING_BACKEND", "dummy").strip().lower()
if backend_name == "dummy":
return build_dummy_scorer()
if backend_name == "mc_dropout":
from .mc_dropout_backend import build_mc_dropout_scorer
model_name = os.environ.get("MC_DROPOUT_MODEL", "facebook/bart-large-cnn")
device = os.environ.get("MC_DROPOUT_DEVICE") or None
return build_mc_dropout_scorer(model_name=model_name, device=device)
if backend_name == "lora_laplace":
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .lora_laplace_backend import LoraLaplaceBackend, load_laplace_sampler
from .scorer import SummaryUncertaintyScorer
base_model_name = os.environ.get("LORA_BASE_MODEL", "google/flan-t5-small")
adapter_path = os.environ.get("LORA_ADAPTER_PATH", "")
sampler_path = os.environ.get("LORA_SAMPLER_PATH", "")
device = os.environ.get("LORA_DEVICE") or None
if not adapter_path:
raise RuntimeError("LORA_ADAPTER_PATH must be set for the lora_laplace backend.")
if not sampler_path:
raise RuntimeError(
"LORA_SAMPLER_PATH must be set for the lora_laplace backend. "
"Fit the sampler offline with compute_uncertainty_scores_lora_laplace.py "
"--save-sampler and provide the resulting .npz path here."
)
import torch
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
peft_model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
backend = LoraLaplaceBackend(peft_model=peft_model, tokenizer=tokenizer, device=device)
sampler = load_laplace_sampler(sampler_path)
return SummaryUncertaintyScorer(backend=backend, posterior_sampler=sampler)
if backend_name == "unconfigured":
return UnconfiguredScoringService()
raise RuntimeError(f"Unsupported SCORING_BACKEND value: {backend_name}")
def _build_default_normalizer() -> QuantileNormalizer:
"""Load the configured uncertainty normalizer."""
default_path = Path(__file__).resolve().parent.parent / "config" / "uncertainty_quantiles_mc_dropout.json"
config_path = os.environ.get("QUANTILE_CONFIG_PATH", str(default_path))
normalizer = load_quantile_normalizer(config_path)
logger.info(
"Quantile normalizer loaded from %r — boundaries: %s",
config_path,
[f"{b:.4f}" for b in normalizer.boundaries],
)
return normalizer
def _build_default_consistency_normalizer() -> QuantileNormalizer:
"""Load the configured consistency normalizer, falling back to the uncertainty normalizer.
Boundaries are over -mean_logprob (a positive value; higher = less consistent).
Fit them from a calibration corpus with compute_uncertainty_scores_lora_laplace.py
and save to the path set in CONSISTENCY_QUANTILE_CONFIG_PATH.
"""
default_path = Path(__file__).resolve().parent.parent / "config" / "consistency_quantiles_lora_laplace.json"
config_path = os.environ.get("CONSISTENCY_QUANTILE_CONFIG_PATH", str(default_path))
if not Path(config_path).exists():
logger.warning(
"Consistency quantile config not found at %r — falling back to uncertainty normalizer.",
config_path,
)
return _build_default_normalizer()
normalizer = load_quantile_normalizer(config_path)
logger.info("Consistency normalizer loaded from %r", config_path)
return normalizer
def _build_default_ambiguity_normalizer() -> QuantileNormalizer:
"""Load the configured ambiguity normalizer, falling back to the uncertainty normalizer."""
default_path = Path(__file__).resolve().parent.parent / "config" / "ambiguity_quantiles_mc_dropout.json"
config_path = os.environ.get("AMBIGUITY_QUANTILE_CONFIG_PATH", str(default_path))
if not Path(config_path).exists():
logger.warning(
"Ambiguity quantile config not found at %r — falling back to uncertainty normalizer.",
config_path,
)
return _build_default_normalizer()
normalizer = load_quantile_normalizer(config_path)
logger.info(
"Ambiguity normalizer loaded from %r — boundaries: %s",
config_path,
[f"{b:.4f}" for b in normalizer.boundaries],
)
return normalizer
def _serialize_summary_score(
summary_score: SummaryScore,
normalizer: QuantileNormalizer,
ambiguity_normalizer: QuantileNormalizer,
consistency_normalizer: QuantileNormalizer,
compute_consistency: bool = True,
) -> dict[str, Any]:
"""Serialize a summary score and attach display-oriented uncertainty values."""
payload = summary_score.to_dict()
normalization: dict[str, Any] = {
"boundaries": list(normalizer.boundaries),
"ambiguity_boundaries": list(ambiguity_normalizer.boundaries),
}
if compute_consistency:
normalization["consistency_boundaries"] = list(consistency_normalizer.boundaries)
payload["normalization"] = normalization
for sentence_result in payload["sentence_results"]:
raw_uncertainty = float(sentence_result["uncertainty"])
sentence_result["uncertainty_raw"] = raw_uncertainty
sentence_result["uncertainty_score"] = normalizer.normalize(raw_uncertainty)
sentence_result["uncertainty_band"] = normalizer.band(raw_uncertainty)
raw_ambiguity = float(sentence_result["expected_entropy"])
sentence_result["ambiguity_score"] = ambiguity_normalizer.normalize(raw_ambiguity)
sentence_result["ambiguity_band"] = ambiguity_normalizer.band(raw_ambiguity)
if compute_consistency:
# consistency_score: higher = more consistent with the source.
# mean_logprob is negative; negate it so higher raw = less consistent,
# then invert the 0-100 scale so the final score reads naturally.
raw_inconsistency = -float(sentence_result["mean_logprob"])
inconsistency_normalized = consistency_normalizer.normalize(raw_inconsistency)
sentence_result["consistency_score"] = round(100.0 - inconsistency_normalized, 4)
sentence_result["consistency_band"] = _invert_band(
consistency_normalizer.band(raw_inconsistency)
)
return payload
def _invert_band(band: str) -> str:
if band == "low":
return "high"
if band == "high":
return "low"
return band
_api_token = os.environ.get("API_TOKEN") or None
app = create_app(
_build_default_service,
normalizer=_build_default_normalizer(),
ambiguity_normalizer=_build_default_ambiguity_normalizer(),
consistency_normalizer=_build_default_consistency_normalizer(),
api_token=_api_token,
)