File size: 3,930 Bytes
10aced5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
ValidatorClient — typed HTTP client for the AI Response Validator API.

Retry policy: exponential backoff on 5xx and network errors, up to max_retries.
Timeouts: connect + read combined, configurable per instance.
Auth: optional Bearer token forwarded as Authorization header (for future use).
"""

import httpx
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
    RetryError,
)

from client.exceptions import APIError, RetryExhaustedError, TimeoutError
from client.models import ConfigResponse, QueryRequest, QueryResponse

DEFAULT_TIMEOUT = 30.0
DEFAULT_MAX_RETRIES = 3
_RETRY_STATUS_CODES = {500, 502, 503, 504}


class ValidatorClient:
    """Typed client for the AI Response Validator API."""

    def __init__(
        self,
        base_url: str,
        timeout: float = DEFAULT_TIMEOUT,
        max_retries: int = DEFAULT_MAX_RETRIES,
        api_key: str | None = None,
    ) -> None:
        headers = {"Accept": "application/json"}
        if api_key:
            headers["Authorization"] = f"Bearer {api_key}"
        self._client = httpx.Client(
            base_url=base_url.rstrip("/"),
            timeout=timeout,
            headers=headers,
        )
        self._max_retries = max_retries

    def _request(self, method: str, path: str, **kwargs: object) -> httpx.Response:
        """Execute an HTTP request with retry on transient server errors."""
        @retry(
            retry=retry_if_exception_type(_TransientError),
            stop=stop_after_attempt(self._max_retries),
            wait=wait_exponential(multiplier=0.5, min=0.5, max=10),
            reraise=False,
        )
        def _attempt() -> httpx.Response:
            try:
                response = self._client.request(method, path, **kwargs)  # type: ignore[arg-type]
            except httpx.TimeoutException as exc:
                raise TimeoutError(str(exc)) from exc
            except httpx.NetworkError as exc:
                raise _TransientError(str(exc)) from exc

            if response.status_code in _RETRY_STATUS_CODES:
                raise _TransientError(f"HTTP {response.status_code}")

            if response.is_error:
                detail = _extract_detail(response)
                raise APIError(response.status_code, detail)

            return response

        try:
            return _attempt()
        except RetryError as exc:
            last = exc.last_attempt.exception()
            raise RetryExhaustedError(self._max_retries, last) from exc

    def get_config(self) -> ConfigResponse:
        """Return domain and client configuration (unauthenticated)."""
        response = self._request("GET", "/config")
        return ConfigResponse.model_validate(response.json())

    def query(self, question: str, client_id: str) -> QueryResponse:
        """Submit a question for a specific client and return a graded response."""
        payload = QueryRequest(query=question, client=client_id)
        response = self._request(
            "POST",
            "/query",
            json=payload.model_dump(),
        )
        return QueryResponse.model_validate(response.json())

    def health(self) -> bool:
        """Return True if the API is reachable and healthy."""
        try:
            response = self._request("GET", "/health")
            return response.json().get("status") == "ok"
        except ValidatorError:
            return False

    def close(self) -> None:
        self._client.close()

    def __enter__(self) -> "ValidatorClient":
        return self

    def __exit__(self, *_: object) -> None:
        self.close()


class _TransientError(Exception):
    """Internal marker for errors that should trigger a retry."""


def _extract_detail(response: httpx.Response) -> str:
    try:
        return str(response.json().get("detail", response.text))
    except Exception:
        return response.text