Spaces:
Sleeping
Sleeping
Aiana Mederbek
perf: translate multi-line input line-by-line; env beam/max-decode; robust env ints
70f84e7 | from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| from fastapi import FastAPI, Header, HTTPException, Request | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| from scripts.nmt_tcp_server import ( | |
| DEFAULT_SRC_LANG, | |
| DEFAULT_TGT_LANG, | |
| SUPPORTED_PAIRS, | |
| load_spm, | |
| load_translator, | |
| translate_one, | |
| validate_lang_pair, | |
| ) | |
| log = logging.getLogger("nmt_http_api") | |
| logging.basicConfig( | |
| level=os.getenv("LOG_LEVEL", "INFO"), | |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", | |
| ) | |
| _TRANSLATE_HINT = ( | |
| 'Send English in JSON: {"text":"Your sentence here"} with ' | |
| "Content-Type: application/json. Alternatives: form field `text`, " | |
| "raw text/plain body, or GET /translate?text=..." | |
| ) | |
| class AppState: | |
| translator: object | |
| sentencepiece: object | |
| api_key: str | None | |
| max_chars: int | |
| timeout_seconds: float | |
| require_api_key: bool | |
| model_variant: str | |
| class TranslateRequest(BaseModel): | |
| text: str = Field(min_length=1, max_length=2000) | |
| source_lang: str = Field(default=DEFAULT_SRC_LANG, min_length=1, max_length=32) | |
| target_lang: str = Field(default=DEFAULT_TGT_LANG, min_length=1, max_length=32) | |
| class TranslateResponse(BaseModel): | |
| translation: str | |
| latency_ms: float | |
| request_id: str | |
| source_lang: str | |
| target_lang: str | |
| model_variant: str | |
| def _read_env_int(name: str, default: int) -> int: | |
| raw = os.getenv(name, str(default)) | |
| if raw is None or not str(raw).strip(): | |
| return default | |
| first = str(raw).strip().splitlines()[0].strip() | |
| try: | |
| return int(first) | |
| except ValueError as exc: | |
| raise RuntimeError(f"Environment variable {name} must be an integer") from exc | |
| def _read_env_bool(name: str, default: bool) -> bool: | |
| raw = os.getenv(name, "1" if default else "0").strip().lower() | |
| return raw in {"1", "true", "yes", "on"} | |
| def _resolve_model_paths() -> tuple[str, str, str]: | |
| model_variant = os.getenv("MODEL_VARIANT", "base").strip() or "base" | |
| variants_json = os.getenv("MODEL_VARIANTS_JSON", "").strip() | |
| if not variants_json: | |
| model_dir = os.getenv("MODEL_DIR", "artifacts/ct2/en_it_v4_casual_weighted/model") | |
| spm_path = os.getenv("SPM_PATH", os.path.join(model_dir, "sentencepiece.bpe.model")) | |
| return model_variant, model_dir, spm_path | |
| try: | |
| variant_map = json.loads(variants_json) | |
| except json.JSONDecodeError as exc: | |
| raise RuntimeError("MODEL_VARIANTS_JSON must be valid JSON") from exc | |
| if model_variant not in variant_map: | |
| raise RuntimeError(f"MODEL_VARIANT {model_variant!r} not found in MODEL_VARIANTS_JSON") | |
| selected = variant_map[model_variant] | |
| model_dir = selected.get("model_dir") | |
| spm_path = selected.get("spm_path") or os.path.join(model_dir, "sentencepiece.bpe.model") | |
| if not model_dir: | |
| raise RuntimeError(f"Variant {model_variant!r} is missing model_dir") | |
| return model_variant, model_dir, spm_path | |
| async def lifespan(app: FastAPI): | |
| model_variant, model_dir, spm_path = _resolve_model_paths() | |
| api_key = os.getenv("NMT_API_KEY", "").strip() or None | |
| require_api_key = _read_env_bool("REQUIRE_API_KEY", default=api_key is not None) | |
| max_chars = _read_env_int("MAX_INPUT_CHARS", 2000) | |
| timeout_ms = _read_env_int("TRANSLATION_TIMEOUT_MS", 10000) | |
| timeout_seconds = max(timeout_ms / 1000.0, 0.1) | |
| log.info( | |
| "Loading model_variant=%s model_dir=%s spm_path=%s", | |
| model_variant, | |
| model_dir, | |
| spm_path, | |
| ) | |
| translator = load_translator(os.path.abspath(model_dir)) | |
| sentencepiece = load_spm(os.path.abspath(spm_path)) | |
| _ = translate_one(translator, sentencepiece, "Hi") | |
| log.info("Warmup complete") | |
| if require_api_key and not api_key: | |
| raise RuntimeError("REQUIRE_API_KEY is enabled but NMT_API_KEY is not set") | |
| if not require_api_key: | |
| log.warning("API key protection is disabled") | |
| app.state.nmt = AppState( | |
| translator=translator, | |
| sentencepiece=sentencepiece, | |
| api_key=api_key, | |
| max_chars=max_chars, | |
| timeout_seconds=timeout_seconds, | |
| require_api_key=require_api_key, | |
| model_variant=model_variant, | |
| ) | |
| yield | |
| app = FastAPI(title="NMT MenKan HTTP API", lifespan=lifespan) | |
| async def request_validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "error": "validation_error", | |
| "detail": exc.errors(), | |
| "hint": _TRANSLATE_HINT, | |
| }, | |
| ) | |
| def _require_api_key(header_key: str | None, expected_key: str | None) -> None: | |
| if not expected_key or not header_key or header_key != expected_key: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| async def _extract_payload(request: Request) -> TranslateRequest: | |
| raw_ct = request.headers.get("content-type") or "" | |
| ct = raw_ct.split(";")[0].strip().lower() | |
| if ct in ("application/json", "") or "application/json" in raw_ct: | |
| try: | |
| body = await request.json() | |
| except json.JSONDecodeError as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "invalid_json", "message": str(exc), "hint": _TRANSLATE_HINT} | |
| ) from exc | |
| return TranslateRequest(**body) | |
| if ct in ("application/x-www-form-urlencoded", "multipart/form-data"): | |
| form = await request.form() | |
| return TranslateRequest( | |
| text=form.get("text", ""), | |
| source_lang=form.get("source_lang", DEFAULT_SRC_LANG), | |
| target_lang=form.get("target_lang", DEFAULT_TGT_LANG) | |
| ) | |
| if ct == "text/plain": | |
| raw = (await request.body()).decode("utf-8", errors="replace").strip() | |
| return TranslateRequest(text=raw) | |
| raise HTTPException( | |
| status_code=415, | |
| detail={"error": "unsupported_media_type", "content_type": raw_ct, "hint": _TRANSLATE_HINT} | |
| ) | |
| async def _translate_core(state: AppState, payload: TranslateRequest, request_id: str) -> TranslateResponse: | |
| text = payload.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail={"error": "empty_text", "hint": _TRANSLATE_HINT}) | |
| if len(text) > state.max_chars: | |
| raise HTTPException(status_code=413, detail=f"Input exceeds MAX_INPUT_CHARS={state.max_chars}") | |
| source_lang = payload.source_lang.strip() | |
| target_lang = payload.target_lang.strip() | |
| try: | |
| validate_lang_pair(source_lang, target_lang) | |
| except ValueError as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"{exc}. Supported pairs: {', '.join([f'{s}->{t}' for s, t in sorted(SUPPORTED_PAIRS)])}" | |
| ) from exc | |
| start = time.perf_counter() | |
| try: | |
| translation = await asyncio.wait_for( | |
| asyncio.to_thread( | |
| translate_one, state.translator, state.sentencepiece, text, source_lang, target_lang | |
| ), | |
| timeout=state.timeout_seconds, | |
| ) | |
| except asyncio.TimeoutError as exc: | |
| log.warning("request_id=%s status=timeout", request_id) | |
| raise HTTPException(status_code=504, detail="Translation timed out") from exc | |
| except Exception: | |
| log.exception("request_id=%s status=error", request_id) | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| latency_ms = (time.perf_counter() - start) * 1000.0 | |
| log.info("request_id=%s status=ok chars=%d latency_ms=%.1f", request_id, len(text), latency_ms) | |
| return TranslateResponse( | |
| translation=translation, | |
| latency_ms=round(latency_ms, 1), | |
| request_id=request_id, | |
| source_lang=source_lang, | |
| target_lang=target_lang, | |
| model_variant=state.model_variant, | |
| ) | |
| def root(request: Request) -> HTMLResponse: | |
| state: AppState = request.app.state.nmt | |
| path = Path(__file__).resolve().parent / "hf_space_ui.html" | |
| try: | |
| html = path.read_text(encoding="utf-8") | |
| except OSError: | |
| html = "<!DOCTYPE html><html><body><p>UI file missing.</p></body></html>" | |
| html = html.replace("__INJECT_REQUIRE_API_KEY__", "true" if state.require_api_key else "false") | |
| html = html.replace("__INJECT_MAX_CHARS__", str(state.max_chars)) | |
| return HTMLResponse(content=html) | |
| def healthz() -> dict[str, str]: | |
| return {"status": "ok"} | |
| async def translate_get( | |
| request: Request, | |
| text: str, | |
| source_lang: str = DEFAULT_SRC_LANG, | |
| target_lang: str = DEFAULT_TGT_LANG, | |
| x_api_key: str | None = Header(default=None, alias="X-API-Key"), | |
| ) -> TranslateResponse: | |
| state: AppState = request.app.state.nmt | |
| request_id = request.headers.get("X-Request-Id", str(uuid.uuid4())) | |
| if state.require_api_key: | |
| _require_api_key(x_api_key, state.api_key) | |
| payload = TranslateRequest(text=text, source_lang=source_lang, target_lang=target_lang) | |
| return await _translate_core(state, payload, request_id) | |
| async def translate_post( | |
| request: Request, | |
| x_api_key: str | None = Header(default=None, alias="X-API-Key"), | |
| ) -> TranslateResponse: | |
| state: AppState = request.app.state.nmt | |
| request_id = request.headers.get("X-Request-Id", str(uuid.uuid4())) | |
| if state.require_api_key: | |
| _require_api_key(x_api_key, state.api_key) | |
| payload = await _extract_payload(request) | |
| return await _translate_core(state, payload, request_id) | |