|
|
import asyncio
|
|
|
import json
|
|
|
import random
|
|
|
import math
|
|
|
import os
|
|
|
import logging
|
|
|
import time
|
|
|
import glob
|
|
|
import hashlib
|
|
|
from typing import Optional, List, Dict, Any, AsyncGenerator
|
|
|
|
|
|
import aiofiles
|
|
|
import aiohttp
|
|
|
import websockets
|
|
|
from fastapi import FastAPI, HTTPException, Header, Request, BackgroundTasks
|
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
|
from pydantic import BaseModel, Field
|
|
|
import uvicorn
|
|
|
from contextlib import asynccontextmanager
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
logger = logging.getLogger("DoubaoTTS")
|
|
|
|
|
|
|
|
|
PORT = 1547
|
|
|
HOST = "0.0.0.0"
|
|
|
MODELS_FILE = "models.json"
|
|
|
COOKIE_DIR = "cookie"
|
|
|
AUDIO_DIR = "saved_audio"
|
|
|
AUTH_PASSWORD = os.getenv("PASSWORD", "sk-wei123")
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
|
if os.path.exists(MODELS_FILE):
|
|
|
try:
|
|
|
with open(MODELS_FILE, "r", encoding="utf-8") as f:
|
|
|
engine.voices = json.load(f)
|
|
|
logger.info(f"Loaded {len(engine.voices)} voices from cache.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load cached models: {e}")
|
|
|
|
|
|
|
|
|
asyncio.create_task(engine.fetch_voices())
|
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="Doubao TTS OpenAI API Server", lifespan=lifespan)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CookieManager:
|
|
|
def __init__(self, cookie_dir: str):
|
|
|
self.cookie_dir = cookie_dir
|
|
|
self.cookies: List[str] = []
|
|
|
self.current_index = 0
|
|
|
self.failure_count = 0
|
|
|
self.load_cookies()
|
|
|
|
|
|
def load_cookies(self):
|
|
|
"""Load all .txt files from the cookie directory."""
|
|
|
self.cookies = []
|
|
|
if not os.path.exists(self.cookie_dir):
|
|
|
try:
|
|
|
os.makedirs(self.cookie_dir, exist_ok=True)
|
|
|
logger.info(f"Created cookie directory: {self.cookie_dir}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create cookie directory: {e}")
|
|
|
return
|
|
|
|
|
|
files = glob.glob(os.path.join(self.cookie_dir, "*.txt"))
|
|
|
files.sort()
|
|
|
|
|
|
for f_path in files:
|
|
|
try:
|
|
|
with open(f_path, 'r', encoding='utf-8') as f:
|
|
|
content = f.read().strip()
|
|
|
if content:
|
|
|
self.cookies.append(content)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error reading cookie file {f_path}: {e}")
|
|
|
|
|
|
logger.info(f"Loaded {len(self.cookies)} cookies from {self.cookie_dir}")
|
|
|
|
|
|
def get_cookie(self) -> Optional[str]:
|
|
|
"""Get the current active cookie."""
|
|
|
if not self.cookies:
|
|
|
|
|
|
self.load_cookies()
|
|
|
if not self.cookies:
|
|
|
return None
|
|
|
|
|
|
|
|
|
if self.current_index >= len(self.cookies):
|
|
|
self.current_index = 0
|
|
|
|
|
|
return self.cookies[self.current_index]
|
|
|
|
|
|
def report_failure(self):
|
|
|
"""Report a failure for the current cookie. Switches if fails >= 2 times."""
|
|
|
self.failure_count += 1
|
|
|
logger.warning(f"Cookie index {self.current_index} failed. Count: {self.failure_count}/2")
|
|
|
|
|
|
if self.failure_count >= 2:
|
|
|
self._rotate()
|
|
|
|
|
|
def force_rotate(self):
|
|
|
"""Force switch to the next cookie (e.g. when blocked)."""
|
|
|
logger.warning(f"Cookie index {self.current_index} blocked/rate-limited. Forcing rotation.")
|
|
|
self._rotate()
|
|
|
|
|
|
def report_success(self):
|
|
|
"""Reset failure count on success."""
|
|
|
if self.failure_count > 0:
|
|
|
self.failure_count = 0
|
|
|
|
|
|
def _rotate(self):
|
|
|
"""Switch to the next cookie."""
|
|
|
if not self.cookies:
|
|
|
return
|
|
|
|
|
|
prev_index = self.current_index
|
|
|
self.current_index = (self.current_index + 1) % len(self.cookies)
|
|
|
self.failure_count = 0
|
|
|
logger.info(f"Rotating cookie: {prev_index} -> {self.current_index}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DoubaoTTS:
|
|
|
def __init__(self):
|
|
|
self.ws_url = "wss://ws-samantha.doubao.com/samantha/audio/tts"
|
|
|
self.voice_url = "https://www.doubao.com/alice/user_voice/recommend"
|
|
|
self.voices = []
|
|
|
self.req_count = 0
|
|
|
|
|
|
self.cookie_manager = CookieManager(COOKIE_DIR)
|
|
|
|
|
|
def _get_device_id(self, cookie: str) -> str:
|
|
|
"""Generate a deterministic device ID based on the cookie."""
|
|
|
|
|
|
hash_object = hashlib.md5(cookie.encode())
|
|
|
hex_dig = hash_object.hexdigest()
|
|
|
|
|
|
num_val = int(hex_dig[:15], 16)
|
|
|
|
|
|
return str(num_val).zfill(19)[:19]
|
|
|
|
|
|
def _get_common_params(self, cookie: str) -> str:
|
|
|
"""Generate the common query parameters for requests."""
|
|
|
cid = self._get_device_id(cookie)
|
|
|
return (
|
|
|
f"&mode=0&language=zh&browser_language=zh-CN&device_platform=web"
|
|
|
f"&aid=586861&real_aid=586861&pkg_type=release_version"
|
|
|
f"&device_id={cid}&tea_uuid={cid}&web_id={cid}"
|
|
|
f"&is_new_user=0®ion=CN&sys_region=CN"
|
|
|
f"&use-olympus-account=1&samantha_web=1"
|
|
|
f"&version=1.20.1&version_code=20800&pc_version=2.47.2"
|
|
|
)
|
|
|
|
|
|
async def fetch_voices(self):
|
|
|
"""Fetch available voices from Doubao and save to models.json."""
|
|
|
cookie = self.cookie_manager.get_cookie()
|
|
|
if not cookie:
|
|
|
logger.error("Cannot fetch voices: No cookies available.")
|
|
|
return
|
|
|
|
|
|
headers = {
|
|
|
"Cookie": cookie,
|
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0"
|
|
|
}
|
|
|
|
|
|
|
|
|
tabs = [
|
|
|
(10, "female"),
|
|
|
(10, "male"),
|
|
|
(10, "characters"),
|
|
|
(10, "accent")
|
|
|
]
|
|
|
|
|
|
all_voices_map = {}
|
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
for rec_type, tab_key in tabs:
|
|
|
params = self._get_common_params(cookie).strip('&')
|
|
|
url = f"{self.voice_url}?{params}"
|
|
|
|
|
|
payload = {
|
|
|
"page_index": 1,
|
|
|
"page_size": 200,
|
|
|
"recommend_type": rec_type,
|
|
|
"tab_key": tab_key
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
async with session.post(url, json=payload, headers=headers) as resp:
|
|
|
if resp.status == 200:
|
|
|
data = await resp.json()
|
|
|
if data.get("code") == 0:
|
|
|
v_list = data.get("data", {}).get("ugc_voice_list", [])
|
|
|
for v in v_list:
|
|
|
vid = v.get("style_id")
|
|
|
if vid and vid not in all_voices_map:
|
|
|
|
|
|
tag_list = v.get("tag_list", [])
|
|
|
tags = "|".join([t.get("tag_value", "") for t in tag_list])
|
|
|
name = f"{v.get('name')} {tags}".strip()
|
|
|
|
|
|
all_voices_map[vid] = {
|
|
|
"id": vid,
|
|
|
"object": "model",
|
|
|
"created": int(time.time()),
|
|
|
"owned_by": "doubao",
|
|
|
"name": name,
|
|
|
"language": v.get("language_code"),
|
|
|
"icon": v.get("icon", {}).get("url")
|
|
|
}
|
|
|
logger.info(f"Fetched {len(v_list)} voices for tab '{tab_key}'")
|
|
|
else:
|
|
|
logger.warning(f"API Error for tab '{tab_key}': {data.get('msg')}")
|
|
|
else:
|
|
|
logger.warning(f"HTTP {resp.status} for tab '{tab_key}'")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Exception fetching voices for tab '{tab_key}': {e}")
|
|
|
|
|
|
self.voices = list(all_voices_map.values())
|
|
|
|
|
|
|
|
|
try:
|
|
|
with open(MODELS_FILE, "w", encoding="utf-8") as f:
|
|
|
json.dump(self.voices, f, ensure_ascii=False, indent=2)
|
|
|
logger.info(f"Successfully saved {len(self.voices)} voices to {MODELS_FILE}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to save models.json: {e}")
|
|
|
|
|
|
async def stream_audio(self, text: str, voice: str, speed: float = 1.0, pitch: float = 1.0) -> AsyncGenerator[bytes, None]:
|
|
|
"""Connect to WebSocket and yield audio chunks with retry logic."""
|
|
|
|
|
|
|
|
|
if not os.path.exists(AUDIO_DIR):
|
|
|
try:
|
|
|
os.makedirs(AUDIO_DIR, exist_ok=True)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create audio directory: {e}")
|
|
|
|
|
|
|
|
|
doubao_rate = int((speed - 1) * 100)
|
|
|
doubao_rate = max(-100, min(100, doubao_rate))
|
|
|
doubao_pitch = int(pitch)
|
|
|
|
|
|
max_retries = 3
|
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
cookie = self.cookie_manager.get_cookie()
|
|
|
if not cookie:
|
|
|
logger.error("No cookies available for streaming.")
|
|
|
yield b""
|
|
|
return
|
|
|
|
|
|
params = self._get_common_params(cookie)
|
|
|
ws_url = f"{self.ws_url}?format=aac&speaker={voice}&speech_rate={doubao_rate}&pitch={doubao_pitch}{params}"
|
|
|
|
|
|
headers = {
|
|
|
"Cookie": cookie,
|
|
|
"Origin": "chrome-extension://capohkkfagimodmlpnahjoijgoocdjhd",
|
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0"
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
async with websockets.connect(ws_url, extra_headers=headers) as ws:
|
|
|
|
|
|
msg = {
|
|
|
"event": "text",
|
|
|
"podcast_extra": {"role": ""},
|
|
|
"text": text
|
|
|
}
|
|
|
await ws.send(json.dumps(msg))
|
|
|
await ws.send(json.dumps({"event": "finish"}))
|
|
|
|
|
|
first_chunk_received = False
|
|
|
|
|
|
|
|
|
try:
|
|
|
for f in os.listdir(AUDIO_DIR):
|
|
|
f_path = os.path.join(AUDIO_DIR, f)
|
|
|
if os.path.isfile(f_path):
|
|
|
os.remove(f_path)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to clean up old audio files: {e}")
|
|
|
|
|
|
|
|
|
timestamp = int(time.time() * 1000)
|
|
|
filename = f"{timestamp}_{voice}.aac"
|
|
|
filepath = os.path.join(AUDIO_DIR, filename)
|
|
|
|
|
|
file_written = False
|
|
|
try:
|
|
|
async with aiofiles.open(filepath, "wb") as f_out:
|
|
|
async for message in ws:
|
|
|
if isinstance(message, bytes):
|
|
|
if not first_chunk_received:
|
|
|
first_chunk_received = True
|
|
|
self.cookie_manager.report_success()
|
|
|
logger.info(f"Streaming and saving audio to: {filepath}")
|
|
|
|
|
|
yield message
|
|
|
await f_out.write(message)
|
|
|
file_written = True
|
|
|
|
|
|
elif isinstance(message, str):
|
|
|
|
|
|
try:
|
|
|
msg_json = json.loads(message)
|
|
|
code = msg_json.get("code")
|
|
|
if code and code != 0:
|
|
|
error_msg = msg_json.get("message", "Unknown Error")
|
|
|
logger.error(f"Doubao API Error (Code {code}): {error_msg}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "block" in str(error_msg).lower() or code in [710022002, 671000003]:
|
|
|
self.cookie_manager.force_rotate()
|
|
|
raise Exception(f"Blocked/Limited by server ({code}): {error_msg}")
|
|
|
except json.JSONDecodeError:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if file_written:
|
|
|
return
|
|
|
else:
|
|
|
|
|
|
logger.warning("Connection closed without receiving audio data.")
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if not file_written and os.path.exists(filepath):
|
|
|
try:
|
|
|
os.remove(filepath)
|
|
|
logger.info(f"Removed empty audio file: {filepath}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to remove empty file: {e}")
|
|
|
|
|
|
except websockets.exceptions.InvalidStatusCode as e:
|
|
|
logger.warning(f"WebSocket Handshake failed (Attempt {attempt+1}): {e.status_code}")
|
|
|
self.cookie_manager.report_failure()
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"WebSocket error (Attempt {attempt+1}): {e}")
|
|
|
self.cookie_manager.report_failure()
|
|
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
|
|
|
|
|
logger.error("All retry attempts failed.")
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
engine = DoubaoTTS()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIRequest(BaseModel):
|
|
|
model: Optional[str] = "tts-1"
|
|
|
input: str
|
|
|
voice: str
|
|
|
response_format: Optional[str] = "aac"
|
|
|
speed: Optional[float] = 1.0
|
|
|
pitch: Optional[float] = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
return {"status": "running", "service": "Doubao TTS OpenAI Server", "cookies_loaded": len(engine.cookie_manager.cookies)}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
|
async def get_models():
|
|
|
return {"object": "list", "data": engine.voices}
|
|
|
|
|
|
@app.get("/v1/audio/speech")
|
|
|
async def check_speech_endpoint():
|
|
|
"""
|
|
|
Allow GET requests for connectivity checks.
|
|
|
"""
|
|
|
return {"status": "ok", "message": "Speech endpoint is ready"}
|
|
|
|
|
|
@app.get("/v1/audio/speech/stream")
|
|
|
async def stream_speech_get(
|
|
|
input: str,
|
|
|
voice: str,
|
|
|
speed: float = 1.0,
|
|
|
pitch: float = 0.0,
|
|
|
response_format: str = "aac",
|
|
|
token: Optional[str] = None
|
|
|
):
|
|
|
"""
|
|
|
GET endpoint for direct streaming (e.g. for <audio src="...">).
|
|
|
"""
|
|
|
|
|
|
if AUTH_PASSWORD and token != AUTH_PASSWORD:
|
|
|
raise HTTPException(status_code=401, detail="Invalid authentication token")
|
|
|
|
|
|
if not input:
|
|
|
raise HTTPException(status_code=400, detail="Input text is required")
|
|
|
|
|
|
media_type = "audio/aac"
|
|
|
if response_format == "mp3":
|
|
|
media_type = "audio/mpeg"
|
|
|
|
|
|
return StreamingResponse(
|
|
|
engine.stream_audio(input, voice, speed, pitch),
|
|
|
media_type=media_type
|
|
|
)
|
|
|
|
|
|
@app.post("/v1/audio/speech")
|
|
|
async def create_speech(req: OpenAIRequest):
|
|
|
"""
|
|
|
OpenAI-compatible speech generation endpoint.
|
|
|
"""
|
|
|
if not req.input:
|
|
|
raise HTTPException(status_code=400, detail="Input text is required")
|
|
|
|
|
|
media_type = "audio/aac"
|
|
|
if req.response_format == "mp3":
|
|
|
media_type = "audio/mpeg"
|
|
|
|
|
|
return StreamingResponse(
|
|
|
engine.stream_audio(req.input, req.voice, req.speed, req.pitch),
|
|
|
media_type=media_type
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print(f"Starting server on port {PORT}...")
|
|
|
print(f"Please place cookie files in '{COOKIE_DIR}/' directory (e.g. 1.txt, 2.txt)")
|
|
|
uvicorn.run(app, host=HOST, port=PORT) |