| | |
| | |
| | """ |
| | Light-weight Sherpa-ONNX HTTP ASR server (thread-safe). |
| | """ |
| |
|
| | from __future__ import annotations |
| | import socket |
| | import argparse |
| | import io |
| | import json |
| | import logging |
| | import os |
| | import sys |
| | import time |
| | import wave |
| | from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
| | from typing import Final, List, Tuple |
| |
|
| | import numpy as np |
| |
|
| | |
| | |
| | |
| | logging.basicConfig( |
| | level=os.getenv("LOGLEVEL", "INFO").upper(), |
| | format="%(asctime)s.%(msecs)03d [%(levelname)8s] %(name)s - %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | ) |
| | log = logging.getLogger("ASR") |
| |
|
| | |
| | |
| | |
| | EXPECTED_RATE: Final[int] = 16_000 |
| | EXPECTED_CHANNELS: Final[int] = 1 |
| | EXPECTED_WIDTH_B: Final[int] = 2 |
| | TAIL_PADDING: Final[np.ndarray] = np.zeros(int(0.5 * EXPECTED_RATE), np.float32) |
| |
|
| | if getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS'): |
| | base_path = sys._MEIPASS |
| | else: |
| | try: |
| | base_path = os.path.dirname(os.path.abspath(__file__)) |
| | except NameError: |
| | base_path = os.path.abspath(".") |
| |
|
| | MODEL_DIR = os.path.join(base_path, "assets", "sensevoicesmallonnx") |
| | MODEL_PATH = os.path.join(MODEL_DIR, "model.onnx") |
| | TOKENS_PATH = os.path.join(MODEL_DIR, "tokens.txt") |
| |
|
| | |
| | |
| | |
| | try: |
| | import sherpa_onnx |
| | except ImportError as exc: |
| | log.critical("sherpa_onnx not found – install it via `pip install sherpa-onnx`") |
| | raise SystemExit(1) from exc |
| |
|
| |
|
| | def load_recognizer() -> "sherpa_onnx.OfflineRecognizer": |
| | if not (os.path.isfile(MODEL_PATH) and os.path.isfile(TOKENS_PATH)): |
| | log.critical("Model assets missing under %s", MODEL_DIR) |
| | raise SystemExit(1) |
| |
|
| | ts0 = time.perf_counter() |
| | log.info("Begin loading model...") |
| | rec = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
| | model=MODEL_PATH, |
| | tokens=TOKENS_PATH, |
| | language="", |
| | use_itn=True, |
| | num_threads=max(1, os.cpu_count() // 2), |
| | provider="cpu", |
| | debug=False, |
| | ) |
| | cost = time.perf_counter() - ts0 |
| | log.info("Model loaded in %.2f s", cost) |
| | return rec |
| |
|
| |
|
| | RECOGNIZER = load_recognizer() |
| |
|
| |
|
| | |
| | |
| | |
| | def transcribe_pcm16le(pcm_bytes: bytes) -> str: |
| | """ |
| | Parameters |
| | ---------- |
| | pcm_bytes : bytes |
| | 16-bit little endian, mono, 16 kHz PCM audio. |
| | |
| | Returns |
| | ------- |
| | str |
| | Transcribed text. |
| | """ |
| | ts0 = time.perf_counter() |
| |
|
| | |
| | audio_i16 = np.frombuffer(memoryview(pcm_bytes), dtype=np.int16) |
| | audio_f32 = audio_i16.astype(np.float32) / 32768.0 |
| |
|
| | stream = RECOGNIZER.create_stream() |
| | stream.accept_waveform(EXPECTED_RATE, audio_f32) |
| | stream.accept_waveform(EXPECTED_RATE, TAIL_PADDING) |
| |
|
| | RECOGNIZER.decode_stream(stream) |
| | text = stream.result.text |
| |
|
| | cost = time.perf_counter() - ts0 |
| | log.info("ASR processed %d samples in %.2f s, result: %s", audio_f32.shape[0], cost, text) |
| | log.debug("Transcribed %d samples in %.2f s → %s", |
| | audio_f32.shape[0], cost, text) |
| | return text |
| |
|
| |
|
| | |
| | |
| | |
| | class ASRHandler(BaseHTTPRequestHandler): |
| | server_version = "SherpaASR/1.0" |
| |
|
| | def log_message(self, fmt, *args): |
| | log.info("%s – " + fmt, self.address_string(), *args) |
| |
|
| | def _json(self, code: int, payload: dict): |
| | body = json.dumps(payload, ensure_ascii=False).encode() |
| | self.send_response(code) |
| | self.send_header("Content-Type", "application/json; charset=utf-8") |
| | self.send_header("Content-Length", str(len(body))) |
| | self.end_headers() |
| | self.wfile.write(body) |
| |
|
| | def _bad_request(self, message: str): |
| | self._json(400, {"status": "error", "message": message}) |
| |
|
| | def do_GET(self): |
| | if self.path != "/": |
| | self._json(404, {"status": "error", "message": "Not Found"}) |
| | return |
| |
|
| | self._json( |
| | 200, |
| | { |
| | "status": "ok", |
| | "model_loaded": True, |
| | "usage": "POST a 16 kHz / 16-bit / mono WAV file to /asr", |
| | }, |
| | ) |
| |
|
| | def do_POST(self): |
| | if self.path != "/asr": |
| | self._json(404, {"status": "error", "message": "Not Found"}) |
| | return |
| |
|
| | length = self.headers.get("Content-Length") |
| | if length is None: |
| | return self._bad_request("Missing Content-Length header") |
| |
|
| | try: |
| | body = self.rfile.read(int(length)) |
| | except Exception as exc: |
| | return self._bad_request(f"Failed to read body: {exc}") |
| |
|
| | |
| | try: |
| | with wave.open(io.BytesIO(body), "rb") as wf: |
| | if ( |
| | wf.getnchannels() != EXPECTED_CHANNELS |
| | or wf.getsampwidth() != EXPECTED_WIDTH_B |
| | or wf.getframerate() != EXPECTED_RATE |
| | ): |
| | return self._bad_request( |
| | "Audio must be 16 kHz, 16-bit, mono PCM WAV" |
| | ) |
| | pcm_bytes = wf.readframes(wf.getnframes()) |
| | except wave.Error as exc: |
| | return self._bad_request(f"Invalid WAV file: {exc}") |
| |
|
| | |
| | try: |
| | text = transcribe_pcm16le(pcm_bytes) |
| | except Exception as exc: |
| | log.exception("ASR failed") |
| | return self._json(500, {"status": "error", "message": str(exc)}) |
| |
|
| | self._json(200, {"status": "success", "result": text}) |
| |
|
| | |
| | |
| | |
| | class DualStackServer(ThreadingHTTPServer): |
| | """HTTP server that supports both IPv4 and IPv6 simultaneously.""" |
| | |
| | def __init__(self, server_address: Tuple[str, int], RequestHandlerClass, ipv4: bool = True, ipv6: bool = True): |
| | self.ipv4 = ipv4 |
| | self.ipv6 = ipv6 |
| | self.address_family = socket.AF_INET6 if ipv6 else socket.AF_INET |
| | ThreadingHTTPServer.__init__(self, server_address, RequestHandlerClass) |
| | |
| | def server_bind(self): |
| | if self.ipv6: |
| | try: |
| | self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) |
| | except Exception as e: |
| | log.warning("Could not enable dual-stack mode (IPV6_V6ONLY=0): %s", e) |
| | |
| | ThreadingHTTPServer.server_bind(self) |
| |
|
| | |
| | |
| | |
| | def get_network_interfaces(ipv6: bool = False) -> List[str]: |
| | addresses = set() |
| | family = socket.AF_INET6 if ipv6 else socket.AF_INET |
| |
|
| | try: |
| | for iface in socket.getaddrinfo(socket.gethostname(), None): |
| | if iface[0] == family: |
| | addr = iface[4][0] |
| | if ipv6: |
| | if not (addr.startswith("fe80") or addr == "::1"): |
| | addresses.add(addr) |
| | else: |
| | if not addr.startswith("127."): |
| | addresses.add(addr) |
| | except Exception: |
| | pass |
| |
|
| | try: |
| | with socket.socket(family, socket.SOCK_DGRAM) as s: |
| | s.connect(("8.8.8.8", 80) if not ipv6 else ("2001:4860:4860::8888", 80)) |
| | addr = s.getsockname()[0] |
| | if ipv6: |
| | if not (addr.startswith("fe80") or addr == "::1"): |
| | addresses.add(addr) |
| | else: |
| | if not addr.startswith("127."): |
| | addresses.add(addr) |
| | except Exception: |
| | pass |
| |
|
| | return sorted(addresses) |
| |
|
| | |
| | |
| | |
| | def main() -> None: |
| | parser = argparse.ArgumentParser(description="Sherpa-ONNX ASR HTTP server") |
| | parser.add_argument("--host", default=None, |
| | help="Specific host to bind to (e.g., '0.0.0.0', '::', 'localhost')") |
| | parser.add_argument("--port", default=7860, type=int, help="Port to listen on") |
| | parser.add_argument("--ip-version", choices=["4", "6", "dual"], default="dual", |
| | help="IP version to use: 4=IPv4 only, 6=IPv6 only, dual=both") |
| | parser.add_argument("--scope", choices=["local", "all"], default="all", |
| | help="Binding scope: local=loopback only, all=all interfaces") |
| | args = parser.parse_args() |
| | |
| | ipv4 = args.ip_version in ["4", "dual"] |
| | ipv6 = args.ip_version in ["6", "dual"] |
| |
|
| | if args.host: |
| | bind_host = args.host |
| | if ":" in bind_host and not ipv6: |
| | log.warning("IPv6 host specified but IPv6 is disabled") |
| | bind_all = bind_host in ("0.0.0.0", "::", "") |
| | else: |
| | if args.scope == "local": |
| | bind_host = "::1" if ipv6 else "127.0.0.1" |
| | bind_all = False |
| | else: |
| | bind_host = "::" if ipv6 else "0.0.0.0" |
| | bind_all = True |
| |
|
| | try: |
| | server = DualStackServer( |
| | server_address=(bind_host, args.port), |
| | RequestHandlerClass=ASRHandler, |
| | ipv4=ipv4, |
| | ipv6=ipv6 |
| | ) |
| | except OSError as e: |
| | log.critical("Failed to start server: %s", e) |
| | sys.exit(1) |
| | |
| | log.info("Server started on port %d", args.port) |
| | log.info("Protocols: IPv4=%s, IPv6=%s", ipv4, ipv6) |
| | |
| | if bind_all: |
| | if ipv4: |
| | for addr in get_network_interfaces(ipv6=False): |
| | log.info("IPv4: http://%s:%d/", addr, args.port) |
| | if ipv6: |
| | for addr in get_network_interfaces(ipv6=True): |
| | log.info("IPv6: http://[%s]:%d/", addr, args.port) |
| | else: |
| | if ":" in bind_host: |
| | log.info("Listening on: http://[%s]:%d/", bind_host, args.port) |
| | else: |
| | log.info("Listening on: http://%s:%d/", bind_host, args.port) |
| | |
| | try: |
| | server.serve_forever() |
| | except KeyboardInterrupt: |
| | log.info("Shutting down...") |
| | finally: |
| | server.server_close() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|