Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- Dockerfile +27 -0
- app.py +207 -0
- requirements.txt +6 -0
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 使用一个官方的、轻量的Python基础镜像
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# --- 按照您的要求进行修改 ---
|
| 5 |
+
# 在容器构建阶段,以root用户身份创建 /nltk_data 目录,并赋予 777 权限。
|
| 6 |
+
# mkdir -p: 如果目录已存在也不会报错
|
| 7 |
+
# chmod 777: 给予所有用户(包括之后运行应用的非root用户)读、写、执行的权限
|
| 8 |
+
RUN mkdir -p /nltk_data && chmod 777 /nltk_data
|
| 9 |
+
|
| 10 |
+
# 在容器内创建一个工作目录
|
| 11 |
+
WORKDIR /code
|
| 12 |
+
|
| 13 |
+
# 复制依赖文件到工作目录
|
| 14 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 15 |
+
|
| 16 |
+
# 安装依赖。--no-cache-dir 参数可以减小镜像体积
|
| 17 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 18 |
+
|
| 19 |
+
# 将当前目录下的所有文件(主要是app.py)复制到工作目录
|
| 20 |
+
COPY ./ /code/
|
| 21 |
+
|
| 22 |
+
# 暴露端口。Hugging Face Spaces 默认使用 7860 端口
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
# 容器启动时要执行的命令
|
| 26 |
+
# 运行uvicorn服务器,监听所有网络接口(0.0.0.0),端口为7860
|
| 27 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import nltk
|
| 4 |
+
import logging
|
| 5 |
+
import re
|
| 6 |
+
import tempfile
|
| 7 |
+
from fastapi import FastAPI, HTTPException, Response, Request
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
import httpx
|
| 12 |
+
from pydub import AudioSegment
|
| 13 |
+
|
| 14 |
+
# --- 1. 配置 ---
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 16 |
+
WORKER_URLS = [
|
| 17 |
+
"https://snsbhg-1111.hf.space/tts", "https://snsbhg-111102.hf.space/tts",
|
| 18 |
+
"https://snsbhg-111103.hf.space/tts", "https://snsbhg-111104.hf.space/tts",
|
| 19 |
+
"https://snsbhg-111105.hf.space/tts", "https://11edx-111106.hf.space/tts",
|
| 20 |
+
"https://11edx-111107.hf.space/tts", "https://11edx-111108.hf.space/tts",
|
| 21 |
+
"https://11edx-111109.hf.space/tts", "https://11edx-111110.hf.space/tts",
|
| 22 |
+
"https://11edx-111111.hf.space/tts", "https://11edx-111112.hf.space/tts",
|
| 23 |
+
"https://11edx-111113.hf.space/tts", "https://11edx-111114.hf.space/tts",
|
| 24 |
+
"https://11edx-111115.hf.space/tts", "https://11edx-111116.hf.space/tts",
|
| 25 |
+
"https://11edx-111117.hf.space/tts", "https://11edx-111118.hf.space/tts",
|
| 26 |
+
"https://11edx-111119.hf.space/tts", "https://11edx-111120.hf.space/tts",
|
| 27 |
+
]
|
| 28 |
+
MAX_CONCURRENT_REQUESTS = 8
|
| 29 |
+
|
| 30 |
+
# --- 2. Pydantic模型 ---
|
| 31 |
+
class WeightsPaths(BaseModel):
|
| 32 |
+
sovits_path: str; gpt_path: str
|
| 33 |
+
|
| 34 |
+
# --- START OF NEW SECTION: OpenAI Audio API Models ---
|
| 35 |
+
class OpenAIAudioRequest(BaseModel):
|
| 36 |
+
model: str = "tts-1" # 兼容字段
|
| 37 |
+
input: str # 这是我们要合成的文本
|
| 38 |
+
voice: str = "alloy" # 兼容字段, 将来可用于选择不同的参考音频
|
| 39 |
+
response_format: str = Field(default="wav", alias="response_format")
|
| 40 |
+
speed: float = 1.0 # 兼容字段
|
| 41 |
+
|
| 42 |
+
class Config:
|
| 43 |
+
extra = "allow"
|
| 44 |
+
populate_by_name = True
|
| 45 |
+
# --- END OF NEW SECTION ---
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# --- 3. 初始化和辅助函数 (所有这部分代码不变) ---
|
| 49 |
+
app = FastAPI()
|
| 50 |
+
client = httpx.AsyncClient(timeout=180.0, http2=True)
|
| 51 |
+
try: nltk.data.find('tokenizers/punkt')
|
| 52 |
+
except LookupError:
|
| 53 |
+
logging.info("NLTK 'punkt' tokenizer not found. Downloading..."); nltk.download('punkt', quiet=True); logging.info("'punkt' downloaded successfully.")
|
| 54 |
+
|
| 55 |
+
def split_by_punctuation(text: str, max_len: int = 150):
|
| 56 |
+
# (此函数不变)
|
| 57 |
+
fragments = re.split(r'([,.:;!?。,、;:!?.…【】])', text); sentences = []
|
| 58 |
+
temp_frag = ""
|
| 59 |
+
for frag in fragments:
|
| 60 |
+
if frag in ",.:;!?。,、;:!?.…【】": temp_frag += frag; sentences.append(temp_frag); temp_frag = ""
|
| 61 |
+
else:
|
| 62 |
+
if temp_frag: sentences.append(temp_frag)
|
| 63 |
+
temp_frag = frag
|
| 64 |
+
if temp_frag: sentences.append(temp_frag)
|
| 65 |
+
chunks, current_chunk = [], ""
|
| 66 |
+
for sentence in sentences:
|
| 67 |
+
sentence = sentence.strip()
|
| 68 |
+
if not sentence: continue
|
| 69 |
+
if not current_chunk or len(current_chunk) + len(sentence) <= max_len: current_chunk += (" " if current_chunk else "") + sentence
|
| 70 |
+
else: chunks.append(current_chunk); current_chunk = sentence
|
| 71 |
+
while len(current_chunk) > max_len:
|
| 72 |
+
split_pos = -1
|
| 73 |
+
for punc in ",;。!?…【】 ":
|
| 74 |
+
pos = current_chunk.rfind(punc, 0, max_len)
|
| 75 |
+
if pos > -1: split_pos = pos; break
|
| 76 |
+
if split_pos == -1: split_pos = max_len -1
|
| 77 |
+
chunks.append(current_chunk[:split_pos+1]); current_chunk = current_chunk[split_pos+1:]
|
| 78 |
+
if current_chunk: chunks.append(current_chunk)
|
| 79 |
+
return [c.strip() for c in chunks if c.strip()]
|
| 80 |
+
|
| 81 |
+
def sanitize_and_default_params(params: dict) -> dict:
|
| 82 |
+
# (此函数不变)
|
| 83 |
+
allowed_keys = {"text", "text_lang", "ref_audio_path", "prompt_lang", "prompt_text", "media_type", "streaming_mode"}
|
| 84 |
+
default_values = {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "streaming_mode": False}
|
| 85 |
+
sanitized_params = {key: params[key] for key in allowed_keys if key in params}
|
| 86 |
+
for key, default_value in default_values.items():
|
| 87 |
+
if key not in sanitized_params: sanitized_params[key] = default_value
|
| 88 |
+
required_keys = {"text", "ref_audio_path", "prompt_text"}
|
| 89 |
+
missing_keys = required_keys - set(sanitized_params.keys())
|
| 90 |
+
if missing_keys: raise ValueError(f"Missing required fields after sanitization: {', '.join(missing_keys)}")
|
| 91 |
+
logging.info(f"Sanitized params. Final keys sent to worker: {list(sanitized_params.keys())}")
|
| 92 |
+
return sanitized_params
|
| 93 |
+
|
| 94 |
+
async def send_task_to_worker(worker_url, payload, index, semaphore):
|
| 95 |
+
# (此函数不变)
|
| 96 |
+
async with semaphore:
|
| 97 |
+
try:
|
| 98 |
+
final_payload = sanitize_and_default_params(payload)
|
| 99 |
+
logging.info(f"Task {index}: Sending chunk to {worker_url} (text length: {len(final_payload.get('text', ''))})")
|
| 100 |
+
response = await client.post(worker_url, json=final_payload, timeout=180.0)
|
| 101 |
+
if response.status_code == 200: logging.info(f"Task {index}: Successfully received audio data from {worker_url}"); return index, response.content
|
| 102 |
+
error_body = await response.text(); logging.error(f"Task {index}: Worker {worker_url} returned error status {response.status_code}. Body: {error_body}"); return index, None
|
| 103 |
+
except ValueError as ve: logging.error(f"Task {index}: Parameter validation failed before sending. Error: {ve}"); return index, None
|
| 104 |
+
except Exception as e: logging.error(f"Task {index}: Request to {worker_url} failed: {e}"); return index, None
|
| 105 |
+
|
| 106 |
+
async def process_tts_request(request_data: dict):
|
| 107 |
+
# (此函数不变)
|
| 108 |
+
text = request_data.get("text", "")
|
| 109 |
+
logging.info(f"Processing new TTS request. Full text length: {len(text)}")
|
| 110 |
+
chunks = split_by_punctuation(text)
|
| 111 |
+
if not chunks: raise HTTPException(status_code=400, detail="Input text is empty or resulted in no chunks.")
|
| 112 |
+
if len(chunks) <= 1:
|
| 113 |
+
logging.info("Text resulted in one chunk. Forwarding to a single worker...")
|
| 114 |
+
payload = request_data; payload["text"] = chunks[0] if chunks else ""
|
| 115 |
+
semaphore = asyncio.Semaphore(1)
|
| 116 |
+
_, audio_bytes = await send_task_to_worker(WORKER_URLS[0], payload, 0, semaphore)
|
| 117 |
+
if audio_bytes: logging.info("Single chunk processed successfully. Returning audio."); return Response(content=audio_bytes, media_type=f"audio/{request_data.get('media_type', 'wav')}")
|
| 118 |
+
raise HTTPException(status_code=500, detail="Downstream worker failed to process the request.")
|
| 119 |
+
logging.info(f"Text split into {len(chunks)} chunks. Starting parallel processing with a limit of {MAX_CONCURRENT_REQUESTS} requests at a time...")
|
| 120 |
+
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
| 121 |
+
tasks = []
|
| 122 |
+
base_payload = request_data; num_workers = len(WORKER_URLS)
|
| 123 |
+
for i, chunk in enumerate(chunks):
|
| 124 |
+
task_payload = {**base_payload, "text": chunk}; worker_url = WORKER_URLS[i % num_workers]
|
| 125 |
+
tasks.append(send_task_to_worker(worker_url, task_payload, i, semaphore))
|
| 126 |
+
results = await asyncio.gather(*tasks); results.sort(key=lambda x: x[0])
|
| 127 |
+
valid_audio_bytes_list = [audio for _, audio in results if audio]
|
| 128 |
+
if not valid_audio_bytes_list: raise HTTPException(status_code=500, detail="All downstream worker tasks failed, could not generate audio.")
|
| 129 |
+
logging.info(f"Starting to splice {len(valid_audio_bytes_list)} valid audio segments using a temporary file to save memory...")
|
| 130 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as temp_f:
|
| 131 |
+
media_type = request_data.get("media_type", "wav")
|
| 132 |
+
first_segment = AudioSegment.from_file(io.BytesIO(valid_audio_bytes_list[0]), format=media_type)
|
| 133 |
+
first_segment.export(temp_f.name, format=media_type)
|
| 134 |
+
for i, audio_bytes in enumerate(valid_audio_bytes_list[1:]):
|
| 135 |
+
try:
|
| 136 |
+
combined_audio = AudioSegment.from_file(temp_f.name, format=media_type)
|
| 137 |
+
next_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format=media_type)
|
| 138 |
+
(combined_audio + next_segment).export(temp_f.name, format=media_type)
|
| 139 |
+
except Exception as e: logging.warning(f"Splicing failed for audio segment {i+1}. Skipping. Error: {e}")
|
| 140 |
+
temp_f.seek(0)
|
| 141 |
+
final_audio_content = temp_f.read()
|
| 142 |
+
logging.info("Audio splicing complete. Returning final audio file to user.")
|
| 143 |
+
return Response(content=final_audio_content, media_type=f"audio/{media_type}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# --- 4. API端点 ---
|
| 147 |
+
@app.get("/")
|
| 148 |
+
def read_root(): return {"status": "Master TTS Accelerator is running"}
|
| 149 |
+
|
| 150 |
+
@app.post("/update-all-workers")
|
| 151 |
+
async def update_all_workers_endpoint(paths: WeightsPaths):
|
| 152 |
+
# (此函数不变)
|
| 153 |
+
logging.info(f"Received request to update all workers. SOVITS='{paths.sovits_path}', GPT='{paths.gpt_path}'")
|
| 154 |
+
async def send_update_to_worker(worker_url, sovits_path, gpt_path):
|
| 155 |
+
base_url = worker_url.replace("/tts", "");
|
| 156 |
+
try:
|
| 157 |
+
sovits_resp = await client.get(f"{base_url}/set_sovits_weights", params={"weights_path": sovits_path}, timeout=60); sovits_resp.raise_for_status()
|
| 158 |
+
gpt_resp = await client.get(f"{base_url}/set_gpt_weights", params={"weights_path": gpt_path}, timeout=60); gpt_resp.raise_for_status()
|
| 159 |
+
return base_url, "Success"
|
| 160 |
+
except Exception as e: return base_url, f"Failed: {e}"
|
| 161 |
+
tasks = [send_update_to_worker(url, paths.sovits_path, paths.gpt_path) for url in WORKER_URLS]
|
| 162 |
+
results = await asyncio.gather(*tasks); status_report = {url: status for url, status in results}
|
| 163 |
+
return {"message": "Update commands sent to all workers.", "status_report": status_report}
|
| 164 |
+
|
| 165 |
+
@app.post("/tts")
|
| 166 |
+
async def tts_post_endpoint(request: Request):
|
| 167 |
+
# (此函数不变)
|
| 168 |
+
logging.info("--- HIT POST /tts endpoint (Sanitization Mode) ---")
|
| 169 |
+
try:
|
| 170 |
+
request_data = await request.json()
|
| 171 |
+
logging.info(f"Received native JSON payload with keys: {list(request_data.keys())}")
|
| 172 |
+
return await process_tts_request(request_data)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logging.error(f"Failed to parse JSON body or process request: {e}")
|
| 175 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON body or processing error: {e}")
|
| 176 |
+
|
| 177 |
+
# --- START OF NEW SECTION: OpenAI Audio API Endpoint ---
|
| 178 |
+
@app.post("/v1/audio/speech")
|
| 179 |
+
async def openai_audio_speech_endpoint(request: OpenAIAudioRequest):
|
| 180 |
+
logging.info("--- HIT POST /v1/audio/speech endpoint (OpenAI Audio API Mode) ---")
|
| 181 |
+
|
| 182 |
+
# 1. 从请求中提取核心文本
|
| 183 |
+
text_to_speak = request.input
|
| 184 |
+
logging.info(f"Extracted text from OpenAI audio format: '{text_to_speak}'")
|
| 185 |
+
|
| 186 |
+
# 2. TODO (可选): 将来可以根据 request.voice 的值选择不同的参考音频
|
| 187 |
+
# 例如: if request.voice == "shimmer": ref_audio = "path/to/shimmer.wav"
|
| 188 |
+
|
| 189 |
+
# 3. 构建默认的、完整的TTS参数包
|
| 190 |
+
# 这些是您的服务必需的、但OpenAI格式中没有的参数
|
| 191 |
+
tts_params = {
|
| 192 |
+
"text": text_to_speak,
|
| 193 |
+
"ref_audio_path": "/app/reference_audio/ref_shantianliang_1.wav", # 使用默认参考音频
|
| 194 |
+
"prompt_text": "这是一条参考音频,将此音频拖入参考内在添加文本即可合成音色", # 使用默认提示文本
|
| 195 |
+
"media_type": request.response_format, # 尊重客户端请求的格式
|
| 196 |
+
# 其他参数将由 sanitize_and_default_params 函数自动补齐
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# 4. 调用核心处理逻辑
|
| 200 |
+
try:
|
| 201 |
+
return await process_tts_request(tts_params)
|
| 202 |
+
except HTTPException as e:
|
| 203 |
+
raise e
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logging.error(f"An unexpected error occurred while processing OpenAI audio request: {e}")
|
| 206 |
+
raise HTTPException(status_code=500, detail="Internal server error during TTS processing.")
|
| 207 |
+
# --- END OF NEW SECTION ---
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
httpx[http2]
|
| 4 |
+
pydub
|
| 5 |
+
pydantic
|
| 6 |
+
nltk
|