Awsl1111ddd commited on
Commit
759ae49
·
verified ·
1 Parent(s): 978eb52

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +27 -0
  2. app.py +207 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 使用一个官方的、轻量的Python基础镜像
2
+ FROM python:3.9-slim
3
+
4
+ # --- 按照您的要求进行修改 ---
5
+ # 在容器构建阶段,以root用户身份创建 /nltk_data 目录,并赋予 777 权限。
6
+ # mkdir -p: 如果目录已存在也不会报错
7
+ # chmod 777: 给予所有用户(包括之后运行应用的非root用户)读、写、执行的权限
8
+ RUN mkdir -p /nltk_data && chmod 777 /nltk_data
9
+
10
+ # 在容器内创建一个工作目录
11
+ WORKDIR /code
12
+
13
+ # 复制依赖文件到工作目录
14
+ COPY ./requirements.txt /code/requirements.txt
15
+
16
+ # 安装依赖。--no-cache-dir 参数可以减小镜像体积
17
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
18
+
19
+ # 将当前目录下的所有文件(主要是app.py)复制到工作目录
20
+ COPY ./ /code/
21
+
22
+ # 暴露端口。Hugging Face Spaces 默认使用 7860 端口
23
+ EXPOSE 7860
24
+
25
+ # 容器启动时要执行的命令
26
+ # 运行uvicorn服务器,监听所有网络接口(0.0.0.0),端口为7860
27
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import nltk
4
+ import logging
5
+ import re
6
+ import tempfile
7
+ from fastapi import FastAPI, HTTPException, Response, Request
8
+ from pydantic import BaseModel, Field
9
+ from typing import List, Dict, Any
10
+
11
+ import httpx
12
+ from pydub import AudioSegment
13
+
14
+ # --- 1. 配置 ---
15
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
+ WORKER_URLS = [
17
+ "https://snsbhg-1111.hf.space/tts", "https://snsbhg-111102.hf.space/tts",
18
+ "https://snsbhg-111103.hf.space/tts", "https://snsbhg-111104.hf.space/tts",
19
+ "https://snsbhg-111105.hf.space/tts", "https://11edx-111106.hf.space/tts",
20
+ "https://11edx-111107.hf.space/tts", "https://11edx-111108.hf.space/tts",
21
+ "https://11edx-111109.hf.space/tts", "https://11edx-111110.hf.space/tts",
22
+ "https://11edx-111111.hf.space/tts", "https://11edx-111112.hf.space/tts",
23
+ "https://11edx-111113.hf.space/tts", "https://11edx-111114.hf.space/tts",
24
+ "https://11edx-111115.hf.space/tts", "https://11edx-111116.hf.space/tts",
25
+ "https://11edx-111117.hf.space/tts", "https://11edx-111118.hf.space/tts",
26
+ "https://11edx-111119.hf.space/tts", "https://11edx-111120.hf.space/tts",
27
+ ]
28
+ MAX_CONCURRENT_REQUESTS = 8
29
+
30
+ # --- 2. Pydantic模型 ---
31
+ class WeightsPaths(BaseModel):
32
+ sovits_path: str; gpt_path: str
33
+
34
+ # --- START OF NEW SECTION: OpenAI Audio API Models ---
35
+ class OpenAIAudioRequest(BaseModel):
36
+ model: str = "tts-1" # 兼容字段
37
+ input: str # 这是我们要合成的文本
38
+ voice: str = "alloy" # 兼容字段, 将来可用于选择不同的参考音频
39
+ response_format: str = Field(default="wav", alias="response_format")
40
+ speed: float = 1.0 # 兼容字段
41
+
42
+ class Config:
43
+ extra = "allow"
44
+ populate_by_name = True
45
+ # --- END OF NEW SECTION ---
46
+
47
+
48
+ # --- 3. 初始化和辅助函数 (所有这部分代码不变) ---
49
+ app = FastAPI()
50
+ client = httpx.AsyncClient(timeout=180.0, http2=True)
51
+ try: nltk.data.find('tokenizers/punkt')
52
+ except LookupError:
53
+ logging.info("NLTK 'punkt' tokenizer not found. Downloading..."); nltk.download('punkt', quiet=True); logging.info("'punkt' downloaded successfully.")
54
+
55
+ def split_by_punctuation(text: str, max_len: int = 150):
56
+ # (此函数不变)
57
+ fragments = re.split(r'([,.:;!?。,、;:!?.…【】])', text); sentences = []
58
+ temp_frag = ""
59
+ for frag in fragments:
60
+ if frag in ",.:;!?。,、;:!?.…【】": temp_frag += frag; sentences.append(temp_frag); temp_frag = ""
61
+ else:
62
+ if temp_frag: sentences.append(temp_frag)
63
+ temp_frag = frag
64
+ if temp_frag: sentences.append(temp_frag)
65
+ chunks, current_chunk = [], ""
66
+ for sentence in sentences:
67
+ sentence = sentence.strip()
68
+ if not sentence: continue
69
+ if not current_chunk or len(current_chunk) + len(sentence) <= max_len: current_chunk += (" " if current_chunk else "") + sentence
70
+ else: chunks.append(current_chunk); current_chunk = sentence
71
+ while len(current_chunk) > max_len:
72
+ split_pos = -1
73
+ for punc in ",;。!?…【】 ":
74
+ pos = current_chunk.rfind(punc, 0, max_len)
75
+ if pos > -1: split_pos = pos; break
76
+ if split_pos == -1: split_pos = max_len -1
77
+ chunks.append(current_chunk[:split_pos+1]); current_chunk = current_chunk[split_pos+1:]
78
+ if current_chunk: chunks.append(current_chunk)
79
+ return [c.strip() for c in chunks if c.strip()]
80
+
81
+ def sanitize_and_default_params(params: dict) -> dict:
82
+ # (此函数不变)
83
+ allowed_keys = {"text", "text_lang", "ref_audio_path", "prompt_lang", "prompt_text", "media_type", "streaming_mode"}
84
+ default_values = {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "streaming_mode": False}
85
+ sanitized_params = {key: params[key] for key in allowed_keys if key in params}
86
+ for key, default_value in default_values.items():
87
+ if key not in sanitized_params: sanitized_params[key] = default_value
88
+ required_keys = {"text", "ref_audio_path", "prompt_text"}
89
+ missing_keys = required_keys - set(sanitized_params.keys())
90
+ if missing_keys: raise ValueError(f"Missing required fields after sanitization: {', '.join(missing_keys)}")
91
+ logging.info(f"Sanitized params. Final keys sent to worker: {list(sanitized_params.keys())}")
92
+ return sanitized_params
93
+
94
+ async def send_task_to_worker(worker_url, payload, index, semaphore):
95
+ # (此函数不变)
96
+ async with semaphore:
97
+ try:
98
+ final_payload = sanitize_and_default_params(payload)
99
+ logging.info(f"Task {index}: Sending chunk to {worker_url} (text length: {len(final_payload.get('text', ''))})")
100
+ response = await client.post(worker_url, json=final_payload, timeout=180.0)
101
+ if response.status_code == 200: logging.info(f"Task {index}: Successfully received audio data from {worker_url}"); return index, response.content
102
+ error_body = await response.text(); logging.error(f"Task {index}: Worker {worker_url} returned error status {response.status_code}. Body: {error_body}"); return index, None
103
+ except ValueError as ve: logging.error(f"Task {index}: Parameter validation failed before sending. Error: {ve}"); return index, None
104
+ except Exception as e: logging.error(f"Task {index}: Request to {worker_url} failed: {e}"); return index, None
105
+
106
+ async def process_tts_request(request_data: dict):
107
+ # (此函数不变)
108
+ text = request_data.get("text", "")
109
+ logging.info(f"Processing new TTS request. Full text length: {len(text)}")
110
+ chunks = split_by_punctuation(text)
111
+ if not chunks: raise HTTPException(status_code=400, detail="Input text is empty or resulted in no chunks.")
112
+ if len(chunks) <= 1:
113
+ logging.info("Text resulted in one chunk. Forwarding to a single worker...")
114
+ payload = request_data; payload["text"] = chunks[0] if chunks else ""
115
+ semaphore = asyncio.Semaphore(1)
116
+ _, audio_bytes = await send_task_to_worker(WORKER_URLS[0], payload, 0, semaphore)
117
+ if audio_bytes: logging.info("Single chunk processed successfully. Returning audio."); return Response(content=audio_bytes, media_type=f"audio/{request_data.get('media_type', 'wav')}")
118
+ raise HTTPException(status_code=500, detail="Downstream worker failed to process the request.")
119
+ logging.info(f"Text split into {len(chunks)} chunks. Starting parallel processing with a limit of {MAX_CONCURRENT_REQUESTS} requests at a time...")
120
+ semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
121
+ tasks = []
122
+ base_payload = request_data; num_workers = len(WORKER_URLS)
123
+ for i, chunk in enumerate(chunks):
124
+ task_payload = {**base_payload, "text": chunk}; worker_url = WORKER_URLS[i % num_workers]
125
+ tasks.append(send_task_to_worker(worker_url, task_payload, i, semaphore))
126
+ results = await asyncio.gather(*tasks); results.sort(key=lambda x: x[0])
127
+ valid_audio_bytes_list = [audio for _, audio in results if audio]
128
+ if not valid_audio_bytes_list: raise HTTPException(status_code=500, detail="All downstream worker tasks failed, could not generate audio.")
129
+ logging.info(f"Starting to splice {len(valid_audio_bytes_list)} valid audio segments using a temporary file to save memory...")
130
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as temp_f:
131
+ media_type = request_data.get("media_type", "wav")
132
+ first_segment = AudioSegment.from_file(io.BytesIO(valid_audio_bytes_list[0]), format=media_type)
133
+ first_segment.export(temp_f.name, format=media_type)
134
+ for i, audio_bytes in enumerate(valid_audio_bytes_list[1:]):
135
+ try:
136
+ combined_audio = AudioSegment.from_file(temp_f.name, format=media_type)
137
+ next_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format=media_type)
138
+ (combined_audio + next_segment).export(temp_f.name, format=media_type)
139
+ except Exception as e: logging.warning(f"Splicing failed for audio segment {i+1}. Skipping. Error: {e}")
140
+ temp_f.seek(0)
141
+ final_audio_content = temp_f.read()
142
+ logging.info("Audio splicing complete. Returning final audio file to user.")
143
+ return Response(content=final_audio_content, media_type=f"audio/{media_type}")
144
+
145
+
146
+ # --- 4. API端点 ---
147
+ @app.get("/")
148
+ def read_root(): return {"status": "Master TTS Accelerator is running"}
149
+
150
+ @app.post("/update-all-workers")
151
+ async def update_all_workers_endpoint(paths: WeightsPaths):
152
+ # (此函数不变)
153
+ logging.info(f"Received request to update all workers. SOVITS='{paths.sovits_path}', GPT='{paths.gpt_path}'")
154
+ async def send_update_to_worker(worker_url, sovits_path, gpt_path):
155
+ base_url = worker_url.replace("/tts", "");
156
+ try:
157
+ sovits_resp = await client.get(f"{base_url}/set_sovits_weights", params={"weights_path": sovits_path}, timeout=60); sovits_resp.raise_for_status()
158
+ gpt_resp = await client.get(f"{base_url}/set_gpt_weights", params={"weights_path": gpt_path}, timeout=60); gpt_resp.raise_for_status()
159
+ return base_url, "Success"
160
+ except Exception as e: return base_url, f"Failed: {e}"
161
+ tasks = [send_update_to_worker(url, paths.sovits_path, paths.gpt_path) for url in WORKER_URLS]
162
+ results = await asyncio.gather(*tasks); status_report = {url: status for url, status in results}
163
+ return {"message": "Update commands sent to all workers.", "status_report": status_report}
164
+
165
+ @app.post("/tts")
166
+ async def tts_post_endpoint(request: Request):
167
+ # (此函数不变)
168
+ logging.info("--- HIT POST /tts endpoint (Sanitization Mode) ---")
169
+ try:
170
+ request_data = await request.json()
171
+ logging.info(f"Received native JSON payload with keys: {list(request_data.keys())}")
172
+ return await process_tts_request(request_data)
173
+ except Exception as e:
174
+ logging.error(f"Failed to parse JSON body or process request: {e}")
175
+ raise HTTPException(status_code=400, detail=f"Invalid JSON body or processing error: {e}")
176
+
177
+ # --- START OF NEW SECTION: OpenAI Audio API Endpoint ---
178
+ @app.post("/v1/audio/speech")
179
+ async def openai_audio_speech_endpoint(request: OpenAIAudioRequest):
180
+ logging.info("--- HIT POST /v1/audio/speech endpoint (OpenAI Audio API Mode) ---")
181
+
182
+ # 1. 从请求中提取核心文本
183
+ text_to_speak = request.input
184
+ logging.info(f"Extracted text from OpenAI audio format: '{text_to_speak}'")
185
+
186
+ # 2. TODO (可选): 将来可以根据 request.voice 的值选择不同的参考音频
187
+ # 例如: if request.voice == "shimmer": ref_audio = "path/to/shimmer.wav"
188
+
189
+ # 3. 构建默认的、完整的TTS参数包
190
+ # 这些是您的服务必需的、但OpenAI格式中没有的参数
191
+ tts_params = {
192
+ "text": text_to_speak,
193
+ "ref_audio_path": "/app/reference_audio/ref_shantianliang_1.wav", # 使用默认参考音频
194
+ "prompt_text": "这是一条参考音频,将此音频拖入参考内在添加文本即可合成音色", # 使用默认提示文本
195
+ "media_type": request.response_format, # 尊重客户端请求的格式
196
+ # 其他参数将由 sanitize_and_default_params 函数自动补齐
197
+ }
198
+
199
+ # 4. 调用核心处理逻辑
200
+ try:
201
+ return await process_tts_request(tts_params)
202
+ except HTTPException as e:
203
+ raise e
204
+ except Exception as e:
205
+ logging.error(f"An unexpected error occurred while processing OpenAI audio request: {e}")
206
+ raise HTTPException(status_code=500, detail="Internal server error during TTS processing.")
207
+ # --- END OF NEW SECTION ---
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ httpx[http2]
4
+ pydub
5
+ pydantic
6
+ nltk