File size: 3,356 Bytes
0584798 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import json
from http.server import BaseHTTPRequestHandler, HTTPServer
from combined_inference import classify_query
from config import DEFAULT_API_HOST, DEFAULT_API_PORT, HEAD_CONFIGS, PROJECT_VERSION
from model_runtime import get_head
from schemas import (
SchemaValidationError,
default_version_payload,
validate_classify_request,
validate_classify_response,
validate_health_response,
validate_version_response,
)
class DemoHandler(BaseHTTPRequestHandler):
def _send_json(self, status_code: int, payload: dict):
body = json.dumps(payload, indent=2).encode("utf-8")
self.send_response(status_code)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def _read_json_body(self) -> dict:
content_length = int(self.headers.get("Content-Length", "0"))
raw_body = self.rfile.read(content_length)
try:
return json.loads(raw_body or b"{}")
except json.JSONDecodeError as exc:
raise SchemaValidationError(
"invalid_json",
[{"field": "body", "message": f"invalid JSON: {exc.msg}", "type": "parse_error"}],
) from exc
def _handle_classify(self):
try:
request_payload = validate_classify_request(self._read_json_body())
response_payload = validate_classify_response(classify_query(request_payload["text"]))
except SchemaValidationError as exc:
status_code = 400 if exc.code == "invalid_json" else 422 if exc.code == "request_validation_failed" else 500
self._send_json(status_code, {"error": exc.code, "details": exc.details})
return
self._send_json(200, response_payload)
def _handle_health(self):
payload = {
"status": "ok",
"system_version": PROJECT_VERSION,
"heads": [get_head(head_name).status() for head_name in HEAD_CONFIGS],
}
try:
response_payload = validate_health_response(payload)
except SchemaValidationError as exc:
self._send_json(500, {"error": exc.code, "details": exc.details})
return
self._send_json(200, response_payload)
def _handle_version(self):
try:
response_payload = validate_version_response(default_version_payload())
except SchemaValidationError as exc:
self._send_json(500, {"error": exc.code, "details": exc.details})
return
self._send_json(200, response_payload)
def do_GET(self):
if self.path == "/health":
self._handle_health()
return
if self.path == "/version":
self._handle_version()
return
self._send_json(404, {"error": "not_found"})
def do_POST(self):
if self.path != "/classify":
self._send_json(404, {"error": "not_found"})
return
self._handle_classify()
def log_message(self, format: str, *args):
return
def main():
server = HTTPServer((DEFAULT_API_HOST, DEFAULT_API_PORT), DemoHandler)
print(f"Serving demo API on http://{DEFAULT_API_HOST}:{DEFAULT_API_PORT}")
server.serve_forever()
if __name__ == "__main__":
main()
|