File size: 3,500 Bytes
ca02ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import json
import os
import tempfile
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import parse_qs

from whisper import Whisper
import cgi


# 模型缓存:避免每次请求都重新加载
_model_cache = {}

def get_model(model_type, model_path, language, task):
    key = (model_type, model_path, language, task)
    if key not in _model_cache:
        print(f"Loading model: type={model_type}, path={model_path}, lang={language}, task={task}")
        _model_cache[key] = Whisper(model_type, model_path, language, task)
    return _model_cache[key]


class WhisperHandler(BaseHTTPRequestHandler):

    def _send_json(self, obj, status=200):
        data = json.dumps(obj, ensure_ascii=False).encode("utf-8")
        self.send_response(status)
        self.send_header("Content-Type", "application/json; charset=utf-8")
        self.send_header("Content-Length", str(len(data)))
        self.end_headers()
        self.wfile.write(data)

    def do_GET(self):
        if self.path == "/health":
            self._send_json({"status": "ok"})
        else:
            self._send_json({"error": "not found"}, 404)

    def do_POST(self):
        if self.path != "/asr":
            self._send_json({"error": "not found"}, 404)
            return

        # 解析 multipart/form-data
        content_type = self.headers.get('Content-Type')
        if not content_type:
            self._send_json({"error": "Missing Content-Type"}, 400)
            return

        ctype, pdict = cgi.parse_header(content_type)

        if ctype != 'multipart/form-data':
            self._send_json({"error": "Only multipart/form-data is supported"}, 400)
            return

        pdict['boundary'] = bytes(pdict['boundary'], "utf-8")
        pdict['CONTENT-LENGTH'] = int(self.headers['Content-Length'])

        form = cgi.parse_multipart(self.rfile, pdict)

        # 必须包含 wav 文件
        if "wav" not in form:
            self._send_json({"error": "Field 'wav' is required"}, 400)
            return

        # 获取参数(如果缺省则使用默认值)
        model_type = form.get("model_type", ["tiny"])[0]
        model_path = form.get("model_path", ["../models/models-ax650"])[0]
        language = form.get("language", ["zh"])[0]
        task = form.get("task", ["transcribe"])[0]

        if task not in ("transcribe", "translate"):
            self._send_json({"error": "task must be 'transcribe' or 'translate'"}, 400)
            return

        wav_bytes = form["wav"][0]

        # 写入临时文件
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            tmp.write(wav_bytes)
            wav_path = tmp.name

        # 加载模型并运行
        try:
            model = get_model(model_type, model_path, language, task)
            result_text = model.run(wav_path)
        except Exception as e:
            self._send_json({"error": str(e)}, 500)
            return
        finally:
            if os.path.exists(wav_path):
                os.remove(wav_path)

        self._send_json({"text": result_text})


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Whisper Server")
    parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
    args = parser.parse_args()
    port = args.port
    server = HTTPServer(("0.0.0.0", port), WhisperHandler)
    print(f"Server started at http://0.0.0.0:{port}")
    server.serve_forever()