Spaces:
Running
Running
File size: 17,241 Bytes
c03fd30 af331f5 c03fd30 bc60f4f 694ac62 9d23e11 af331f5 c03fd30 9c6a9a7 bc60f4f 8b650a9 c03fd30 694ac62 9d23e11 c03fd30 2dfb0f9 c03fd30 329c3b8 c03fd30 fa0c40c af331f5 9c6a9a7 8fe9283 9c6a9a7 8fe9283 9c6a9a7 25740c7 8fe9283 70f02c8 a4dfc8b 9c6a9a7 8fe9283 9c6a9a7 70f02c8 9c6a9a7 c07e2e5 c03fd30 af331f5 c03fd30 9d23e11 288990e 63cc4ab c07e2e5 c03fd30 af331f5 c03fd30 10d217f af331f5 10d217f 288990e 63cc4ab 10d217f 9c6a9a7 c03fd30 2db0d34 af331f5 2db0d34 c03fd30 fa0c40c 72717fc fa0c40c af331f5 8b650a9 c04fa04 8b650a9 c03fd30 c04fa04 8b650a9 10d217f bc60f4f c03fd30 63cc4ab 329c3b8 63cc4ab c03fd30 694ac62 d6e5d06 ecb6675 d6e5d06 ecb6675 a813743 ecb6675 d6e5d06 694ac62 d6e5d06 ecb6675 a813743 ecb6675 ecfab30 ecb6675 8297b42 ecb6675 694ac62 9d23e11 b36f3d0 9d23e11 dac93ad 14f6420 531fb2e 9d23e11 63cc4ab 288990e 9d23e11 288990e 63cc4ab 329c3b8 9d23e11 329c3b8 c5a5a68 329c3b8 9d23e11 b36f3d0 9d23e11 288990e 329c3b8 63cc4ab 9d23e11 63cc4ab cd9bcc1 9d23e11 af331f5 9d23e11 288990e 63cc4ab cd9bcc1 9d23e11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 | """FastAPI wrapper around the summary uncertainty scoring service."""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import logging
import os
from pathlib import Path
from typing import Any, Callable, Protocol, Sequence
import aiohttp
logger = logging.getLogger(__name__)
from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel, Field, field_validator
from .dummy_backend import build_dummy_scorer
from .normalization import QuantileNormalizer, load_quantile_normalizer
from .scorer import SummaryScore
class ScoringService(Protocol):
"""Minimal service interface required by the API layer."""
def score_summary(
self,
source: str,
summary: str,
sentences: Sequence[str] | None = None,
sample_count: int = 40,
top_k_tokens: int | None = None,
seed: int | None = None,
) -> SummaryScore:
"""Score a displayed summary and return sentence-level uncertainty."""
class ScoreRequest(BaseModel):
"""Request payload for uncertainty scoring."""
source: str = Field(min_length=1)
summary: str = Field(min_length=1)
sentences: list[str] | None = None
sample_count: int = Field(default=7, ge=1, le=100)
top_k_tokens: int | None = Field(default=None, ge=1)
seed: int | None = None
compute_consistency: bool = True
@field_validator("sentences")
@classmethod
def validate_sentences(cls, sentences: list[str] | None) -> list[str] | None:
"""Reject blank sentence strings at the API boundary."""
if sentences is None:
return None
normalized_sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
if not normalized_sentences:
raise ValueError("sentences must contain at least one non-empty sentence.")
return normalized_sentences
class HealthResponse(BaseModel):
"""Health check response."""
status: str
class WakeResponse(BaseModel):
"""Wake-up response."""
status: str
class ReadinessResponse(BaseModel):
"""Readiness check response."""
ready: bool
async def _keep_alive_task(frontend_url: str, interval_seconds: int = 60) -> None:
"""Periodically ping the frontend to keep it from going to sleep.
Transforms huggingface.co/spaces URLs to the corresponding hf.space direct URL.
"""
await asyncio.sleep(5) # Wait a bit before starting pings
# Transform https://huggingface.co/spaces/owner/repo to https://owner-repo.hf.space/health
if "huggingface.co/spaces/" in frontend_url:
spaces_path = frontend_url.split("huggingface.co/spaces/")[1]
owner_repo = spaces_path.replace("/", "-")
ping_url = f"https://{owner_repo}.hf.space/health"
else:
ping_url = frontend_url
async with aiohttp.ClientSession() as session:
while True:
try:
logger.debug(f"Sending keep-alive ping to frontend")
async with session.get(
ping_url,
timeout=aiohttp.ClientTimeout(total=10),
allow_redirects=False,
):
pass
except asyncio.TimeoutError:
logger.warning(f"Timeout pinging frontend at {ping_url}")
except Exception as e:
logger.debug(f"Keep-alive ping error (not critical): {e}")
await asyncio.sleep(interval_seconds)
def create_app(
scoring_service_factory: Callable[[], ScoringService],
*,
normalizer: QuantileNormalizer,
ambiguity_normalizer: QuantileNormalizer,
consistency_normalizer: QuantileNormalizer,
api_token: str | None = None,
title: str = "Summary Uncertainty API",
) -> FastAPI:
"""Create the FastAPI application with an injected scoring service factory.
The factory is called in a background thread during lifespan startup so the
server can accept /wake and /is-ready requests while the model loads.
"""
async def _load_service(app: FastAPI) -> None:
loop = asyncio.get_event_loop()
app.state.scoring_service = await loop.run_in_executor(None, scoring_service_factory)
app.state.ready = True
logger.info("Scoring service ready")
@asynccontextmanager
async def lifespan(app: FastAPI) -> Any:
app.state.ready = False
app.state.scoring_service = None
app.state.normalizer = normalizer
app.state.ambiguity_normalizer = ambiguity_normalizer
app.state.consistency_normalizer = consistency_normalizer
asyncio.create_task(_load_service(app))
frontend_url = os.environ.get("FRONTEND_URL")
if frontend_url:
asyncio.create_task(_keep_alive_task(frontend_url))
yield
app = FastAPI(title=title, lifespan=lifespan)
@app.get("/")
async def root() -> dict[str, Any]:
"""Return a brief description of the service and its endpoints."""
return {
"service": title,
"description": (
"Estimates per-sentence epistemic uncertainty for a given "
"source text and its summary."
),
"endpoints": {
"GET /health": "Liveness check.",
"GET /wake": "Wake-up ping; call on frontend start to trigger cold-start recovery.",
"GET /is-ready": "Returns {ready: true} once the scoring model is fully loaded.",
"POST /score": (
"Score a summary. Required fields: source (str), summary (str). "
"Optional: sample_count (int, 1-100), sentences (list[str]), "
"top_k_tokens (int), seed (int)."
),
},
"docs": "/docs",
}
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
"""Return API liveness information."""
return HealthResponse(status="ok")
@app.get("/wake", response_model=WakeResponse)
async def wake() -> WakeResponse:
"""Wake-up ping for cold-start scenarios.
Call this endpoint when the frontend starts so that the API server
is fully warmed up by the time the user submits their first request.
"""
logger.info("GET /wake β server awake")
return WakeResponse(status="awake")
@app.get("/is-ready", response_model=ReadinessResponse)
async def is_ready() -> ReadinessResponse:
"""Return whether the scoring model has finished loading."""
return ReadinessResponse(ready=getattr(app.state, "ready", False))
@app.post("/score")
async def score_summary(
request: ScoreRequest,
x_api_token: str | None = Header(default=None),
) -> dict[str, Any]:
"""Score the displayed summary without re-generating it."""
if api_token and x_api_token != api_token.strip():
raise HTTPException(status_code=401, detail="Invalid or missing API token.")
if not app.state.ready:
raise HTTPException(status_code=503, detail="Scoring service is still loading.")
logger.info(
"POST /score β sample_count=%d top_k_tokens=%s seed=%s",
request.sample_count,
request.top_k_tokens,
request.seed,
)
try:
result = app.state.scoring_service.score_summary(
source=request.source,
summary=request.summary,
sentences=request.sentences,
sample_count=request.sample_count,
top_k_tokens=request.top_k_tokens,
seed=request.seed,
)
except ValueError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
except NotImplementedError as error:
raise HTTPException(status_code=503, detail=str(error)) from error
except Exception as error:
raise HTTPException(status_code=500, detail=str(error)) from error
return _serialize_summary_score(
result,
app.state.normalizer,
app.state.ambiguity_normalizer,
app.state.consistency_normalizer,
compute_consistency=request.compute_consistency,
)
return app
class UnconfiguredScoringService:
"""Fallback scoring service that forces explicit model wiring."""
def score_summary(
self,
source: str,
summary: str,
sentences: Sequence[str] | None = None,
sample_count: int = 40,
top_k_tokens: int | None = None,
seed: int | None = None,
) -> SummaryScore:
"""Raise until a real posterior-scoring backend is configured."""
del source
del summary
del sentences
del sample_count
del top_k_tokens
del seed
raise NotImplementedError(
"No scoring service has been configured. "
"Inject a SummaryUncertaintyScorer backed by the trained model and Laplace sampler."
)
def _build_default_service() -> ScoringService:
"""Build the default scoring service from environment configuration.
Recognised SCORING_BACKEND values:
- ``dummy`` β rule-based mock, no model required (default)
- ``mc_dropout`` β teacher-forced MC Dropout over a HuggingFace seq2seq model
- ``lora_laplace`` β LoRA + diagonal Laplace approximation
- ``unconfigured`` β raises on every request (forces explicit wiring)
MC Dropout environment variables:
- ``MC_DROPOUT_MODEL`` β HuggingFace model identifier (default: facebook/bart-large-cnn)
- ``MC_DROPOUT_DEVICE`` β torch device string, e.g. ``cpu`` or ``cuda`` (auto-detected when unset)
LoRA-Laplace environment variables:
- ``LORA_BASE_MODEL`` β HuggingFace base model identifier (default: google/flan-t5-small)
- ``LORA_ADAPTER_PATH`` β path to the PEFT adapter checkpoint directory (required)
- ``LORA_SAMPLER_PATH`` β path to a pre-fitted laplace_sampler.npz (required);
fit offline with compute_uncertainty_scores_lora_laplace.py --save-sampler
- ``LORA_DEVICE`` β torch device string (auto-detected when unset)
"""
backend_name = os.environ.get("SCORING_BACKEND", "dummy").strip().lower()
if backend_name == "dummy":
return build_dummy_scorer()
if backend_name == "mc_dropout":
from .mc_dropout_backend import build_mc_dropout_scorer
model_name = os.environ.get("MC_DROPOUT_MODEL", "facebook/bart-large-cnn")
device = os.environ.get("MC_DROPOUT_DEVICE") or None
return build_mc_dropout_scorer(model_name=model_name, device=device)
if backend_name == "lora_laplace":
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .lora_laplace_backend import LoraLaplaceBackend, load_laplace_sampler
from .scorer import SummaryUncertaintyScorer
base_model_name = os.environ.get("LORA_BASE_MODEL", "google/flan-t5-small")
adapter_path = os.environ.get("LORA_ADAPTER_PATH", "")
sampler_path = os.environ.get("LORA_SAMPLER_PATH", "")
device = os.environ.get("LORA_DEVICE") or None
if not adapter_path:
raise RuntimeError("LORA_ADAPTER_PATH must be set for the lora_laplace backend.")
if not sampler_path:
raise RuntimeError(
"LORA_SAMPLER_PATH must be set for the lora_laplace backend. "
"Fit the sampler offline with compute_uncertainty_scores_lora_laplace.py "
"--save-sampler and provide the resulting .npz path here."
)
import torch
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
peft_model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
backend = LoraLaplaceBackend(peft_model=peft_model, tokenizer=tokenizer, device=device)
sampler = load_laplace_sampler(sampler_path)
return SummaryUncertaintyScorer(backend=backend, posterior_sampler=sampler)
if backend_name == "unconfigured":
return UnconfiguredScoringService()
raise RuntimeError(f"Unsupported SCORING_BACKEND value: {backend_name}")
def _build_default_normalizer() -> QuantileNormalizer:
"""Load the configured uncertainty normalizer."""
default_path = Path(__file__).resolve().parent.parent / "config" / "uncertainty_quantiles_mc_dropout.json"
config_path = os.environ.get("QUANTILE_CONFIG_PATH", str(default_path))
normalizer = load_quantile_normalizer(config_path)
logger.info(
"Quantile normalizer loaded from %r β boundaries: %s",
config_path,
[f"{b:.4f}" for b in normalizer.boundaries],
)
return normalizer
def _build_default_consistency_normalizer() -> QuantileNormalizer:
"""Load the configured consistency normalizer, falling back to the uncertainty normalizer.
Boundaries are over -mean_logprob (a positive value; higher = less consistent).
Fit them from a calibration corpus with compute_uncertainty_scores_lora_laplace.py
and save to the path set in CONSISTENCY_QUANTILE_CONFIG_PATH.
"""
default_path = Path(__file__).resolve().parent.parent / "config" / "consistency_quantiles_lora_laplace.json"
config_path = os.environ.get("CONSISTENCY_QUANTILE_CONFIG_PATH", str(default_path))
if not Path(config_path).exists():
logger.warning(
"Consistency quantile config not found at %r β falling back to uncertainty normalizer.",
config_path,
)
return _build_default_normalizer()
normalizer = load_quantile_normalizer(config_path)
logger.info("Consistency normalizer loaded from %r", config_path)
return normalizer
def _build_default_ambiguity_normalizer() -> QuantileNormalizer:
"""Load the configured ambiguity normalizer, falling back to the uncertainty normalizer."""
default_path = Path(__file__).resolve().parent.parent / "config" / "ambiguity_quantiles_mc_dropout.json"
config_path = os.environ.get("AMBIGUITY_QUANTILE_CONFIG_PATH", str(default_path))
if not Path(config_path).exists():
logger.warning(
"Ambiguity quantile config not found at %r β falling back to uncertainty normalizer.",
config_path,
)
return _build_default_normalizer()
normalizer = load_quantile_normalizer(config_path)
logger.info(
"Ambiguity normalizer loaded from %r β boundaries: %s",
config_path,
[f"{b:.4f}" for b in normalizer.boundaries],
)
return normalizer
def _serialize_summary_score(
summary_score: SummaryScore,
normalizer: QuantileNormalizer,
ambiguity_normalizer: QuantileNormalizer,
consistency_normalizer: QuantileNormalizer,
compute_consistency: bool = True,
) -> dict[str, Any]:
"""Serialize a summary score and attach display-oriented uncertainty values."""
payload = summary_score.to_dict()
normalization: dict[str, Any] = {
"boundaries": list(normalizer.boundaries),
"ambiguity_boundaries": list(ambiguity_normalizer.boundaries),
}
if compute_consistency:
normalization["consistency_boundaries"] = list(consistency_normalizer.boundaries)
payload["normalization"] = normalization
for sentence_result in payload["sentence_results"]:
raw_uncertainty = float(sentence_result["uncertainty"])
sentence_result["uncertainty_raw"] = raw_uncertainty
sentence_result["uncertainty_score"] = normalizer.normalize(raw_uncertainty)
sentence_result["uncertainty_band"] = normalizer.band(raw_uncertainty)
raw_ambiguity = float(sentence_result["expected_entropy"])
sentence_result["ambiguity_score"] = ambiguity_normalizer.normalize(raw_ambiguity)
sentence_result["ambiguity_band"] = ambiguity_normalizer.band(raw_ambiguity)
if compute_consistency:
# consistency_score: higher = more consistent with the source.
# mean_logprob is negative; negate it so higher raw = less consistent,
# then invert the 0-100 scale so the final score reads naturally.
raw_inconsistency = -float(sentence_result["mean_logprob"])
inconsistency_normalized = consistency_normalizer.normalize(raw_inconsistency)
sentence_result["consistency_score"] = round(100.0 - inconsistency_normalized, 4)
sentence_result["consistency_band"] = _invert_band(
consistency_normalizer.band(raw_inconsistency)
)
return payload
def _invert_band(band: str) -> str:
if band == "low":
return "high"
if band == "high":
return "low"
return band
_api_token = os.environ.get("API_TOKEN") or None
app = create_app(
_build_default_service,
normalizer=_build_default_normalizer(),
ambiguity_normalizer=_build_default_ambiguity_normalizer(),
consistency_normalizer=_build_default_consistency_normalizer(),
api_token=_api_token,
)
|