| import time |
| from fastapi import APIRouter, HTTPException, Query, Request |
| from src.schemas.requests import SampleInput |
| from src.schemas.responses import SampleOutput |
| from src.logger import log_to_betterstack |
|
|
| router = APIRouter() |
|
|
| MODEL_NAME = "marker" |
| MAX_LEN = 256 |
| BATCH_SIZE = 32 |
|
|
|
|
| @router.post("/predict", response_model=list[SampleOutput]) |
| def predict_endpoint( |
| request: Request, |
| samples: list[SampleInput], |
| deduplicate: bool = Query(False), |
| ) -> list[SampleOutput]: |
| pipeline = getattr(request.app.state, f"pipeline_{MODEL_NAME}", None) |
| if pipeline is None: |
| raise HTTPException( |
| status_code=503, |
| detail=f"Model '{MODEL_NAME}' is not available.", |
| ) |
|
|
| samples_raw = [s.model_dump() for s in samples] |
| start = time.perf_counter() |
| results = pipeline.run(samples_raw, MAX_LEN, BATCH_SIZE, deduplicate=deduplicate) |
| elapsed = time.perf_counter() - start |
| output = [SampleOutput(**r) for r in results] |
|
|
| log_to_betterstack( |
| endpoint_name="/predict", |
| original_text=samples_raw, |
| formatted_text=[o.model_dump() for o in output], |
| model_name=MODEL_NAME, |
| time_elapsed=elapsed, |
| ) |
|
|
| return output |
|
|