baseai / app.py
Awsl1111ddd's picture
Update app.py
1c76b8c verified
import asyncio
import io
import nltk
import logging
import re
import tempfile
from fastapi import FastAPI, HTTPException, Response, Request
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import httpx
from pydub import AudioSegment
# --- 1. 配置 ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
WORKER_URLS = [
"https://awsl1111ddd-114514test.hf.space/tts",
"https://awsl1111ddd-114514.hf.space/tts",
"https://awsl1111ddd-dev.hf.space/tts",
"https://awsl1111ddd-cialloapi.hf.space/tts",
"https://qwed2025-devfork.hf.space/tts",
"https://qwed2025-Qwer.hf.space/tts",
"https://qwed2025-MY-TESTSPACE.hf.space/tts",
"https://johnali202508-aitts-test.hf.space/tts",
"https://johnali202508-ciallo.hf.space/tts",
"https://johnali202508-ai.hf.space/tts",
]
MAX_CONCURRENT_REQUESTS = 8
# --- 2. Pydantic模型 ---
class WeightsPaths(BaseModel):
sovits_path: str; gpt_path: str
# --- START OF NEW SECTION: OpenAI Audio API Models ---
class OpenAIAudioRequest(BaseModel):
model: str = "tts-1" # 兼容字段
input: str # 这是我们要合成的文本
voice: str = "alloy" # 兼容字段, 将来可用于选择不同的参考音频
response_format: str = Field(default="wav", alias="response_format")
speed: float = 1.0 # 兼容字段
class Config:
extra = "allow"
populate_by_name = True
# --- END OF NEW SECTION ---
# --- 3. 初始化和辅助函数 (所有这部分代码不变) ---
app = FastAPI()
client = httpx.AsyncClient(timeout=180.0, http2=True)
try: nltk.data.find('tokenizers/punkt')
except LookupError:
logging.info("NLTK 'punkt' tokenizer not found. Downloading..."); nltk.download('punkt', quiet=True); logging.info("'punkt' downloaded successfully.")
def split_by_punctuation(text: str, max_len: int = 20):
# (此函数不变)
fragments = re.split(r'([,.:;!?。,、;:!?.…【】])', text); sentences = []
temp_frag = ""
for frag in fragments:
if frag in ",.:;!?。,、;:!?.…【】": temp_frag += frag; sentences.append(temp_frag); temp_frag = ""
else:
if temp_frag: sentences.append(temp_frag)
temp_frag = frag
if temp_frag: sentences.append(temp_frag)
chunks, current_chunk = [], ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence: continue
if not current_chunk or len(current_chunk) + len(sentence) <= max_len: current_chunk += (" " if current_chunk else "") + sentence
else: chunks.append(current_chunk); current_chunk = sentence
while len(current_chunk) > max_len:
split_pos = -1
for punc in ",;。!?…【】 ":
pos = current_chunk.rfind(punc, 0, max_len)
if pos > -1: split_pos = pos; break
if split_pos == -1: split_pos = max_len -1
chunks.append(current_chunk[:split_pos+1]); current_chunk = current_chunk[split_pos+1:]
if current_chunk: chunks.append(current_chunk)
return [c.strip() for c in chunks if c.strip()]
def sanitize_and_default_params(params: dict) -> dict:
# (此函数不变)
allowed_keys = {"text", "text_lang", "ref_audio_path", "prompt_lang", "prompt_text", "media_type", "streaming_mode"}
default_values = {"text_lang": "zh", "prompt_lang": "zh", "media_type": "wav", "streaming_mode": False}
sanitized_params = {key: params[key] for key in allowed_keys if key in params}
for key, default_value in default_values.items():
if key not in sanitized_params: sanitized_params[key] = default_value
required_keys = {"text", "ref_audio_path", "prompt_text"}
missing_keys = required_keys - set(sanitized_params.keys())
if missing_keys: raise ValueError(f"Missing required fields after sanitization: {', '.join(missing_keys)}")
logging.info(f"Sanitized params. Final keys sent to worker: {list(sanitized_params.keys())}")
return sanitized_params
async def send_task_to_worker(worker_url, payload, index, semaphore):
# (此函数不变)
async with semaphore:
try:
final_payload = sanitize_and_default_params(payload)
logging.info(f"Task {index}: Sending chunk to {worker_url} (text length: {len(final_payload.get('text', ''))})")
response = await client.post(worker_url, json=final_payload, timeout=180.0)
if response.status_code == 200: logging.info(f"Task {index}: Successfully received audio data from {worker_url}"); return index, response.content
error_body = await response.text(); logging.error(f"Task {index}: Worker {worker_url} returned error status {response.status_code}. Body: {error_body}"); return index, None
except ValueError as ve: logging.error(f"Task {index}: Parameter validation failed before sending. Error: {ve}"); return index, None
except Exception as e: logging.error(f"Task {index}: Request to {worker_url} failed: {e}"); return index, None
async def process_tts_request(request_data: dict):
# (此函数不变)
text = request_data.get("text", "")
logging.info(f"Processing new TTS request. Full text length: {len(text)}")
chunks = split_by_punctuation(text)
if not chunks: raise HTTPException(status_code=400, detail="Input text is empty or resulted in no chunks.")
if len(chunks) <= 1:
logging.info("Text resulted in one chunk. Forwarding to a single worker...")
payload = request_data; payload["text"] = chunks[0] if chunks else ""
semaphore = asyncio.Semaphore(1)
_, audio_bytes = await send_task_to_worker(WORKER_URLS[0], payload, 0, semaphore)
if audio_bytes: logging.info("Single chunk processed successfully. Returning audio."); return Response(content=audio_bytes, media_type=f"audio/{request_data.get('media_type', 'wav')}")
raise HTTPException(status_code=500, detail="Downstream worker failed to process the request.")
logging.info(f"Text split into {len(chunks)} chunks. Starting parallel processing with a limit of {MAX_CONCURRENT_REQUESTS} requests at a time...")
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
tasks = []
base_payload = request_data; num_workers = len(WORKER_URLS)
for i, chunk in enumerate(chunks):
task_payload = {**base_payload, "text": chunk}; worker_url = WORKER_URLS[i % num_workers]
tasks.append(send_task_to_worker(worker_url, task_payload, i, semaphore))
results = await asyncio.gather(*tasks); results.sort(key=lambda x: x[0])
valid_audio_bytes_list = [audio for _, audio in results if audio]
if not valid_audio_bytes_list: raise HTTPException(status_code=500, detail="All downstream worker tasks failed, could not generate audio.")
logging.info(f"Starting to splice {len(valid_audio_bytes_list)} valid audio segments using a temporary file to save memory...")
with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as temp_f:
media_type = request_data.get("media_type", "wav")
first_segment = AudioSegment.from_file(io.BytesIO(valid_audio_bytes_list[0]), format=media_type)
first_segment.export(temp_f.name, format=media_type)
for i, audio_bytes in enumerate(valid_audio_bytes_list[1:]):
try:
combined_audio = AudioSegment.from_file(temp_f.name, format=media_type)
next_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format=media_type)
(combined_audio + next_segment).export(temp_f.name, format=media_type)
except Exception as e: logging.warning(f"Splicing failed for audio segment {i+1}. Skipping. Error: {e}")
temp_f.seek(0)
final_audio_content = temp_f.read()
logging.info("Audio splicing complete. Returning final audio file to user.")
return Response(content=final_audio_content, media_type=f"audio/{media_type}")
# --- 4. API端点 ---
@app.get("/")
def read_root(): return {"status": "Master TTS Accelerator is running"}
@app.post("/update-all-workers")
async def update_all_workers_endpoint(paths: WeightsPaths):
# (此函数不变)
logging.info(f"Received request to update all workers. SOVITS='{paths.sovits_path}', GPT='{paths.gpt_path}'")
async def send_update_to_worker(worker_url, sovits_path, gpt_path):
base_url = worker_url.replace("/tts", "");
try:
sovits_resp = await client.get(f"{base_url}/set_sovits_weights", params={"weights_path": sovits_path}, timeout=60); sovits_resp.raise_for_status()
gpt_resp = await client.get(f"{base_url}/set_gpt_weights", params={"weights_path": gpt_path}, timeout=60); gpt_resp.raise_for_status()
return base_url, "Success"
except Exception as e: return base_url, f"Failed: {e}"
tasks = [send_update_to_worker(url, paths.sovits_path, paths.gpt_path) for url in WORKER_URLS]
results = await asyncio.gather(*tasks); status_report = {url: status for url, status in results}
return {"message": "Update commands sent to all workers.", "status_report": status_report}
@app.post("/tts")
async def tts_post_endpoint(request: Request):
# (此函数不变)
logging.info("--- HIT POST /tts endpoint (Sanitization Mode) ---")
try:
request_data = await request.json()
logging.info(f"Received native JSON payload with keys: {list(request_data.keys())}")
return await process_tts_request(request_data)
except Exception as e:
logging.error(f"Failed to parse JSON body or process request: {e}")
raise HTTPException(status_code=400, detail=f"Invalid JSON body or processing error: {e}")
# --- START OF NEW SECTION: OpenAI Audio API Endpoint ---
@app.post("/v1/audio/speech")
async def openai_audio_speech_endpoint(request: OpenAIAudioRequest):
logging.info("--- HIT POST /v1/audio/speech endpoint (OpenAI Audio API Mode) ---")
# 1. 从请求中提取核心文本
text_to_speak = request.input
logging.info(f"Extracted text from OpenAI audio format: '{text_to_speak}'")
# 2. TODO (可选): 将来可以根据 request.voice 的值选择不同的参考音频
# 例如: if request.voice == "shimmer": ref_audio = "path/to/shimmer.wav"
# 3. 构建默认的、完整的TTS参数包
# 这些是您的服务必需的、但OpenAI格式中没有的参数
tts_params = {
"text": text_to_speak,
"ref_audio_path": "/app/reference_audio/ref_shantianliang_1.wav", # 使用默认参考音频
"prompt_text": "这是一条参考音频,将此音频拖入参考内在添加文本即可合成音色", # 使用默认提示文本
"media_type": request.response_format, # 尊重客户端请求的格式
# 其他参数将由 sanitize_and_default_params 函数自动补齐
}
# 4. 调用核心处理逻辑
try:
return await process_tts_request(tts_params)
except HTTPException as e:
raise e
except Exception as e:
logging.error(f"An unexpected error occurred while processing OpenAI audio request: {e}")
raise HTTPException(status_code=500, detail="Internal server error during TTS processing.")
# --- END OF NEW SECTION ---