Spaces:
Sleeping
Sleeping
| import asyncio | |
| import io | |
| import nltk | |
| import logging | |
| import re | |
| import tempfile | |
| from fastapi import FastAPI, HTTPException, Response, Request | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any | |
| import httpx | |
| from pydub import AudioSegment | |
| # --- 1. 配置 --- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| WORKER_URLS = [ | |
| "https://awsl1111ddd-114514test.hf.space/tts", | |
| "https://awsl1111ddd-114514.hf.space/tts", | |
| "https://awsl1111ddd-dev.hf.space/tts", | |
| "https://awsl1111ddd-cialloapi.hf.space/tts", | |
| "https://qwed2025-devfork.hf.space/tts", | |
| "https://qwed2025-Qwer.hf.space/tts", | |
| "https://qwed2025-MY-TESTSPACE.hf.space/tts", | |
| "https://johnali202508-aitts-test.hf.space/tts", | |
| "https://johnali202508-ciallo.hf.space/tts", | |
| "https://johnali202508-ai.hf.space/tts", | |
| ] | |
| MAX_CONCURRENT_REQUESTS = 8 | |
| # --- 2. Pydantic模型 --- | |
| class WeightsPaths(BaseModel): | |
| sovits_path: str; gpt_path: str | |
| # --- START OF NEW SECTION: OpenAI Audio API Models --- | |
| class OpenAIAudioRequest(BaseModel): | |
| model: str = "tts-1" # 兼容字段 | |
| input: str # 这是我们要合成的文本 | |
| voice: str = "alloy" # 兼容字段, 将来可用于选择不同的参考音频 | |
| response_format: str = Field(default="wav", alias="response_format") | |
| speed: float = 1.0 # 兼容字段 | |
| class Config: | |
| extra = "allow" | |
| populate_by_name = True | |
| # --- END OF NEW SECTION --- | |
| # --- 3. 初始化和辅助函数 (所有这部分代码不变) --- | |
| app = FastAPI() | |
| client = httpx.AsyncClient(timeout=180.0, http2=True) | |
| try: nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| logging.info("NLTK 'punkt' tokenizer not found. Downloading..."); nltk.download('punkt', quiet=True); logging.info("'punkt' downloaded successfully.") | |
| def split_by_punctuation(text: str, max_len: int = 20): | |
| # (此函数不变) | |
| fragments = re.split(r'([,.:;!?。,、;:!?.…【】])', text); sentences = [] | |
| temp_frag = "" | |
| for frag in fragments: | |
| if frag in ",.:;!?。,、;:!?.…【】": temp_frag += frag; sentences.append(temp_frag); temp_frag = "" | |
| else: | |
| if temp_frag: sentences.append(temp_frag) | |
| temp_frag = frag | |
| if temp_frag: sentences.append(temp_frag) | |
| chunks, current_chunk = [], "" | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if not sentence: continue | |
| if not current_chunk or len(current_chunk) + len(sentence) <= max_len: current_chunk += (" " if current_chunk else "") + sentence | |
| else: chunks.append(current_chunk); current_chunk = sentence | |
| while len(current_chunk) > max_len: | |
| split_pos = -1 | |
| for punc in ",;。!?…【】 ": | |
| pos = current_chunk.rfind(punc, 0, max_len) | |
| if pos > -1: split_pos = pos; break | |
| if split_pos == -1: split_pos = max_len -1 | |
| chunks.append(current_chunk[:split_pos+1]); current_chunk = current_chunk[split_pos+1:] | |
| if current_chunk: chunks.append(current_chunk) | |
| return [c.strip() for c in chunks if c.strip()] | |
| def sanitize_and_default_params(params: dict) -> dict: | |
| # (此函数不变) | |
| allowed_keys = {"text", "text_lang", "ref_audio_path", "prompt_lang", "prompt_text", "media_type", "streaming_mode"} | |
| default_values = {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "streaming_mode": False} | |
| sanitized_params = {key: params[key] for key in allowed_keys if key in params} | |
| for key, default_value in default_values.items(): | |
| if key not in sanitized_params: sanitized_params[key] = default_value | |
| required_keys = {"text", "ref_audio_path", "prompt_text"} | |
| missing_keys = required_keys - set(sanitized_params.keys()) | |
| if missing_keys: raise ValueError(f"Missing required fields after sanitization: {', '.join(missing_keys)}") | |
| logging.info(f"Sanitized params. Final keys sent to worker: {list(sanitized_params.keys())}") | |
| return sanitized_params | |
| async def send_task_to_worker(worker_url, payload, index, semaphore): | |
| # (此函数不变) | |
| async with semaphore: | |
| try: | |
| final_payload = sanitize_and_default_params(payload) | |
| logging.info(f"Task {index}: Sending chunk to {worker_url} (text length: {len(final_payload.get('text', ''))})") | |
| response = await client.post(worker_url, json=final_payload, timeout=180.0) | |
| if response.status_code == 200: logging.info(f"Task {index}: Successfully received audio data from {worker_url}"); return index, response.content | |
| error_body = await response.text(); logging.error(f"Task {index}: Worker {worker_url} returned error status {response.status_code}. Body: {error_body}"); return index, None | |
| except ValueError as ve: logging.error(f"Task {index}: Parameter validation failed before sending. Error: {ve}"); return index, None | |
| except Exception as e: logging.error(f"Task {index}: Request to {worker_url} failed: {e}"); return index, None | |
| async def process_tts_request(request_data: dict): | |
| # (此函数不变) | |
| text = request_data.get("text", "") | |
| logging.info(f"Processing new TTS request. Full text length: {len(text)}") | |
| chunks = split_by_punctuation(text) | |
| if not chunks: raise HTTPException(status_code=400, detail="Input text is empty or resulted in no chunks.") | |
| if len(chunks) <= 1: | |
| logging.info("Text resulted in one chunk. Forwarding to a single worker...") | |
| payload = request_data; payload["text"] = chunks[0] if chunks else "" | |
| semaphore = asyncio.Semaphore(1) | |
| _, audio_bytes = await send_task_to_worker(WORKER_URLS[0], payload, 0, semaphore) | |
| if audio_bytes: logging.info("Single chunk processed successfully. Returning audio."); return Response(content=audio_bytes, media_type=f"audio/{request_data.get('media_type', 'wav')}") | |
| raise HTTPException(status_code=500, detail="Downstream worker failed to process the request.") | |
| logging.info(f"Text split into {len(chunks)} chunks. Starting parallel processing with a limit of {MAX_CONCURRENT_REQUESTS} requests at a time...") | |
| semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) | |
| tasks = [] | |
| base_payload = request_data; num_workers = len(WORKER_URLS) | |
| for i, chunk in enumerate(chunks): | |
| task_payload = {**base_payload, "text": chunk}; worker_url = WORKER_URLS[i % num_workers] | |
| tasks.append(send_task_to_worker(worker_url, task_payload, i, semaphore)) | |
| results = await asyncio.gather(*tasks); results.sort(key=lambda x: x[0]) | |
| valid_audio_bytes_list = [audio for _, audio in results if audio] | |
| if not valid_audio_bytes_list: raise HTTPException(status_code=500, detail="All downstream worker tasks failed, could not generate audio.") | |
| logging.info(f"Starting to splice {len(valid_audio_bytes_list)} valid audio segments using a temporary file to save memory...") | |
| with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as temp_f: | |
| media_type = request_data.get("media_type", "wav") | |
| first_segment = AudioSegment.from_file(io.BytesIO(valid_audio_bytes_list[0]), format=media_type) | |
| first_segment.export(temp_f.name, format=media_type) | |
| for i, audio_bytes in enumerate(valid_audio_bytes_list[1:]): | |
| try: | |
| combined_audio = AudioSegment.from_file(temp_f.name, format=media_type) | |
| next_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format=media_type) | |
| (combined_audio + next_segment).export(temp_f.name, format=media_type) | |
| except Exception as e: logging.warning(f"Splicing failed for audio segment {i+1}. Skipping. Error: {e}") | |
| temp_f.seek(0) | |
| final_audio_content = temp_f.read() | |
| logging.info("Audio splicing complete. Returning final audio file to user.") | |
| return Response(content=final_audio_content, media_type=f"audio/{media_type}") | |
| # --- 4. API端点 --- | |
| def read_root(): return {"status": "Master TTS Accelerator is running"} | |
| async def update_all_workers_endpoint(paths: WeightsPaths): | |
| # (此函数不变) | |
| logging.info(f"Received request to update all workers. SOVITS='{paths.sovits_path}', GPT='{paths.gpt_path}'") | |
| async def send_update_to_worker(worker_url, sovits_path, gpt_path): | |
| base_url = worker_url.replace("/tts", ""); | |
| try: | |
| sovits_resp = await client.get(f"{base_url}/set_sovits_weights", params={"weights_path": sovits_path}, timeout=60); sovits_resp.raise_for_status() | |
| gpt_resp = await client.get(f"{base_url}/set_gpt_weights", params={"weights_path": gpt_path}, timeout=60); gpt_resp.raise_for_status() | |
| return base_url, "Success" | |
| except Exception as e: return base_url, f"Failed: {e}" | |
| tasks = [send_update_to_worker(url, paths.sovits_path, paths.gpt_path) for url in WORKER_URLS] | |
| results = await asyncio.gather(*tasks); status_report = {url: status for url, status in results} | |
| return {"message": "Update commands sent to all workers.", "status_report": status_report} | |
| async def tts_post_endpoint(request: Request): | |
| # (此函数不变) | |
| logging.info("--- HIT POST /tts endpoint (Sanitization Mode) ---") | |
| try: | |
| request_data = await request.json() | |
| logging.info(f"Received native JSON payload with keys: {list(request_data.keys())}") | |
| return await process_tts_request(request_data) | |
| except Exception as e: | |
| logging.error(f"Failed to parse JSON body or process request: {e}") | |
| raise HTTPException(status_code=400, detail=f"Invalid JSON body or processing error: {e}") | |
| # --- START OF NEW SECTION: OpenAI Audio API Endpoint --- | |
| async def openai_audio_speech_endpoint(request: OpenAIAudioRequest): | |
| logging.info("--- HIT POST /v1/audio/speech endpoint (OpenAI Audio API Mode) ---") | |
| # 1. 从请求中提取核心文本 | |
| text_to_speak = request.input | |
| logging.info(f"Extracted text from OpenAI audio format: '{text_to_speak}'") | |
| # 2. TODO (可选): 将来可以根据 request.voice 的值选择不同的参考音频 | |
| # 例如: if request.voice == "shimmer": ref_audio = "path/to/shimmer.wav" | |
| # 3. 构建默认的、完整的TTS参数包 | |
| # 这些是您的服务必需的、但OpenAI格式中没有的参数 | |
| tts_params = { | |
| "text": text_to_speak, | |
| "ref_audio_path": "/app/reference_audio/ref_shantianliang_1.wav", # 使用默认参考音频 | |
| "prompt_text": "这是一条参考音频,将此音频拖入参考内在添加文本即可合成音色", # 使用默认提示文本 | |
| "media_type": request.response_format, # 尊重客户端请求的格式 | |
| # 其他参数将由 sanitize_and_default_params 函数自动补齐 | |
| } | |
| # 4. 调用核心处理逻辑 | |
| try: | |
| return await process_tts_request(tts_params) | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred while processing OpenAI audio request: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error during TTS processing.") | |
| # --- END OF NEW SECTION --- |