|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import tempfile |
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
|
from urllib.parse import parse_qs |
|
|
|
|
|
from whisper import Whisper |
|
|
import cgi |
|
|
|
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
def get_model(model_type, model_path, language, task): |
|
|
key = (model_type, model_path, language, task) |
|
|
if key not in _model_cache: |
|
|
print(f"Loading model: type={model_type}, path={model_path}, lang={language}, task={task}") |
|
|
_model_cache[key] = Whisper(model_type, model_path, language, task) |
|
|
return _model_cache[key] |
|
|
|
|
|
|
|
|
class WhisperHandler(BaseHTTPRequestHandler): |
|
|
|
|
|
def _send_json(self, obj, status=200): |
|
|
data = json.dumps(obj, ensure_ascii=False).encode("utf-8") |
|
|
self.send_response(status) |
|
|
self.send_header("Content-Type", "application/json; charset=utf-8") |
|
|
self.send_header("Content-Length", str(len(data))) |
|
|
self.end_headers() |
|
|
self.wfile.write(data) |
|
|
|
|
|
def do_GET(self): |
|
|
if self.path == "/health": |
|
|
self._send_json({"status": "ok"}) |
|
|
else: |
|
|
self._send_json({"error": "not found"}, 404) |
|
|
|
|
|
def do_POST(self): |
|
|
if self.path != "/asr": |
|
|
self._send_json({"error": "not found"}, 404) |
|
|
return |
|
|
|
|
|
|
|
|
content_type = self.headers.get('Content-Type') |
|
|
if not content_type: |
|
|
self._send_json({"error": "Missing Content-Type"}, 400) |
|
|
return |
|
|
|
|
|
ctype, pdict = cgi.parse_header(content_type) |
|
|
|
|
|
if ctype != 'multipart/form-data': |
|
|
self._send_json({"error": "Only multipart/form-data is supported"}, 400) |
|
|
return |
|
|
|
|
|
pdict['boundary'] = bytes(pdict['boundary'], "utf-8") |
|
|
pdict['CONTENT-LENGTH'] = int(self.headers['Content-Length']) |
|
|
|
|
|
form = cgi.parse_multipart(self.rfile, pdict) |
|
|
|
|
|
|
|
|
if "wav" not in form: |
|
|
self._send_json({"error": "Field 'wav' is required"}, 400) |
|
|
return |
|
|
|
|
|
|
|
|
model_type = form.get("model_type", ["tiny"])[0] |
|
|
model_path = form.get("model_path", ["../models/models-ax650"])[0] |
|
|
language = form.get("language", ["zh"])[0] |
|
|
task = form.get("task", ["transcribe"])[0] |
|
|
|
|
|
if task not in ("transcribe", "translate"): |
|
|
self._send_json({"error": "task must be 'transcribe' or 'translate'"}, 400) |
|
|
return |
|
|
|
|
|
wav_bytes = form["wav"][0] |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
|
tmp.write(wav_bytes) |
|
|
wav_path = tmp.name |
|
|
|
|
|
|
|
|
try: |
|
|
model = get_model(model_type, model_path, language, task) |
|
|
result_text = model.run(wav_path) |
|
|
except Exception as e: |
|
|
self._send_json({"error": str(e)}, 500) |
|
|
return |
|
|
finally: |
|
|
if os.path.exists(wav_path): |
|
|
os.remove(wav_path) |
|
|
|
|
|
self._send_json({"text": result_text}) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Whisper Server") |
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") |
|
|
args = parser.parse_args() |
|
|
port = args.port |
|
|
server = HTTPServer(("0.0.0.0", port), WhisperHandler) |
|
|
print(f"Server started at http://0.0.0.0:{port}") |
|
|
server.serve_forever() |