| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Word embeddings server class.""" |
| |
|
| | import http.server |
| | import json |
| | import socketserver |
| | import os |
| | import sys |
| | import threading |
| | import urllib.parse |
| |
|
| | import numpy as np |
| |
|
| | class WEmbeddingsServer(socketserver.ThreadingTCPServer): |
| |
|
| | class WEmbeddingsRequestHandler(http.server.BaseHTTPRequestHandler): |
| | protocol_version = "HTTP/1.1" |
| |
|
| | def respond(request, content_type, code=200): |
| | request.close_connection = True |
| | request.send_response(code) |
| | request.send_header("Connection", "close") |
| | request.send_header("Content-Type", content_type) |
| | request.send_header("Access-Control-Allow-Origin", "*") |
| | request.end_headers() |
| |
|
| | def respond_error(request, message, code=400): |
| | request.respond("text/plain", code) |
| | request.wfile.write(message.encode("utf-8")) |
| |
|
| | def do_POST(request): |
| | try: |
| | request.path = request.path.encode("iso-8859-1").decode("utf-8") |
| | url = urllib.parse.urlparse(request.path) |
| | except: |
| | return request.respond_error("Cannot parse request URL.") |
| |
|
| | |
| | if url.path == "/wembeddings": |
| | if request.headers.get("Transfer-Encoding", "identity").lower() != "identity": |
| | return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.") |
| |
|
| | if "Content-Length" not in request.headers: |
| | return request.respond_error("The Content-Length of payload is required.") |
| |
|
| | try: |
| | length = int(request.headers["Content-Length"]) |
| | data = json.loads(request.rfile.read(length)) |
| | model, sentences = data["model"], data["sentences"] |
| | except: |
| | import traceback |
| | traceback.print_exc(file=sys.stderr) |
| | sys.stderr.flush() |
| | return request.respond_error("Malformed request.") |
| |
|
| | try: |
| | with request.server._wembeddings_mutex: |
| | sentences_embeddings = request.server._wembeddings.compute_embeddings(model, sentences) |
| | except: |
| | import traceback |
| | traceback.print_exc(file=sys.stderr) |
| | sys.stderr.flush() |
| | return request.respond_error("An error occurred during wembeddings computation.") |
| |
|
| | request.respond("application/octet_stream") |
| | for sentence_embedding in sentences_embeddings: |
| | np.lib.format.write_array(request.wfile, sentence_embedding.astype(request.server._dtype), allow_pickle=False) |
| |
|
| | |
| | else: |
| | request.respond_error("No handler for the given URL '{}'".format(url.path), code=404) |
| |
|
| | def do_GET(request): |
| | try: |
| | request.path = request.path.encode("iso-8859-1").decode("utf-8") |
| | url = urllib.parse.urlparse(request.path) |
| | except: |
| | return request.respond_error("Cannot parse request URL.") |
| |
|
| | if url.path == "/status": |
| | request.respond("application/json") |
| | request.wfile.write(bytes("""{"status": "UP"}""", "utf-8")) |
| | |
| | else: |
| | request.respond_error("No handler for the given URL '{}'".format(url.path), code=404) |
| |
|
| | daemon_threads = False |
| |
|
| | def __init__(self, port, dtype, wembeddings_lambda): |
| | self._dtype = dtype |
| |
|
| | |
| | self._wembeddings = wembeddings_lambda() |
| | self._wembeddings_mutex = threading.Lock() |
| |
|
| | |
| | super().__init__(("", port), self.WEmbeddingsRequestHandler) |
| |
|
| | def server_bind(self): |
| | import socket |
| | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| | if os.name != 'nt': |
| | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
| | super().server_bind() |
| |
|
| | def service_actions(self): |
| | if isinstance(getattr(self, "_threads", None), list): |
| | if len(self._threads) >= 1024: |
| | self._threads = [thread for thread in self._threads if thread.is_alive()] |
| |
|