| import time |
|
|
| from fastapi import APIRouter, HTTPException, Query, Request |
|
|
| from src.schemas.requests import SampleInput |
| from src.schemas.responses import SampleOutput |
| from src.logger import log_to_betterstack |
|
|
| router = APIRouter() |
|
|
| MAX_LEN = 256 |
| BATCH_SIZE = 32 |
|
|
|
|
| @router.post("/predict-all-models", response_model=dict[str, list[SampleOutput]]) |
| def predict_all_models_endpoint( |
| request: Request, |
| samples: list[SampleInput], |
| deduplicate: bool = Query(False), |
| ) -> dict[str, list[SampleOutput]]: |
| pipeline = getattr(request.app.state, "all_models_pipeline", None) |
| if pipeline is None or not pipeline.models: |
| raise HTTPException( |
| status_code=503, |
| detail="No models are available.", |
| ) |
| samples_raw = [s.model_dump() for s in samples] |
|
|
| start = time.perf_counter() |
| results = pipeline.run(samples_raw, MAX_LEN, BATCH_SIZE, deduplicate=deduplicate) |
| elapsed = time.perf_counter() - start |
|
|
| output = { |
| mode: [SampleOutput(**r) for r in preds] |
| for mode, preds in results.items() |
| } |
|
|
| log_to_betterstack( |
| endpoint_name="/predict-all-models", |
| original_text=samples_raw, |
| formatted_text={m: [o.model_dump() for o in preds] for m, preds in output.items()}, |
| model_name="all", |
| time_elapsed=elapsed, |
| ) |
|
|
| return output |
|
|