File size: 3,356 Bytes
0584798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
from http.server import BaseHTTPRequestHandler, HTTPServer

from combined_inference import classify_query
from config import DEFAULT_API_HOST, DEFAULT_API_PORT, HEAD_CONFIGS, PROJECT_VERSION
from model_runtime import get_head
from schemas import (
    SchemaValidationError,
    default_version_payload,
    validate_classify_request,
    validate_classify_response,
    validate_health_response,
    validate_version_response,
)


class DemoHandler(BaseHTTPRequestHandler):
    def _send_json(self, status_code: int, payload: dict):
        body = json.dumps(payload, indent=2).encode("utf-8")
        self.send_response(status_code)
        self.send_header("Content-Type", "application/json")
        self.send_header("Content-Length", str(len(body)))
        self.end_headers()
        self.wfile.write(body)

    def _read_json_body(self) -> dict:
        content_length = int(self.headers.get("Content-Length", "0"))
        raw_body = self.rfile.read(content_length)
        try:
            return json.loads(raw_body or b"{}")
        except json.JSONDecodeError as exc:
            raise SchemaValidationError(
                "invalid_json",
                [{"field": "body", "message": f"invalid JSON: {exc.msg}", "type": "parse_error"}],
            ) from exc

    def _handle_classify(self):
        try:
            request_payload = validate_classify_request(self._read_json_body())
            response_payload = validate_classify_response(classify_query(request_payload["text"]))
        except SchemaValidationError as exc:
            status_code = 400 if exc.code == "invalid_json" else 422 if exc.code == "request_validation_failed" else 500
            self._send_json(status_code, {"error": exc.code, "details": exc.details})
            return

        self._send_json(200, response_payload)

    def _handle_health(self):
        payload = {
            "status": "ok",
            "system_version": PROJECT_VERSION,
            "heads": [get_head(head_name).status() for head_name in HEAD_CONFIGS],
        }
        try:
            response_payload = validate_health_response(payload)
        except SchemaValidationError as exc:
            self._send_json(500, {"error": exc.code, "details": exc.details})
            return
        self._send_json(200, response_payload)

    def _handle_version(self):
        try:
            response_payload = validate_version_response(default_version_payload())
        except SchemaValidationError as exc:
            self._send_json(500, {"error": exc.code, "details": exc.details})
            return
        self._send_json(200, response_payload)

    def do_GET(self):
        if self.path == "/health":
            self._handle_health()
            return
        if self.path == "/version":
            self._handle_version()
            return
        self._send_json(404, {"error": "not_found"})

    def do_POST(self):
        if self.path != "/classify":
            self._send_json(404, {"error": "not_found"})
            return
        self._handle_classify()

    def log_message(self, format: str, *args):
        return


def main():
    server = HTTPServer((DEFAULT_API_HOST, DEFAULT_API_PORT), DemoHandler)
    print(f"Serving demo API on http://{DEFAULT_API_HOST}:{DEFAULT_API_PORT}")
    server.serve_forever()


if __name__ == "__main__":
    main()