Commit ·
172bc37
0
Parent(s):
Initial commit
Browse files- .dockerignore +9 -0
- Dockerfile +35 -0
- app/main.py +270 -0
- requirements.txt +3 -0
- test_ws_file.py +85 -0
.dockerignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
__pycache__
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.env
|
| 7 |
+
.venv
|
| 8 |
+
dist
|
| 9 |
+
build
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
PORT=7860 \
|
| 7 |
+
INTERNAL_WS_HOST=127.0.0.1 \
|
| 8 |
+
INTERNAL_WS_PORT=9000 \
|
| 9 |
+
S2S_REPO_DIR=/opt/speech-to-speech
|
| 10 |
+
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
git \
|
| 13 |
+
ffmpeg \
|
| 14 |
+
libsndfile1 \
|
| 15 |
+
curl \
|
| 16 |
+
ca-certificates \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
WORKDIR /app
|
| 20 |
+
|
| 21 |
+
COPY requirements.txt .
|
| 22 |
+
RUN pip install --upgrade pip setuptools wheel && \
|
| 23 |
+
pip install -r requirements.txt && \
|
| 24 |
+
pip install uv
|
| 25 |
+
|
| 26 |
+
# Clone speech-to-speech and install its dependencies the way the repo expects
|
| 27 |
+
RUN git clone https://github.com/huggingface/speech-to-speech.git ${S2S_REPO_DIR} && \
|
| 28 |
+
cd ${S2S_REPO_DIR} && \
|
| 29 |
+
uv sync --no-dev
|
| 30 |
+
|
| 31 |
+
COPY app /app/app
|
| 32 |
+
|
| 33 |
+
EXPOSE 7860
|
| 34 |
+
|
| 35 |
+
CMD ["uv", "run", "--directory", "/app", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/main.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import signal
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
| 11 |
+
from fastapi.responses import JSONResponse
|
| 12 |
+
import websockets
|
| 13 |
+
from websockets.exceptions import ConnectionClosed
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=os.getenv("LOG_LEVEL", "INFO").upper(),
|
| 17 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 18 |
+
)
|
| 19 |
+
logger = logging.getLogger("s2s-endpoint")
|
| 20 |
+
|
| 21 |
+
HOST = "0.0.0.0"
|
| 22 |
+
PORT = int(os.getenv("PORT", "7860"))
|
| 23 |
+
|
| 24 |
+
INTERNAL_WS_HOST = os.getenv("INTERNAL_WS_HOST", "127.0.0.1")
|
| 25 |
+
INTERNAL_WS_PORT = int(os.getenv("INTERNAL_WS_PORT", "9000"))
|
| 26 |
+
INTERNAL_WS_URL = f"ws://{INTERNAL_WS_HOST}:{INTERNAL_WS_PORT}"
|
| 27 |
+
|
| 28 |
+
S2S_REPO_DIR = os.getenv("S2S_REPO_DIR", "/opt/speech-to-speech")
|
| 29 |
+
|
| 30 |
+
# Baseline model choices. Keep them simple for a first deployment.
|
| 31 |
+
# You can override any of these in the endpoint env vars.
|
| 32 |
+
LM_MODEL_NAME = os.getenv("LM_MODEL_NAME", "Qwen/Qwen2.5-3B-Instruct")
|
| 33 |
+
TTS = os.getenv("TTS", "pocket")
|
| 34 |
+
POCKET_TTS_VOICE = os.getenv("POCKET_TTS_VOICE", "jean")
|
| 35 |
+
DEVICE = os.getenv("DEVICE", "cuda")
|
| 36 |
+
LANGUAGE = os.getenv("LANGUAGE", "en")
|
| 37 |
+
CHAT_SIZE = os.getenv("CHAT_SIZE", "10")
|
| 38 |
+
STT_COMPILE_MODE = os.getenv("STT_COMPILE_MODE", "reduce-overhead")
|
| 39 |
+
|
| 40 |
+
# Optional extra CLI args for speech-to-speech, space-separated.
|
| 41 |
+
# Example:
|
| 42 |
+
# EXTRA_S2S_ARGS="--stt_model_name large-v3 --temperature 0.7"
|
| 43 |
+
EXTRA_S2S_ARGS = os.getenv("EXTRA_S2S_ARGS", "").strip()
|
| 44 |
+
|
| 45 |
+
# If you later want to use an OpenAI-compatible API-backed LLM instead of a local LM,
|
| 46 |
+
# set USE_OPENAI_API_LLM=1 and configure the related env vars.
|
| 47 |
+
USE_OPENAI_API_LLM = os.getenv("USE_OPENAI_API_LLM", "0") == "1"
|
| 48 |
+
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "")
|
| 49 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
| 50 |
+
OPENAI_API_MODEL = os.getenv("OPENAI_API_MODEL", "")
|
| 51 |
+
|
| 52 |
+
pipeline_process: Optional[subprocess.Popen] = None
|
| 53 |
+
|
| 54 |
+
def build_s2s_command() -> list[str]:
|
| 55 |
+
cmd = [
|
| 56 |
+
"uv",
|
| 57 |
+
"run",
|
| 58 |
+
"--directory",
|
| 59 |
+
S2S_REPO_DIR,
|
| 60 |
+
"python",
|
| 61 |
+
"s2s_pipeline.py",
|
| 62 |
+
"--mode", "websocket",
|
| 63 |
+
"--ws_host", INTERNAL_WS_HOST,
|
| 64 |
+
"--ws_port", str(INTERNAL_WS_PORT),
|
| 65 |
+
"--device", DEVICE,
|
| 66 |
+
"--language", LANGUAGE,
|
| 67 |
+
"--chat_size", CHAT_SIZE,
|
| 68 |
+
"--tts", TTS,
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
if STT_COMPILE_MODE:
|
| 72 |
+
cmd += ["--stt_compile_mode", STT_COMPILE_MODE]
|
| 73 |
+
|
| 74 |
+
if TTS == "pocket" and POCKET_TTS_VOICE:
|
| 75 |
+
cmd += ["--pocket_tts_voice", POCKET_TTS_VOICE]
|
| 76 |
+
|
| 77 |
+
if USE_OPENAI_API_LLM:
|
| 78 |
+
cmd += [
|
| 79 |
+
"--llm", "open-api",
|
| 80 |
+
"--open_api_base_url", OPENAI_API_BASE,
|
| 81 |
+
"--open_api_key", OPENAI_API_KEY,
|
| 82 |
+
"--open_api_model_name", OPENAI_API_MODEL,
|
| 83 |
+
]
|
| 84 |
+
else:
|
| 85 |
+
cmd += ["--lm_model_name", LM_MODEL_NAME]
|
| 86 |
+
|
| 87 |
+
if EXTRA_S2S_ARGS:
|
| 88 |
+
cmd += EXTRA_S2S_ARGS.split()
|
| 89 |
+
|
| 90 |
+
return cmd
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
async def wait_for_internal_ws(timeout_s: float = 900.0) -> None:
|
| 94 |
+
"""
|
| 95 |
+
Wait until the internal speech-to-speech websocket server accepts connections.
|
| 96 |
+
First model load can take a while on endpoint startup.
|
| 97 |
+
"""
|
| 98 |
+
start = asyncio.get_event_loop().time()
|
| 99 |
+
last_error = None
|
| 100 |
+
|
| 101 |
+
while True:
|
| 102 |
+
if pipeline_process is not None and pipeline_process.poll() is not None:
|
| 103 |
+
raise RuntimeError(
|
| 104 |
+
f"speech-to-speech process exited early with code {pipeline_process.returncode}"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
async with websockets.connect(
|
| 109 |
+
INTERNAL_WS_URL,
|
| 110 |
+
open_timeout=5,
|
| 111 |
+
ping_interval=None,
|
| 112 |
+
max_size=None,
|
| 113 |
+
):
|
| 114 |
+
logger.info("Internal speech-to-speech websocket is ready at %s", INTERNAL_WS_URL)
|
| 115 |
+
return
|
| 116 |
+
except Exception as exc:
|
| 117 |
+
last_error = exc
|
| 118 |
+
|
| 119 |
+
if asyncio.get_event_loop().time() - start > timeout_s:
|
| 120 |
+
raise RuntimeError(
|
| 121 |
+
f"Timed out waiting for internal websocket server at {INTERNAL_WS_URL}. "
|
| 122 |
+
f"Last error: {last_error}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
await asyncio.sleep(2.0)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def start_pipeline() -> None:
|
| 129 |
+
global pipeline_process
|
| 130 |
+
|
| 131 |
+
if pipeline_process is not None and pipeline_process.poll() is None:
|
| 132 |
+
logger.info("speech-to-speech process already running")
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
cmd = build_s2s_command()
|
| 136 |
+
logger.info("Starting speech-to-speech subprocess:\n%s", " ".join(cmd))
|
| 137 |
+
|
| 138 |
+
env = os.environ.copy()
|
| 139 |
+
|
| 140 |
+
pipeline_process = subprocess.Popen(
|
| 141 |
+
cmd,
|
| 142 |
+
cwd=S2S_REPO_DIR,
|
| 143 |
+
env=env,
|
| 144 |
+
stdout=sys.stdout,
|
| 145 |
+
stderr=sys.stderr,
|
| 146 |
+
preexec_fn=os.setsid if os.name != "nt" else None,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def stop_pipeline() -> None:
|
| 151 |
+
global pipeline_process
|
| 152 |
+
|
| 153 |
+
if pipeline_process is None:
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
if pipeline_process.poll() is not None:
|
| 157 |
+
logger.info("speech-to-speech process already stopped")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
logger.info("Stopping speech-to-speech subprocess")
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
if os.name != "nt":
|
| 164 |
+
os.killpg(os.getpgid(pipeline_process.pid), signal.SIGTERM)
|
| 165 |
+
else:
|
| 166 |
+
pipeline_process.terminate()
|
| 167 |
+
pipeline_process.wait(timeout=20)
|
| 168 |
+
except Exception:
|
| 169 |
+
logger.exception("Graceful shutdown failed, killing subprocess")
|
| 170 |
+
try:
|
| 171 |
+
if os.name != "nt":
|
| 172 |
+
os.killpg(os.getpgid(pipeline_process.pid), signal.SIGKILL)
|
| 173 |
+
else:
|
| 174 |
+
pipeline_process.kill()
|
| 175 |
+
except Exception:
|
| 176 |
+
logger.exception("Failed to kill subprocess")
|
| 177 |
+
finally:
|
| 178 |
+
pipeline_process = None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@asynccontextmanager
|
| 182 |
+
async def lifespan(app: FastAPI):
|
| 183 |
+
start_pipeline()
|
| 184 |
+
try:
|
| 185 |
+
yield
|
| 186 |
+
finally:
|
| 187 |
+
stop_pipeline()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
app = FastAPI(lifespan=lifespan)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@app.get("/")
|
| 194 |
+
async def root():
|
| 195 |
+
return {
|
| 196 |
+
"message": "s2s endpoint is up",
|
| 197 |
+
"health": "/health",
|
| 198 |
+
"websocket": "/ws",
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@app.get("/health")
|
| 203 |
+
async def health():
|
| 204 |
+
if pipeline_process is None:
|
| 205 |
+
raise HTTPException(status_code=503, detail="speech-to-speech process not started")
|
| 206 |
+
|
| 207 |
+
if pipeline_process.poll() is not None:
|
| 208 |
+
raise HTTPException(
|
| 209 |
+
status_code=503,
|
| 210 |
+
detail=f"speech-to-speech process exited with code {pipeline_process.returncode}",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
await asyncio.wait_for(wait_for_internal_ws(timeout_s=5), timeout=6)
|
| 215 |
+
except Exception as exc:
|
| 216 |
+
raise HTTPException(status_code=503, detail=f"internal websocket not ready: {exc}") from exc
|
| 217 |
+
|
| 218 |
+
return JSONResponse({"status": "ok", "internal_ws": INTERNAL_WS_URL})
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@app.websocket("/ws")
|
| 222 |
+
async def websocket_proxy(client_ws: WebSocket):
|
| 223 |
+
await client_ws.accept()
|
| 224 |
+
logger.info("Client websocket connected")
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
async with websockets.connect(
|
| 228 |
+
INTERNAL_WS_URL,
|
| 229 |
+
open_timeout=30,
|
| 230 |
+
ping_interval=20,
|
| 231 |
+
ping_timeout=20,
|
| 232 |
+
max_size=None,
|
| 233 |
+
) as upstream_ws:
|
| 234 |
+
|
| 235 |
+
async def client_to_upstream():
|
| 236 |
+
while True:
|
| 237 |
+
message = await client_ws.receive()
|
| 238 |
+
|
| 239 |
+
if message["type"] == "websocket.disconnect":
|
| 240 |
+
raise WebSocketDisconnect()
|
| 241 |
+
|
| 242 |
+
if "bytes" in message and message["bytes"] is not None:
|
| 243 |
+
await upstream_ws.send(message["bytes"])
|
| 244 |
+
elif "text" in message and message["text"] is not None:
|
| 245 |
+
await upstream_ws.send(message["text"])
|
| 246 |
+
|
| 247 |
+
async def upstream_to_client():
|
| 248 |
+
while True:
|
| 249 |
+
msg = await upstream_ws.recv()
|
| 250 |
+
if isinstance(msg, bytes):
|
| 251 |
+
await client_ws.send_bytes(msg)
|
| 252 |
+
else:
|
| 253 |
+
await client_ws.send_text(msg)
|
| 254 |
+
|
| 255 |
+
await asyncio.gather(client_to_upstream(), upstream_to_client())
|
| 256 |
+
|
| 257 |
+
except WebSocketDisconnect:
|
| 258 |
+
logger.info("Client websocket disconnected")
|
| 259 |
+
except ConnectionClosed:
|
| 260 |
+
logger.info("Upstream websocket disconnected")
|
| 261 |
+
try:
|
| 262 |
+
await client_ws.close()
|
| 263 |
+
except Exception:
|
| 264 |
+
pass
|
| 265 |
+
except Exception:
|
| 266 |
+
logger.exception("Websocket proxy failed")
|
| 267 |
+
try:
|
| 268 |
+
await client_ws.close(code=1011, reason="Proxy failure")
|
| 269 |
+
except Exception:
|
| 270 |
+
pass
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.116.1
|
| 2 |
+
uvicorn[standard]==0.35.0
|
| 3 |
+
websockets==15.0.1
|
test_ws_file.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import sys
|
| 3 |
+
import wave
|
| 4 |
+
import websockets
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
CHUNK_SAMPLES = 512 # matches the old endpoint handler chunking pattern nicely
|
| 8 |
+
SAMPLE_RATE = 16000
|
| 9 |
+
SAMPLE_WIDTH = 2
|
| 10 |
+
CHANNELS = 1
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def read_wav_pcm16_mono(path: str) -> bytes:
|
| 14 |
+
with wave.open(path, "rb") as wf:
|
| 15 |
+
sr = wf.getframerate()
|
| 16 |
+
sw = wf.getsampwidth()
|
| 17 |
+
ch = wf.getnchannels()
|
| 18 |
+
|
| 19 |
+
if sr != SAMPLE_RATE or sw != SAMPLE_WIDTH or ch != CHANNELS:
|
| 20 |
+
raise ValueError(
|
| 21 |
+
f"Expected WAV mono/16kHz/16-bit PCM, got sr={sr}, sw={sw}, ch={ch}"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
return wf.readframes(wf.getnframes())
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def write_wav_pcm16_mono(path: str, pcm_bytes: bytes) -> None:
|
| 28 |
+
with wave.open(path, "wb") as wf:
|
| 29 |
+
wf.setnchannels(CHANNELS)
|
| 30 |
+
wf.setsampwidth(SAMPLE_WIDTH)
|
| 31 |
+
wf.setframerate(SAMPLE_RATE)
|
| 32 |
+
wf.writeframes(pcm_bytes)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def main():
|
| 36 |
+
if len(sys.argv) < 3:
|
| 37 |
+
print("Usage:")
|
| 38 |
+
print(" python test_ws_file.py <ws_url> <input.wav> [hf_token]")
|
| 39 |
+
print("Example:")
|
| 40 |
+
print(" python test_ws_file.py ws://localhost:7860/ws input.wav")
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
|
| 43 |
+
ws_url = sys.argv[1]
|
| 44 |
+
input_wav = sys.argv[2]
|
| 45 |
+
hf_token = sys.argv[3] if len(sys.argv) > 3 else None
|
| 46 |
+
|
| 47 |
+
headers = {}
|
| 48 |
+
if hf_token:
|
| 49 |
+
headers["Authorization"] = f"Bearer {hf_token}"
|
| 50 |
+
|
| 51 |
+
audio = read_wav_pcm16_mono(input_wav)
|
| 52 |
+
bytes_per_chunk = CHUNK_SAMPLES * SAMPLE_WIDTH
|
| 53 |
+
|
| 54 |
+
received = bytearray()
|
| 55 |
+
|
| 56 |
+
async with websockets.connect(
|
| 57 |
+
ws_url,
|
| 58 |
+
additional_headers=headers if headers else None,
|
| 59 |
+
max_size=None,
|
| 60 |
+
ping_interval=20,
|
| 61 |
+
ping_timeout=20,
|
| 62 |
+
) as ws:
|
| 63 |
+
# sender
|
| 64 |
+
for i in range(0, len(audio), bytes_per_chunk):
|
| 65 |
+
await ws.send(audio[i : i + bytes_per_chunk])
|
| 66 |
+
await asyncio.sleep(CHUNK_SAMPLES / SAMPLE_RATE)
|
| 67 |
+
|
| 68 |
+
# Give the server some time to answer
|
| 69 |
+
# For a real app you'd use a smarter turn-ending signal or UI behavior.
|
| 70 |
+
try:
|
| 71 |
+
while True:
|
| 72 |
+
msg = await asyncio.wait_for(ws.recv(), timeout=8.0)
|
| 73 |
+
if isinstance(msg, bytes):
|
| 74 |
+
received.extend(msg)
|
| 75 |
+
else:
|
| 76 |
+
print("TEXT EVENT:", msg)
|
| 77 |
+
except asyncio.TimeoutError:
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
write_wav_pcm16_mono("response.wav", bytes(received))
|
| 81 |
+
print("Wrote response.wav")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
asyncio.run(main())
|