andito HF Staff commited on
Commit
172bc37
·
0 Parent(s):

Initial commit

Browse files
Files changed (5) hide show
  1. .dockerignore +9 -0
  2. Dockerfile +35 -0
  3. app/main.py +270 -0
  4. requirements.txt +3 -0
  5. test_ws_file.py +85 -0
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ __pycache__
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .env
7
+ .venv
8
+ dist
9
+ build
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ PORT=7860 \
7
+ INTERNAL_WS_HOST=127.0.0.1 \
8
+ INTERNAL_WS_PORT=9000 \
9
+ S2S_REPO_DIR=/opt/speech-to-speech
10
+
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ git \
13
+ ffmpeg \
14
+ libsndfile1 \
15
+ curl \
16
+ ca-certificates \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ WORKDIR /app
20
+
21
+ COPY requirements.txt .
22
+ RUN pip install --upgrade pip setuptools wheel && \
23
+ pip install -r requirements.txt && \
24
+ pip install uv
25
+
26
+ # Clone speech-to-speech and install its dependencies the way the repo expects
27
+ RUN git clone https://github.com/huggingface/speech-to-speech.git ${S2S_REPO_DIR} && \
28
+ cd ${S2S_REPO_DIR} && \
29
+ uv sync --no-dev
30
+
31
+ COPY app /app/app
32
+
33
+ EXPOSE 7860
34
+
35
+ CMD ["uv", "run", "--directory", "/app", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/main.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import signal
5
+ import subprocess
6
+ import sys
7
+ from contextlib import asynccontextmanager
8
+ from typing import Optional
9
+
10
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
11
+ from fastapi.responses import JSONResponse
12
+ import websockets
13
+ from websockets.exceptions import ConnectionClosed
14
+
15
+ logging.basicConfig(
16
+ level=os.getenv("LOG_LEVEL", "INFO").upper(),
17
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
18
+ )
19
+ logger = logging.getLogger("s2s-endpoint")
20
+
21
+ HOST = "0.0.0.0"
22
+ PORT = int(os.getenv("PORT", "7860"))
23
+
24
+ INTERNAL_WS_HOST = os.getenv("INTERNAL_WS_HOST", "127.0.0.1")
25
+ INTERNAL_WS_PORT = int(os.getenv("INTERNAL_WS_PORT", "9000"))
26
+ INTERNAL_WS_URL = f"ws://{INTERNAL_WS_HOST}:{INTERNAL_WS_PORT}"
27
+
28
+ S2S_REPO_DIR = os.getenv("S2S_REPO_DIR", "/opt/speech-to-speech")
29
+
30
+ # Baseline model choices. Keep them simple for a first deployment.
31
+ # You can override any of these in the endpoint env vars.
32
+ LM_MODEL_NAME = os.getenv("LM_MODEL_NAME", "Qwen/Qwen2.5-3B-Instruct")
33
+ TTS = os.getenv("TTS", "pocket")
34
+ POCKET_TTS_VOICE = os.getenv("POCKET_TTS_VOICE", "jean")
35
+ DEVICE = os.getenv("DEVICE", "cuda")
36
+ LANGUAGE = os.getenv("LANGUAGE", "en")
37
+ CHAT_SIZE = os.getenv("CHAT_SIZE", "10")
38
+ STT_COMPILE_MODE = os.getenv("STT_COMPILE_MODE", "reduce-overhead")
39
+
40
+ # Optional extra CLI args for speech-to-speech, space-separated.
41
+ # Example:
42
+ # EXTRA_S2S_ARGS="--stt_model_name large-v3 --temperature 0.7"
43
+ EXTRA_S2S_ARGS = os.getenv("EXTRA_S2S_ARGS", "").strip()
44
+
45
+ # If you later want to use an OpenAI-compatible API-backed LLM instead of a local LM,
46
+ # set USE_OPENAI_API_LLM=1 and configure the related env vars.
47
+ USE_OPENAI_API_LLM = os.getenv("USE_OPENAI_API_LLM", "0") == "1"
48
+ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "")
49
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
50
+ OPENAI_API_MODEL = os.getenv("OPENAI_API_MODEL", "")
51
+
52
+ pipeline_process: Optional[subprocess.Popen] = None
53
+
54
+ def build_s2s_command() -> list[str]:
55
+ cmd = [
56
+ "uv",
57
+ "run",
58
+ "--directory",
59
+ S2S_REPO_DIR,
60
+ "python",
61
+ "s2s_pipeline.py",
62
+ "--mode", "websocket",
63
+ "--ws_host", INTERNAL_WS_HOST,
64
+ "--ws_port", str(INTERNAL_WS_PORT),
65
+ "--device", DEVICE,
66
+ "--language", LANGUAGE,
67
+ "--chat_size", CHAT_SIZE,
68
+ "--tts", TTS,
69
+ ]
70
+
71
+ if STT_COMPILE_MODE:
72
+ cmd += ["--stt_compile_mode", STT_COMPILE_MODE]
73
+
74
+ if TTS == "pocket" and POCKET_TTS_VOICE:
75
+ cmd += ["--pocket_tts_voice", POCKET_TTS_VOICE]
76
+
77
+ if USE_OPENAI_API_LLM:
78
+ cmd += [
79
+ "--llm", "open-api",
80
+ "--open_api_base_url", OPENAI_API_BASE,
81
+ "--open_api_key", OPENAI_API_KEY,
82
+ "--open_api_model_name", OPENAI_API_MODEL,
83
+ ]
84
+ else:
85
+ cmd += ["--lm_model_name", LM_MODEL_NAME]
86
+
87
+ if EXTRA_S2S_ARGS:
88
+ cmd += EXTRA_S2S_ARGS.split()
89
+
90
+ return cmd
91
+
92
+
93
+ async def wait_for_internal_ws(timeout_s: float = 900.0) -> None:
94
+ """
95
+ Wait until the internal speech-to-speech websocket server accepts connections.
96
+ First model load can take a while on endpoint startup.
97
+ """
98
+ start = asyncio.get_event_loop().time()
99
+ last_error = None
100
+
101
+ while True:
102
+ if pipeline_process is not None and pipeline_process.poll() is not None:
103
+ raise RuntimeError(
104
+ f"speech-to-speech process exited early with code {pipeline_process.returncode}"
105
+ )
106
+
107
+ try:
108
+ async with websockets.connect(
109
+ INTERNAL_WS_URL,
110
+ open_timeout=5,
111
+ ping_interval=None,
112
+ max_size=None,
113
+ ):
114
+ logger.info("Internal speech-to-speech websocket is ready at %s", INTERNAL_WS_URL)
115
+ return
116
+ except Exception as exc:
117
+ last_error = exc
118
+
119
+ if asyncio.get_event_loop().time() - start > timeout_s:
120
+ raise RuntimeError(
121
+ f"Timed out waiting for internal websocket server at {INTERNAL_WS_URL}. "
122
+ f"Last error: {last_error}"
123
+ )
124
+
125
+ await asyncio.sleep(2.0)
126
+
127
+
128
+ def start_pipeline() -> None:
129
+ global pipeline_process
130
+
131
+ if pipeline_process is not None and pipeline_process.poll() is None:
132
+ logger.info("speech-to-speech process already running")
133
+ return
134
+
135
+ cmd = build_s2s_command()
136
+ logger.info("Starting speech-to-speech subprocess:\n%s", " ".join(cmd))
137
+
138
+ env = os.environ.copy()
139
+
140
+ pipeline_process = subprocess.Popen(
141
+ cmd,
142
+ cwd=S2S_REPO_DIR,
143
+ env=env,
144
+ stdout=sys.stdout,
145
+ stderr=sys.stderr,
146
+ preexec_fn=os.setsid if os.name != "nt" else None,
147
+ )
148
+
149
+
150
+ def stop_pipeline() -> None:
151
+ global pipeline_process
152
+
153
+ if pipeline_process is None:
154
+ return
155
+
156
+ if pipeline_process.poll() is not None:
157
+ logger.info("speech-to-speech process already stopped")
158
+ return
159
+
160
+ logger.info("Stopping speech-to-speech subprocess")
161
+
162
+ try:
163
+ if os.name != "nt":
164
+ os.killpg(os.getpgid(pipeline_process.pid), signal.SIGTERM)
165
+ else:
166
+ pipeline_process.terminate()
167
+ pipeline_process.wait(timeout=20)
168
+ except Exception:
169
+ logger.exception("Graceful shutdown failed, killing subprocess")
170
+ try:
171
+ if os.name != "nt":
172
+ os.killpg(os.getpgid(pipeline_process.pid), signal.SIGKILL)
173
+ else:
174
+ pipeline_process.kill()
175
+ except Exception:
176
+ logger.exception("Failed to kill subprocess")
177
+ finally:
178
+ pipeline_process = None
179
+
180
+
181
+ @asynccontextmanager
182
+ async def lifespan(app: FastAPI):
183
+ start_pipeline()
184
+ try:
185
+ yield
186
+ finally:
187
+ stop_pipeline()
188
+
189
+
190
+ app = FastAPI(lifespan=lifespan)
191
+
192
+
193
+ @app.get("/")
194
+ async def root():
195
+ return {
196
+ "message": "s2s endpoint is up",
197
+ "health": "/health",
198
+ "websocket": "/ws",
199
+ }
200
+
201
+
202
+ @app.get("/health")
203
+ async def health():
204
+ if pipeline_process is None:
205
+ raise HTTPException(status_code=503, detail="speech-to-speech process not started")
206
+
207
+ if pipeline_process.poll() is not None:
208
+ raise HTTPException(
209
+ status_code=503,
210
+ detail=f"speech-to-speech process exited with code {pipeline_process.returncode}",
211
+ )
212
+
213
+ try:
214
+ await asyncio.wait_for(wait_for_internal_ws(timeout_s=5), timeout=6)
215
+ except Exception as exc:
216
+ raise HTTPException(status_code=503, detail=f"internal websocket not ready: {exc}") from exc
217
+
218
+ return JSONResponse({"status": "ok", "internal_ws": INTERNAL_WS_URL})
219
+
220
+
221
+ @app.websocket("/ws")
222
+ async def websocket_proxy(client_ws: WebSocket):
223
+ await client_ws.accept()
224
+ logger.info("Client websocket connected")
225
+
226
+ try:
227
+ async with websockets.connect(
228
+ INTERNAL_WS_URL,
229
+ open_timeout=30,
230
+ ping_interval=20,
231
+ ping_timeout=20,
232
+ max_size=None,
233
+ ) as upstream_ws:
234
+
235
+ async def client_to_upstream():
236
+ while True:
237
+ message = await client_ws.receive()
238
+
239
+ if message["type"] == "websocket.disconnect":
240
+ raise WebSocketDisconnect()
241
+
242
+ if "bytes" in message and message["bytes"] is not None:
243
+ await upstream_ws.send(message["bytes"])
244
+ elif "text" in message and message["text"] is not None:
245
+ await upstream_ws.send(message["text"])
246
+
247
+ async def upstream_to_client():
248
+ while True:
249
+ msg = await upstream_ws.recv()
250
+ if isinstance(msg, bytes):
251
+ await client_ws.send_bytes(msg)
252
+ else:
253
+ await client_ws.send_text(msg)
254
+
255
+ await asyncio.gather(client_to_upstream(), upstream_to_client())
256
+
257
+ except WebSocketDisconnect:
258
+ logger.info("Client websocket disconnected")
259
+ except ConnectionClosed:
260
+ logger.info("Upstream websocket disconnected")
261
+ try:
262
+ await client_ws.close()
263
+ except Exception:
264
+ pass
265
+ except Exception:
266
+ logger.exception("Websocket proxy failed")
267
+ try:
268
+ await client_ws.close(code=1011, reason="Proxy failure")
269
+ except Exception:
270
+ pass
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi==0.116.1
2
+ uvicorn[standard]==0.35.0
3
+ websockets==15.0.1
test_ws_file.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import sys
3
+ import wave
4
+ import websockets
5
+
6
+
7
+ CHUNK_SAMPLES = 512 # matches the old endpoint handler chunking pattern nicely
8
+ SAMPLE_RATE = 16000
9
+ SAMPLE_WIDTH = 2
10
+ CHANNELS = 1
11
+
12
+
13
+ def read_wav_pcm16_mono(path: str) -> bytes:
14
+ with wave.open(path, "rb") as wf:
15
+ sr = wf.getframerate()
16
+ sw = wf.getsampwidth()
17
+ ch = wf.getnchannels()
18
+
19
+ if sr != SAMPLE_RATE or sw != SAMPLE_WIDTH or ch != CHANNELS:
20
+ raise ValueError(
21
+ f"Expected WAV mono/16kHz/16-bit PCM, got sr={sr}, sw={sw}, ch={ch}"
22
+ )
23
+
24
+ return wf.readframes(wf.getnframes())
25
+
26
+
27
+ def write_wav_pcm16_mono(path: str, pcm_bytes: bytes) -> None:
28
+ with wave.open(path, "wb") as wf:
29
+ wf.setnchannels(CHANNELS)
30
+ wf.setsampwidth(SAMPLE_WIDTH)
31
+ wf.setframerate(SAMPLE_RATE)
32
+ wf.writeframes(pcm_bytes)
33
+
34
+
35
+ async def main():
36
+ if len(sys.argv) < 3:
37
+ print("Usage:")
38
+ print(" python test_ws_file.py <ws_url> <input.wav> [hf_token]")
39
+ print("Example:")
40
+ print(" python test_ws_file.py ws://localhost:7860/ws input.wav")
41
+ sys.exit(1)
42
+
43
+ ws_url = sys.argv[1]
44
+ input_wav = sys.argv[2]
45
+ hf_token = sys.argv[3] if len(sys.argv) > 3 else None
46
+
47
+ headers = {}
48
+ if hf_token:
49
+ headers["Authorization"] = f"Bearer {hf_token}"
50
+
51
+ audio = read_wav_pcm16_mono(input_wav)
52
+ bytes_per_chunk = CHUNK_SAMPLES * SAMPLE_WIDTH
53
+
54
+ received = bytearray()
55
+
56
+ async with websockets.connect(
57
+ ws_url,
58
+ additional_headers=headers if headers else None,
59
+ max_size=None,
60
+ ping_interval=20,
61
+ ping_timeout=20,
62
+ ) as ws:
63
+ # sender
64
+ for i in range(0, len(audio), bytes_per_chunk):
65
+ await ws.send(audio[i : i + bytes_per_chunk])
66
+ await asyncio.sleep(CHUNK_SAMPLES / SAMPLE_RATE)
67
+
68
+ # Give the server some time to answer
69
+ # For a real app you'd use a smarter turn-ending signal or UI behavior.
70
+ try:
71
+ while True:
72
+ msg = await asyncio.wait_for(ws.recv(), timeout=8.0)
73
+ if isinstance(msg, bytes):
74
+ received.extend(msg)
75
+ else:
76
+ print("TEXT EVENT:", msg)
77
+ except asyncio.TimeoutError:
78
+ pass
79
+
80
+ write_wav_pcm16_mono("response.wav", bytes(received))
81
+ print("Wrote response.wav")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ asyncio.run(main())