import asyncio import io import nltk import logging import re import tempfile from fastapi import FastAPI, HTTPException, Response, Request from pydantic import BaseModel, Field from typing import List, Dict, Any import httpx from pydub import AudioSegment # --- 1. 配置 --- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") WORKER_URLS = [ "https://awsl1111ddd-114514test.hf.space/tts", "https://awsl1111ddd-114514.hf.space/tts", "https://awsl1111ddd-dev.hf.space/tts", "https://awsl1111ddd-cialloapi.hf.space/tts", "https://qwed2025-devfork.hf.space/tts", "https://qwed2025-Qwer.hf.space/tts", "https://qwed2025-MY-TESTSPACE.hf.space/tts", "https://johnali202508-aitts-test.hf.space/tts", "https://johnali202508-ciallo.hf.space/tts", "https://johnali202508-ai.hf.space/tts", ] MAX_CONCURRENT_REQUESTS = 8 # --- 2. Pydantic模型 --- class WeightsPaths(BaseModel): sovits_path: str; gpt_path: str # --- START OF NEW SECTION: OpenAI Audio API Models --- class OpenAIAudioRequest(BaseModel): model: str = "tts-1" # 兼容字段 input: str # 这是我们要合成的文本 voice: str = "alloy" # 兼容字段, 将来可用于选择不同的参考音频 response_format: str = Field(default="wav", alias="response_format") speed: float = 1.0 # 兼容字段 class Config: extra = "allow" populate_by_name = True # --- END OF NEW SECTION --- # --- 3. 初始化和辅助函数 (所有这部分代码不变) --- app = FastAPI() client = httpx.AsyncClient(timeout=180.0, http2=True) try: nltk.data.find('tokenizers/punkt') except LookupError: logging.info("NLTK 'punkt' tokenizer not found. Downloading..."); nltk.download('punkt', quiet=True); logging.info("'punkt' downloaded successfully.") def split_by_punctuation(text: str, max_len: int = 20): # (此函数不变) fragments = re.split(r'([,.:;!?。,、;:!?.…【】])', text); sentences = [] temp_frag = "" for frag in fragments: if frag in ",.:;!?。,、;:!?.…【】": temp_frag += frag; sentences.append(temp_frag); temp_frag = "" else: if temp_frag: sentences.append(temp_frag) temp_frag = frag if temp_frag: sentences.append(temp_frag) chunks, current_chunk = [], "" for sentence in sentences: sentence = sentence.strip() if not sentence: continue if not current_chunk or len(current_chunk) + len(sentence) <= max_len: current_chunk += (" " if current_chunk else "") + sentence else: chunks.append(current_chunk); current_chunk = sentence while len(current_chunk) > max_len: split_pos = -1 for punc in ",;。!?…【】 ": pos = current_chunk.rfind(punc, 0, max_len) if pos > -1: split_pos = pos; break if split_pos == -1: split_pos = max_len -1 chunks.append(current_chunk[:split_pos+1]); current_chunk = current_chunk[split_pos+1:] if current_chunk: chunks.append(current_chunk) return [c.strip() for c in chunks if c.strip()] def sanitize_and_default_params(params: dict) -> dict: # (此函数不变) allowed_keys = {"text", "text_lang", "ref_audio_path", "prompt_lang", "prompt_text", "media_type", "streaming_mode"} default_values = {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "streaming_mode": False} sanitized_params = {key: params[key] for key in allowed_keys if key in params} for key, default_value in default_values.items(): if key not in sanitized_params: sanitized_params[key] = default_value required_keys = {"text", "ref_audio_path", "prompt_text"} missing_keys = required_keys - set(sanitized_params.keys()) if missing_keys: raise ValueError(f"Missing required fields after sanitization: {', '.join(missing_keys)}") logging.info(f"Sanitized params. Final keys sent to worker: {list(sanitized_params.keys())}") return sanitized_params async def send_task_to_worker(worker_url, payload, index, semaphore): # (此函数不变) async with semaphore: try: final_payload = sanitize_and_default_params(payload) logging.info(f"Task {index}: Sending chunk to {worker_url} (text length: {len(final_payload.get('text', ''))})") response = await client.post(worker_url, json=final_payload, timeout=180.0) if response.status_code == 200: logging.info(f"Task {index}: Successfully received audio data from {worker_url}"); return index, response.content error_body = await response.text(); logging.error(f"Task {index}: Worker {worker_url} returned error status {response.status_code}. Body: {error_body}"); return index, None except ValueError as ve: logging.error(f"Task {index}: Parameter validation failed before sending. Error: {ve}"); return index, None except Exception as e: logging.error(f"Task {index}: Request to {worker_url} failed: {e}"); return index, None async def process_tts_request(request_data: dict): # (此函数不变) text = request_data.get("text", "") logging.info(f"Processing new TTS request. Full text length: {len(text)}") chunks = split_by_punctuation(text) if not chunks: raise HTTPException(status_code=400, detail="Input text is empty or resulted in no chunks.") if len(chunks) <= 1: logging.info("Text resulted in one chunk. Forwarding to a single worker...") payload = request_data; payload["text"] = chunks[0] if chunks else "" semaphore = asyncio.Semaphore(1) _, audio_bytes = await send_task_to_worker(WORKER_URLS[0], payload, 0, semaphore) if audio_bytes: logging.info("Single chunk processed successfully. Returning audio."); return Response(content=audio_bytes, media_type=f"audio/{request_data.get('media_type', 'wav')}") raise HTTPException(status_code=500, detail="Downstream worker failed to process the request.") logging.info(f"Text split into {len(chunks)} chunks. Starting parallel processing with a limit of {MAX_CONCURRENT_REQUESTS} requests at a time...") semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) tasks = [] base_payload = request_data; num_workers = len(WORKER_URLS) for i, chunk in enumerate(chunks): task_payload = {**base_payload, "text": chunk}; worker_url = WORKER_URLS[i % num_workers] tasks.append(send_task_to_worker(worker_url, task_payload, i, semaphore)) results = await asyncio.gather(*tasks); results.sort(key=lambda x: x[0]) valid_audio_bytes_list = [audio for _, audio in results if audio] if not valid_audio_bytes_list: raise HTTPException(status_code=500, detail="All downstream worker tasks failed, could not generate audio.") logging.info(f"Starting to splice {len(valid_audio_bytes_list)} valid audio segments using a temporary file to save memory...") with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as temp_f: media_type = request_data.get("media_type", "wav") first_segment = AudioSegment.from_file(io.BytesIO(valid_audio_bytes_list[0]), format=media_type) first_segment.export(temp_f.name, format=media_type) for i, audio_bytes in enumerate(valid_audio_bytes_list[1:]): try: combined_audio = AudioSegment.from_file(temp_f.name, format=media_type) next_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format=media_type) (combined_audio + next_segment).export(temp_f.name, format=media_type) except Exception as e: logging.warning(f"Splicing failed for audio segment {i+1}. Skipping. Error: {e}") temp_f.seek(0) final_audio_content = temp_f.read() logging.info("Audio splicing complete. Returning final audio file to user.") return Response(content=final_audio_content, media_type=f"audio/{media_type}") # --- 4. API端点 --- @app.get("/") def read_root(): return {"status": "Master TTS Accelerator is running"} @app.post("/update-all-workers") async def update_all_workers_endpoint(paths: WeightsPaths): # (此函数不变) logging.info(f"Received request to update all workers. SOVITS='{paths.sovits_path}', GPT='{paths.gpt_path}'") async def send_update_to_worker(worker_url, sovits_path, gpt_path): base_url = worker_url.replace("/tts", ""); try: sovits_resp = await client.get(f"{base_url}/set_sovits_weights", params={"weights_path": sovits_path}, timeout=60); sovits_resp.raise_for_status() gpt_resp = await client.get(f"{base_url}/set_gpt_weights", params={"weights_path": gpt_path}, timeout=60); gpt_resp.raise_for_status() return base_url, "Success" except Exception as e: return base_url, f"Failed: {e}" tasks = [send_update_to_worker(url, paths.sovits_path, paths.gpt_path) for url in WORKER_URLS] results = await asyncio.gather(*tasks); status_report = {url: status for url, status in results} return {"message": "Update commands sent to all workers.", "status_report": status_report} @app.post("/tts") async def tts_post_endpoint(request: Request): # (此函数不变) logging.info("--- HIT POST /tts endpoint (Sanitization Mode) ---") try: request_data = await request.json() logging.info(f"Received native JSON payload with keys: {list(request_data.keys())}") return await process_tts_request(request_data) except Exception as e: logging.error(f"Failed to parse JSON body or process request: {e}") raise HTTPException(status_code=400, detail=f"Invalid JSON body or processing error: {e}") # --- START OF NEW SECTION: OpenAI Audio API Endpoint --- @app.post("/v1/audio/speech") async def openai_audio_speech_endpoint(request: OpenAIAudioRequest): logging.info("--- HIT POST /v1/audio/speech endpoint (OpenAI Audio API Mode) ---") # 1. 从请求中提取核心文本 text_to_speak = request.input logging.info(f"Extracted text from OpenAI audio format: '{text_to_speak}'") # 2. TODO (可选): 将来可以根据 request.voice 的值选择不同的参考音频 # 例如: if request.voice == "shimmer": ref_audio = "path/to/shimmer.wav" # 3. 构建默认的、完整的TTS参数包 # 这些是您的服务必需的、但OpenAI格式中没有的参数 tts_params = { "text": text_to_speak, "ref_audio_path": "/app/reference_audio/ref_shantianliang_1.wav", # 使用默认参考音频 "prompt_text": "这是一条参考音频,将此音频拖入参考内在添加文本即可合成音色", # 使用默认提示文本 "media_type": request.response_format, # 尊重客户端请求的格式 # 其他参数将由 sanitize_and_default_params 函数自动补齐 } # 4. 调用核心处理逻辑 try: return await process_tts_request(tts_params) except HTTPException as e: raise e except Exception as e: logging.error(f"An unexpected error occurred while processing OpenAI audio request: {e}") raise HTTPException(status_code=500, detail="Internal server error during TTS processing.") # --- END OF NEW SECTION ---