NMT / scripts /nmt_http_api.py
Aiana Mederbek
perf: translate multi-line input line-by-line; env beam/max-decode; robust env ints
70f84e7
Raw
History Blame Contribute Delete
10.1 kB
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from fastapi import FastAPI, Header, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel, Field
from scripts.nmt_tcp_server import (
DEFAULT_SRC_LANG,
DEFAULT_TGT_LANG,
SUPPORTED_PAIRS,
load_spm,
load_translator,
translate_one,
validate_lang_pair,
)
log = logging.getLogger("nmt_http_api")
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
_TRANSLATE_HINT = (
'Send English in JSON: {"text":"Your sentence here"} with '
"Content-Type: application/json. Alternatives: form field `text`, "
"raw text/plain body, or GET /translate?text=..."
)
@dataclass
class AppState:
translator: object
sentencepiece: object
api_key: str | None
max_chars: int
timeout_seconds: float
require_api_key: bool
model_variant: str
class TranslateRequest(BaseModel):
text: str = Field(min_length=1, max_length=2000)
source_lang: str = Field(default=DEFAULT_SRC_LANG, min_length=1, max_length=32)
target_lang: str = Field(default=DEFAULT_TGT_LANG, min_length=1, max_length=32)
class TranslateResponse(BaseModel):
translation: str
latency_ms: float
request_id: str
source_lang: str
target_lang: str
model_variant: str
def _read_env_int(name: str, default: int) -> int:
raw = os.getenv(name, str(default))
if raw is None or not str(raw).strip():
return default
first = str(raw).strip().splitlines()[0].strip()
try:
return int(first)
except ValueError as exc:
raise RuntimeError(f"Environment variable {name} must be an integer") from exc
def _read_env_bool(name: str, default: bool) -> bool:
raw = os.getenv(name, "1" if default else "0").strip().lower()
return raw in {"1", "true", "yes", "on"}
def _resolve_model_paths() -> tuple[str, str, str]:
model_variant = os.getenv("MODEL_VARIANT", "base").strip() or "base"
variants_json = os.getenv("MODEL_VARIANTS_JSON", "").strip()
if not variants_json:
model_dir = os.getenv("MODEL_DIR", "artifacts/ct2/en_it_v4_casual_weighted/model")
spm_path = os.getenv("SPM_PATH", os.path.join(model_dir, "sentencepiece.bpe.model"))
return model_variant, model_dir, spm_path
try:
variant_map = json.loads(variants_json)
except json.JSONDecodeError as exc:
raise RuntimeError("MODEL_VARIANTS_JSON must be valid JSON") from exc
if model_variant not in variant_map:
raise RuntimeError(f"MODEL_VARIANT {model_variant!r} not found in MODEL_VARIANTS_JSON")
selected = variant_map[model_variant]
model_dir = selected.get("model_dir")
spm_path = selected.get("spm_path") or os.path.join(model_dir, "sentencepiece.bpe.model")
if not model_dir:
raise RuntimeError(f"Variant {model_variant!r} is missing model_dir")
return model_variant, model_dir, spm_path
@asynccontextmanager
async def lifespan(app: FastAPI):
model_variant, model_dir, spm_path = _resolve_model_paths()
api_key = os.getenv("NMT_API_KEY", "").strip() or None
require_api_key = _read_env_bool("REQUIRE_API_KEY", default=api_key is not None)
max_chars = _read_env_int("MAX_INPUT_CHARS", 2000)
timeout_ms = _read_env_int("TRANSLATION_TIMEOUT_MS", 10000)
timeout_seconds = max(timeout_ms / 1000.0, 0.1)
log.info(
"Loading model_variant=%s model_dir=%s spm_path=%s",
model_variant,
model_dir,
spm_path,
)
translator = load_translator(os.path.abspath(model_dir))
sentencepiece = load_spm(os.path.abspath(spm_path))
_ = translate_one(translator, sentencepiece, "Hi")
log.info("Warmup complete")
if require_api_key and not api_key:
raise RuntimeError("REQUIRE_API_KEY is enabled but NMT_API_KEY is not set")
if not require_api_key:
log.warning("API key protection is disabled")
app.state.nmt = AppState(
translator=translator,
sentencepiece=sentencepiece,
api_key=api_key,
max_chars=max_chars,
timeout_seconds=timeout_seconds,
require_api_key=require_api_key,
model_variant=model_variant,
)
yield
app = FastAPI(title="NMT MenKan HTTP API", lifespan=lifespan)
@app.exception_handler(RequestValidationError)
async def request_validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
return JSONResponse(
status_code=422,
content={
"error": "validation_error",
"detail": exc.errors(),
"hint": _TRANSLATE_HINT,
},
)
def _require_api_key(header_key: str | None, expected_key: str | None) -> None:
if not expected_key or not header_key or header_key != expected_key:
raise HTTPException(status_code=401, detail="Unauthorized")
async def _extract_payload(request: Request) -> TranslateRequest:
raw_ct = request.headers.get("content-type") or ""
ct = raw_ct.split(";")[0].strip().lower()
if ct in ("application/json", "") or "application/json" in raw_ct:
try:
body = await request.json()
except json.JSONDecodeError as exc:
raise HTTPException(
status_code=400,
detail={"error": "invalid_json", "message": str(exc), "hint": _TRANSLATE_HINT}
) from exc
return TranslateRequest(**body)
if ct in ("application/x-www-form-urlencoded", "multipart/form-data"):
form = await request.form()
return TranslateRequest(
text=form.get("text", ""),
source_lang=form.get("source_lang", DEFAULT_SRC_LANG),
target_lang=form.get("target_lang", DEFAULT_TGT_LANG)
)
if ct == "text/plain":
raw = (await request.body()).decode("utf-8", errors="replace").strip()
return TranslateRequest(text=raw)
raise HTTPException(
status_code=415,
detail={"error": "unsupported_media_type", "content_type": raw_ct, "hint": _TRANSLATE_HINT}
)
async def _translate_core(state: AppState, payload: TranslateRequest, request_id: str) -> TranslateResponse:
text = payload.text.strip()
if not text:
raise HTTPException(status_code=400, detail={"error": "empty_text", "hint": _TRANSLATE_HINT})
if len(text) > state.max_chars:
raise HTTPException(status_code=413, detail=f"Input exceeds MAX_INPUT_CHARS={state.max_chars}")
source_lang = payload.source_lang.strip()
target_lang = payload.target_lang.strip()
try:
validate_lang_pair(source_lang, target_lang)
except ValueError as exc:
raise HTTPException(
status_code=400,
detail=f"{exc}. Supported pairs: {', '.join([f'{s}->{t}' for s, t in sorted(SUPPORTED_PAIRS)])}"
) from exc
start = time.perf_counter()
try:
translation = await asyncio.wait_for(
asyncio.to_thread(
translate_one, state.translator, state.sentencepiece, text, source_lang, target_lang
),
timeout=state.timeout_seconds,
)
except asyncio.TimeoutError as exc:
log.warning("request_id=%s status=timeout", request_id)
raise HTTPException(status_code=504, detail="Translation timed out") from exc
except Exception:
log.exception("request_id=%s status=error", request_id)
raise HTTPException(status_code=500, detail="Internal server error")
latency_ms = (time.perf_counter() - start) * 1000.0
log.info("request_id=%s status=ok chars=%d latency_ms=%.1f", request_id, len(text), latency_ms)
return TranslateResponse(
translation=translation,
latency_ms=round(latency_ms, 1),
request_id=request_id,
source_lang=source_lang,
target_lang=target_lang,
model_variant=state.model_variant,
)
@app.get("/", response_class=HTMLResponse)
def root(request: Request) -> HTMLResponse:
state: AppState = request.app.state.nmt
path = Path(__file__).resolve().parent / "hf_space_ui.html"
try:
html = path.read_text(encoding="utf-8")
except OSError:
html = "<!DOCTYPE html><html><body><p>UI file missing.</p></body></html>"
html = html.replace("__INJECT_REQUIRE_API_KEY__", "true" if state.require_api_key else "false")
html = html.replace("__INJECT_MAX_CHARS__", str(state.max_chars))
return HTMLResponse(content=html)
@app.get("/healthz")
def healthz() -> dict[str, str]:
return {"status": "ok"}
@app.get("/translate", response_model=TranslateResponse)
async def translate_get(
request: Request,
text: str,
source_lang: str = DEFAULT_SRC_LANG,
target_lang: str = DEFAULT_TGT_LANG,
x_api_key: str | None = Header(default=None, alias="X-API-Key"),
) -> TranslateResponse:
state: AppState = request.app.state.nmt
request_id = request.headers.get("X-Request-Id", str(uuid.uuid4()))
if state.require_api_key:
_require_api_key(x_api_key, state.api_key)
payload = TranslateRequest(text=text, source_lang=source_lang, target_lang=target_lang)
return await _translate_core(state, payload, request_id)
@app.post("/translate", response_model=TranslateResponse)
async def translate_post(
request: Request,
x_api_key: str | None = Header(default=None, alias="X-API-Key"),
) -> TranslateResponse:
state: AppState = request.app.state.nmt
request_id = request.headers.get("X-Request-Id", str(uuid.uuid4()))
if state.require_api_key:
_require_api_key(x_api_key, state.api_key)
payload = await _extract_payload(request)
return await _translate_core(state, payload, request_id)