File size: 17,241 Bytes
c03fd30
 
 
 
af331f5
c03fd30
bc60f4f
694ac62
9d23e11
af331f5
c03fd30
9c6a9a7
 
bc60f4f
 
8b650a9
c03fd30
 
694ac62
9d23e11
c03fd30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dfb0f9
c03fd30
 
329c3b8
c03fd30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa0c40c
 
 
 
 
 
af331f5
 
 
 
 
 
9c6a9a7
8fe9283
 
 
 
9c6a9a7
 
8fe9283
 
 
 
 
 
 
 
 
9c6a9a7
 
 
25740c7
8fe9283
 
70f02c8
 
 
a4dfc8b
9c6a9a7
8fe9283
9c6a9a7
70f02c8
9c6a9a7
c07e2e5
 
c03fd30
af331f5
c03fd30
9d23e11
288990e
63cc4ab
c07e2e5
c03fd30
 
af331f5
 
 
 
 
c03fd30
10d217f
af331f5
 
 
 
10d217f
 
 
 
 
 
288990e
63cc4ab
10d217f
9c6a9a7
 
 
 
c03fd30
 
 
 
2db0d34
 
 
 
 
 
 
 
 
 
 
 
af331f5
 
2db0d34
 
 
 
 
 
 
 
 
c03fd30
 
 
 
 
 
fa0c40c
 
 
 
 
 
 
 
72717fc
fa0c40c
 
af331f5
 
 
 
 
 
8b650a9
 
 
c04fa04
8b650a9
c03fd30
 
c04fa04
 
8b650a9
10d217f
 
 
bc60f4f
 
 
 
 
 
c03fd30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63cc4ab
 
 
 
 
329c3b8
63cc4ab
c03fd30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694ac62
d6e5d06
 
 
ecb6675
 
 
 
d6e5d06
 
 
 
ecb6675
 
a813743
ecb6675
 
 
 
d6e5d06
694ac62
 
 
 
d6e5d06
 
 
 
 
 
ecb6675
 
 
 
 
 
 
a813743
ecb6675
 
 
 
 
 
 
 
 
 
 
 
 
ecfab30
 
ecb6675
8297b42
 
ecb6675
 
 
694ac62
 
 
 
 
9d23e11
b36f3d0
9d23e11
dac93ad
14f6420
531fb2e
 
 
 
 
 
 
9d23e11
 
63cc4ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288990e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d23e11
 
 
288990e
63cc4ab
329c3b8
9d23e11
 
 
 
329c3b8
c5a5a68
 
 
329c3b8
 
 
9d23e11
 
 
 
b36f3d0
 
9d23e11
288990e
 
 
 
329c3b8
 
 
 
 
 
 
 
 
 
63cc4ab
9d23e11
 
 
63cc4ab
 
 
 
 
 
 
 
cd9bcc1
 
9d23e11
af331f5
9d23e11
288990e
63cc4ab
cd9bcc1
9d23e11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
"""FastAPI wrapper around the summary uncertainty scoring service."""

from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
import logging
import os
from pathlib import Path
from typing import Any, Callable, Protocol, Sequence

import aiohttp

logger = logging.getLogger(__name__)

from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel, Field, field_validator

from .dummy_backend import build_dummy_scorer
from .normalization import QuantileNormalizer, load_quantile_normalizer
from .scorer import SummaryScore


class ScoringService(Protocol):
    """Minimal service interface required by the API layer."""

    def score_summary(
        self,
        source: str,
        summary: str,
        sentences: Sequence[str] | None = None,
        sample_count: int = 40,
        top_k_tokens: int | None = None,
        seed: int | None = None,
    ) -> SummaryScore:
        """Score a displayed summary and return sentence-level uncertainty."""


class ScoreRequest(BaseModel):
    """Request payload for uncertainty scoring."""

    source: str = Field(min_length=1)
    summary: str = Field(min_length=1)
    sentences: list[str] | None = None
    sample_count: int = Field(default=7, ge=1, le=100)
    top_k_tokens: int | None = Field(default=None, ge=1)
    seed: int | None = None
    compute_consistency: bool = True

    @field_validator("sentences")
    @classmethod
    def validate_sentences(cls, sentences: list[str] | None) -> list[str] | None:
        """Reject blank sentence strings at the API boundary."""

        if sentences is None:
            return None
        normalized_sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
        if not normalized_sentences:
            raise ValueError("sentences must contain at least one non-empty sentence.")
        return normalized_sentences


class HealthResponse(BaseModel):
    """Health check response."""

    status: str


class WakeResponse(BaseModel):
    """Wake-up response."""

    status: str


class ReadinessResponse(BaseModel):
    """Readiness check response."""

    ready: bool


async def _keep_alive_task(frontend_url: str, interval_seconds: int = 60) -> None:
    """Periodically ping the frontend to keep it from going to sleep.

    Transforms huggingface.co/spaces URLs to the corresponding hf.space direct URL.
    """

    await asyncio.sleep(5)  # Wait a bit before starting pings

    # Transform https://huggingface.co/spaces/owner/repo to https://owner-repo.hf.space/health
    if "huggingface.co/spaces/" in frontend_url:
        spaces_path = frontend_url.split("huggingface.co/spaces/")[1]
        owner_repo = spaces_path.replace("/", "-")
        ping_url = f"https://{owner_repo}.hf.space/health"
    else:
        ping_url = frontend_url

    async with aiohttp.ClientSession() as session:
        while True:
            try:
                logger.debug(f"Sending keep-alive ping to frontend")
                async with session.get(
                    ping_url,
                    timeout=aiohttp.ClientTimeout(total=10),
                    allow_redirects=False,
                ):
                    pass
            except asyncio.TimeoutError:
                logger.warning(f"Timeout pinging frontend at {ping_url}")
            except Exception as e:
                logger.debug(f"Keep-alive ping error (not critical): {e}")
            await asyncio.sleep(interval_seconds)


def create_app(
    scoring_service_factory: Callable[[], ScoringService],
    *,
    normalizer: QuantileNormalizer,
    ambiguity_normalizer: QuantileNormalizer,
    consistency_normalizer: QuantileNormalizer,
    api_token: str | None = None,
    title: str = "Summary Uncertainty API",
) -> FastAPI:
    """Create the FastAPI application with an injected scoring service factory.

    The factory is called in a background thread during lifespan startup so the
    server can accept /wake and /is-ready requests while the model loads.
    """

    async def _load_service(app: FastAPI) -> None:
        loop = asyncio.get_event_loop()
        app.state.scoring_service = await loop.run_in_executor(None, scoring_service_factory)
        app.state.ready = True
        logger.info("Scoring service ready")

    @asynccontextmanager
    async def lifespan(app: FastAPI) -> Any:
        app.state.ready = False
        app.state.scoring_service = None
        app.state.normalizer = normalizer
        app.state.ambiguity_normalizer = ambiguity_normalizer
        app.state.consistency_normalizer = consistency_normalizer
        asyncio.create_task(_load_service(app))

        frontend_url = os.environ.get("FRONTEND_URL")
        if frontend_url:
            asyncio.create_task(_keep_alive_task(frontend_url))
        yield

    app = FastAPI(title=title, lifespan=lifespan)

    @app.get("/")
    async def root() -> dict[str, Any]:
        """Return a brief description of the service and its endpoints."""

        return {
            "service": title,
            "description": (
                "Estimates per-sentence epistemic uncertainty for a given "
                "source text and its summary."
            ),
            "endpoints": {
                "GET /health": "Liveness check.",
                "GET /wake": "Wake-up ping; call on frontend start to trigger cold-start recovery.",
                "GET /is-ready": "Returns {ready: true} once the scoring model is fully loaded.",
                "POST /score": (
                    "Score a summary. Required fields: source (str), summary (str). "
                    "Optional: sample_count (int, 1-100), sentences (list[str]), "
                    "top_k_tokens (int), seed (int)."
                ),
            },
            "docs": "/docs",
        }

    @app.get("/health", response_model=HealthResponse)
    async def health() -> HealthResponse:
        """Return API liveness information."""

        return HealthResponse(status="ok")

    @app.get("/wake", response_model=WakeResponse)
    async def wake() -> WakeResponse:
        """Wake-up ping for cold-start scenarios.

        Call this endpoint when the frontend starts so that the API server
        is fully warmed up by the time the user submits their first request.
        """

        logger.info("GET /wake β€” server awake")
        return WakeResponse(status="awake")

    @app.get("/is-ready", response_model=ReadinessResponse)
    async def is_ready() -> ReadinessResponse:
        """Return whether the scoring model has finished loading."""

        return ReadinessResponse(ready=getattr(app.state, "ready", False))

    @app.post("/score")
    async def score_summary(
        request: ScoreRequest,
        x_api_token: str | None = Header(default=None),
    ) -> dict[str, Any]:
        """Score the displayed summary without re-generating it."""

        if api_token and x_api_token != api_token.strip():
            raise HTTPException(status_code=401, detail="Invalid or missing API token.")

        if not app.state.ready:
            raise HTTPException(status_code=503, detail="Scoring service is still loading.")

        logger.info(
            "POST /score β€” sample_count=%d top_k_tokens=%s seed=%s",
            request.sample_count,
            request.top_k_tokens,
            request.seed,
        )
        try:
            result = app.state.scoring_service.score_summary(
                source=request.source,
                summary=request.summary,
                sentences=request.sentences,
                sample_count=request.sample_count,
                top_k_tokens=request.top_k_tokens,
                seed=request.seed,
            )
        except ValueError as error:
            raise HTTPException(status_code=400, detail=str(error)) from error
        except NotImplementedError as error:
            raise HTTPException(status_code=503, detail=str(error)) from error
        except Exception as error:
            raise HTTPException(status_code=500, detail=str(error)) from error

        return _serialize_summary_score(
            result,
            app.state.normalizer,
            app.state.ambiguity_normalizer,
            app.state.consistency_normalizer,
            compute_consistency=request.compute_consistency,
        )

    return app


class UnconfiguredScoringService:
    """Fallback scoring service that forces explicit model wiring."""

    def score_summary(
        self,
        source: str,
        summary: str,
        sentences: Sequence[str] | None = None,
        sample_count: int = 40,
        top_k_tokens: int | None = None,
        seed: int | None = None,
    ) -> SummaryScore:
        """Raise until a real posterior-scoring backend is configured."""

        del source
        del summary
        del sentences
        del sample_count
        del top_k_tokens
        del seed
        raise NotImplementedError(
            "No scoring service has been configured. "
            "Inject a SummaryUncertaintyScorer backed by the trained model and Laplace sampler."
        )


def _build_default_service() -> ScoringService:
    """Build the default scoring service from environment configuration.

    Recognised SCORING_BACKEND values:
    - ``dummy``        – rule-based mock, no model required (default)
    - ``mc_dropout``   – teacher-forced MC Dropout over a HuggingFace seq2seq model
    - ``lora_laplace`` – LoRA + diagonal Laplace approximation
    - ``unconfigured`` – raises on every request (forces explicit wiring)

    MC Dropout environment variables:
    - ``MC_DROPOUT_MODEL``  – HuggingFace model identifier (default: facebook/bart-large-cnn)
    - ``MC_DROPOUT_DEVICE`` – torch device string, e.g. ``cpu`` or ``cuda`` (auto-detected when unset)

    LoRA-Laplace environment variables:
    - ``LORA_BASE_MODEL``   – HuggingFace base model identifier (default: google/flan-t5-small)
    - ``LORA_ADAPTER_PATH`` – path to the PEFT adapter checkpoint directory (required)
    - ``LORA_SAMPLER_PATH`` – path to a pre-fitted laplace_sampler.npz (required);
                              fit offline with compute_uncertainty_scores_lora_laplace.py --save-sampler
    - ``LORA_DEVICE``       – torch device string (auto-detected when unset)
    """

    backend_name = os.environ.get("SCORING_BACKEND", "dummy").strip().lower()
    if backend_name == "dummy":
        return build_dummy_scorer()
    if backend_name == "mc_dropout":
        from .mc_dropout_backend import build_mc_dropout_scorer

        model_name = os.environ.get("MC_DROPOUT_MODEL", "facebook/bart-large-cnn")
        device = os.environ.get("MC_DROPOUT_DEVICE") or None
        return build_mc_dropout_scorer(model_name=model_name, device=device)
    if backend_name == "lora_laplace":
        from peft import PeftModel
        from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

        from .lora_laplace_backend import LoraLaplaceBackend, load_laplace_sampler
        from .scorer import SummaryUncertaintyScorer

        base_model_name = os.environ.get("LORA_BASE_MODEL", "google/flan-t5-small")
        adapter_path = os.environ.get("LORA_ADAPTER_PATH", "")
        sampler_path = os.environ.get("LORA_SAMPLER_PATH", "")
        device = os.environ.get("LORA_DEVICE") or None

        if not adapter_path:
            raise RuntimeError("LORA_ADAPTER_PATH must be set for the lora_laplace backend.")
        if not sampler_path:
            raise RuntimeError(
                "LORA_SAMPLER_PATH must be set for the lora_laplace backend. "
                "Fit the sampler offline with compute_uncertainty_scores_lora_laplace.py "
                "--save-sampler and provide the resulting .npz path here."
            )

        import torch

        base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        peft_model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True)
        tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        backend = LoraLaplaceBackend(peft_model=peft_model, tokenizer=tokenizer, device=device)
        sampler = load_laplace_sampler(sampler_path)
        return SummaryUncertaintyScorer(backend=backend, posterior_sampler=sampler)
    if backend_name == "unconfigured":
        return UnconfiguredScoringService()
    raise RuntimeError(f"Unsupported SCORING_BACKEND value: {backend_name}")


def _build_default_normalizer() -> QuantileNormalizer:
    """Load the configured uncertainty normalizer."""

    default_path = Path(__file__).resolve().parent.parent / "config" / "uncertainty_quantiles_mc_dropout.json"
    config_path = os.environ.get("QUANTILE_CONFIG_PATH", str(default_path))
    normalizer = load_quantile_normalizer(config_path)
    logger.info(
        "Quantile normalizer loaded from %r β€” boundaries: %s",
        config_path,
        [f"{b:.4f}" for b in normalizer.boundaries],
    )
    return normalizer


def _build_default_consistency_normalizer() -> QuantileNormalizer:
    """Load the configured consistency normalizer, falling back to the uncertainty normalizer.

    Boundaries are over -mean_logprob (a positive value; higher = less consistent).
    Fit them from a calibration corpus with compute_uncertainty_scores_lora_laplace.py
    and save to the path set in CONSISTENCY_QUANTILE_CONFIG_PATH.
    """

    default_path = Path(__file__).resolve().parent.parent / "config" / "consistency_quantiles_lora_laplace.json"
    config_path = os.environ.get("CONSISTENCY_QUANTILE_CONFIG_PATH", str(default_path))
    if not Path(config_path).exists():
        logger.warning(
            "Consistency quantile config not found at %r β€” falling back to uncertainty normalizer.",
            config_path,
        )
        return _build_default_normalizer()
    normalizer = load_quantile_normalizer(config_path)
    logger.info("Consistency normalizer loaded from %r", config_path)
    return normalizer


def _build_default_ambiguity_normalizer() -> QuantileNormalizer:
    """Load the configured ambiguity normalizer, falling back to the uncertainty normalizer."""

    default_path = Path(__file__).resolve().parent.parent / "config" / "ambiguity_quantiles_mc_dropout.json"
    config_path = os.environ.get("AMBIGUITY_QUANTILE_CONFIG_PATH", str(default_path))
    if not Path(config_path).exists():
        logger.warning(
            "Ambiguity quantile config not found at %r β€” falling back to uncertainty normalizer.",
            config_path,
        )
        return _build_default_normalizer()
    normalizer = load_quantile_normalizer(config_path)
    logger.info(
        "Ambiguity normalizer loaded from %r β€” boundaries: %s",
        config_path,
        [f"{b:.4f}" for b in normalizer.boundaries],
    )
    return normalizer


def _serialize_summary_score(
    summary_score: SummaryScore,
    normalizer: QuantileNormalizer,
    ambiguity_normalizer: QuantileNormalizer,
    consistency_normalizer: QuantileNormalizer,
    compute_consistency: bool = True,
) -> dict[str, Any]:
    """Serialize a summary score and attach display-oriented uncertainty values."""

    payload = summary_score.to_dict()
    normalization: dict[str, Any] = {
        "boundaries": list(normalizer.boundaries),
        "ambiguity_boundaries": list(ambiguity_normalizer.boundaries),
    }
    if compute_consistency:
        normalization["consistency_boundaries"] = list(consistency_normalizer.boundaries)
    payload["normalization"] = normalization

    for sentence_result in payload["sentence_results"]:
        raw_uncertainty = float(sentence_result["uncertainty"])
        sentence_result["uncertainty_raw"] = raw_uncertainty
        sentence_result["uncertainty_score"] = normalizer.normalize(raw_uncertainty)
        sentence_result["uncertainty_band"] = normalizer.band(raw_uncertainty)

        raw_ambiguity = float(sentence_result["expected_entropy"])
        sentence_result["ambiguity_score"] = ambiguity_normalizer.normalize(raw_ambiguity)
        sentence_result["ambiguity_band"] = ambiguity_normalizer.band(raw_ambiguity)

        if compute_consistency:
            # consistency_score: higher = more consistent with the source.
            # mean_logprob is negative; negate it so higher raw = less consistent,
            # then invert the 0-100 scale so the final score reads naturally.
            raw_inconsistency = -float(sentence_result["mean_logprob"])
            inconsistency_normalized = consistency_normalizer.normalize(raw_inconsistency)
            sentence_result["consistency_score"] = round(100.0 - inconsistency_normalized, 4)
            sentence_result["consistency_band"] = _invert_band(
                consistency_normalizer.band(raw_inconsistency)
            )

    return payload


def _invert_band(band: str) -> str:
    if band == "low":
        return "high"
    if band == "high":
        return "low"
    return band


_api_token = os.environ.get("API_TOKEN") or None

app = create_app(
    _build_default_service,
    normalizer=_build_default_normalizer(),
    ambiguity_normalizer=_build_default_ambiguity_normalizer(),
    consistency_normalizer=_build_default_consistency_normalizer(),
    api_token=_api_token,
)