| import asyncio
|
| import json
|
| import random
|
| import math
|
| import os
|
| import logging
|
| import time
|
| import glob
|
| from typing import Optional, List, Dict, Any, AsyncGenerator
|
|
|
| import aiohttp
|
| import websockets
|
| from fastapi import FastAPI, HTTPException, Header, Request, BackgroundTasks, Depends
|
| from fastapi.responses import StreamingResponse, JSONResponse
|
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| from pydantic import BaseModel, Field
|
| import uvicorn
|
| from contextlib import asynccontextmanager
|
|
|
|
|
|
|
|
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| )
|
| logger = logging.getLogger("DoubaoTTS")
|
|
|
|
|
| PORT = 1547
|
| HOST = "0.0.0.0"
|
| MODELS_FILE = "models.json"
|
| COOKIE_DIR = "cookie"
|
| AUTH_PASSWORD = os.getenv("PASSWORD", "sk-wei123")
|
|
|
|
|
| @asynccontextmanager
|
| async def lifespan(app: FastAPI):
|
|
|
| if os.path.exists(MODELS_FILE):
|
| try:
|
| with open(MODELS_FILE, "r", encoding="utf-8") as f:
|
| engine.voices = json.load(f)
|
| logger.info(f"Loaded {len(engine.voices)} voices from cache.")
|
| except Exception as e:
|
| logger.error(f"Failed to load cached models: {e}")
|
|
|
|
|
| asyncio.create_task(engine.fetch_voices())
|
| yield
|
|
|
| app = FastAPI(title="Doubao TTS OpenAI API Server", lifespan=lifespan)
|
| security = HTTPBearer()
|
|
|
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| if credentials.credentials != AUTH_PASSWORD:
|
| raise HTTPException(status_code=401, detail="Invalid authentication token")
|
| return credentials.credentials
|
|
|
|
|
|
|
|
|
|
|
| class CookieManager:
|
| def __init__(self, cookie_dir: str):
|
| self.cookie_dir = cookie_dir
|
| self.cookies: List[str] = []
|
| self.current_index = 0
|
| self.failure_count = 0
|
| self.load_cookies()
|
|
|
| def load_cookies(self):
|
| """Load all .txt files from the cookie directory."""
|
| self.cookies = []
|
| if not os.path.exists(self.cookie_dir):
|
| try:
|
| os.makedirs(self.cookie_dir, exist_ok=True)
|
| logger.info(f"Created cookie directory: {self.cookie_dir}")
|
| except Exception as e:
|
| logger.error(f"Failed to create cookie directory: {e}")
|
| return
|
|
|
| files = glob.glob(os.path.join(self.cookie_dir, "*.txt"))
|
| files.sort()
|
|
|
| for f_path in files:
|
| try:
|
| with open(f_path, 'r', encoding='utf-8') as f:
|
| content = f.read().strip()
|
| if content:
|
| self.cookies.append(content)
|
| except Exception as e:
|
| logger.error(f"Error reading cookie file {f_path}: {e}")
|
|
|
| logger.info(f"Loaded {len(self.cookies)} cookies from {self.cookie_dir}")
|
|
|
| def get_cookie(self) -> Optional[str]:
|
| """Get the current active cookie."""
|
| if not self.cookies:
|
|
|
| self.load_cookies()
|
| if not self.cookies:
|
| return None
|
|
|
|
|
| if self.current_index >= len(self.cookies):
|
| self.current_index = 0
|
|
|
| return self.cookies[self.current_index]
|
|
|
| def report_failure(self):
|
| """Report a failure for the current cookie. Switches if fails >= 2 times."""
|
| self.failure_count += 1
|
| logger.warning(f"Cookie index {self.current_index} failed. Count: {self.failure_count}/2")
|
|
|
| if self.failure_count >= 2:
|
| self._rotate()
|
|
|
| def report_success(self):
|
| """Reset failure count on success."""
|
| if self.failure_count > 0:
|
| self.failure_count = 0
|
|
|
| def _rotate(self):
|
| """Switch to the next cookie."""
|
| if not self.cookies:
|
| return
|
|
|
| prev_index = self.current_index
|
| self.current_index = (self.current_index + 1) % len(self.cookies)
|
| self.failure_count = 0
|
| logger.info(f"Rotating cookie: {prev_index} -> {self.current_index}")
|
|
|
|
|
|
|
|
|
|
|
| class DoubaoTTS:
|
| def __init__(self):
|
| self.ws_url = "wss://ws-samantha.doubao.com/samantha/audio/tts"
|
| self.voice_url = "https://www.doubao.com/alice/user_voice/recommend"
|
| self.voices = []
|
| self.req_count = 0
|
| self.current_id = self._generate_id()
|
| self.cookie_manager = CookieManager(COOKIE_DIR)
|
|
|
| def _generate_id(self) -> str:
|
| """Generate a random device/session ID."""
|
| num1 = math.floor(1e8 + 9e8 * random.random())
|
| num2 = math.floor(1e8 + 9e8 * random.random())
|
| return f"{num1}{num2}"
|
|
|
| def _get_common_params(self) -> str:
|
| """Generate the common query parameters for requests."""
|
| self.req_count += 1
|
| if self.req_count > 5:
|
| self.req_count = 0
|
| self.current_id = self._generate_id()
|
|
|
| cid = self.current_id
|
| return (
|
| f"&mode=0&language=zh&browser_language=zh-CN&device_platform=web"
|
| f"&aid=586861&real_aid=586861&pkg_type=release_version"
|
| f"&device_id={cid}&tea_uuid={cid}&web_id={cid}"
|
| f"&is_new_user=0®ion=CN&sys_region=CN"
|
| f"&use-olympus-account=1&samantha_web=1"
|
| f"&version=1.20.1&version_code=20800&pc_version=2.47.2"
|
| )
|
|
|
| async def fetch_voices(self):
|
| """Fetch available voices from Doubao and save to models.json."""
|
| cookie = self.cookie_manager.get_cookie()
|
| if not cookie:
|
| logger.error("Cannot fetch voices: No cookies available.")
|
| return
|
|
|
| headers = {
|
| "Cookie": cookie,
|
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0"
|
| }
|
|
|
|
|
| tabs = [
|
| (10, "female"),
|
| (10, "male"),
|
| (10, "characters"),
|
| (10, "accent")
|
| ]
|
|
|
| all_voices_map = {}
|
|
|
| async with aiohttp.ClientSession() as session:
|
| for rec_type, tab_key in tabs:
|
| params = self._get_common_params().strip('&')
|
| url = f"{self.voice_url}?{params}"
|
|
|
| payload = {
|
| "page_index": 1,
|
| "page_size": 200,
|
| "recommend_type": rec_type,
|
| "tab_key": tab_key
|
| }
|
|
|
| try:
|
| async with session.post(url, json=payload, headers=headers) as resp:
|
| if resp.status == 200:
|
| data = await resp.json()
|
| if data.get("code") == 0:
|
| v_list = data.get("data", {}).get("ugc_voice_list", [])
|
| for v in v_list:
|
| vid = v.get("style_id")
|
| if vid and vid not in all_voices_map:
|
|
|
| tag_list = v.get("tag_list", [])
|
| tags = "|".join([t.get("tag_value", "") for t in tag_list])
|
| name = f"{v.get('name')} {tags}".strip()
|
|
|
| all_voices_map[vid] = {
|
| "id": vid,
|
| "object": "model",
|
| "created": int(time.time()),
|
| "owned_by": "doubao",
|
| "name": name,
|
| "language": v.get("language_code"),
|
| "icon": v.get("icon", {}).get("url")
|
| }
|
| logger.info(f"Fetched {len(v_list)} voices for tab '{tab_key}'")
|
| else:
|
| logger.warning(f"API Error for tab '{tab_key}': {data.get('msg')}")
|
| else:
|
| logger.warning(f"HTTP {resp.status} for tab '{tab_key}'")
|
| except Exception as e:
|
| logger.error(f"Exception fetching voices for tab '{tab_key}': {e}")
|
|
|
| self.voices = list(all_voices_map.values())
|
|
|
|
|
| try:
|
| with open(MODELS_FILE, "w", encoding="utf-8") as f:
|
| json.dump(self.voices, f, ensure_ascii=False, indent=2)
|
| logger.info(f"Successfully saved {len(self.voices)} voices to {MODELS_FILE}")
|
| except Exception as e:
|
| logger.error(f"Failed to save models.json: {e}")
|
|
|
| async def stream_audio(self, text: str, voice: str, speed: float = 1.0, pitch: float = 1.0) -> AsyncGenerator[bytes, None]:
|
| """Connect to WebSocket and yield audio chunks with retry logic."""
|
|
|
|
|
| doubao_rate = int((speed - 1) * 100)
|
| doubao_rate = max(-100, min(100, doubao_rate))
|
| doubao_pitch = int(pitch)
|
|
|
| max_retries = 3
|
|
|
| for attempt in range(max_retries):
|
| cookie = self.cookie_manager.get_cookie()
|
| if not cookie:
|
| logger.error("No cookies available for streaming.")
|
| yield b""
|
| return
|
|
|
| params = self._get_common_params()
|
| ws_url = f"{self.ws_url}?format=aac&speaker={voice}&speech_rate={doubao_rate}&pitch={doubao_pitch}{params}"
|
|
|
| headers = {
|
| "Cookie": cookie,
|
| "Origin": "chrome-extension://capohkkfagimodmlpnahjoijgoocdjhd",
|
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0"
|
| }
|
|
|
| try:
|
|
|
|
|
|
|
| async with websockets.connect(ws_url, extra_headers=headers) as ws:
|
|
|
| msg = {
|
| "event": "text",
|
| "podcast_extra": {"role": ""},
|
| "text": text
|
| }
|
| await ws.send(json.dumps(msg))
|
| await ws.send(json.dumps({"event": "finish"}))
|
|
|
| first_chunk_received = False
|
|
|
| async for message in ws:
|
| if isinstance(message, bytes):
|
| if not first_chunk_received:
|
| first_chunk_received = True
|
| self.cookie_manager.report_success()
|
| yield message
|
| elif isinstance(message, str):
|
|
|
|
|
| pass
|
|
|
|
|
| return
|
|
|
| except websockets.exceptions.InvalidStatusCode as e:
|
| logger.warning(f"WebSocket Handshake failed (Attempt {attempt+1}): {e.status_code}")
|
| self.cookie_manager.report_failure()
|
|
|
|
|
| except Exception as e:
|
| logger.error(f"WebSocket error (Attempt {attempt+1}): {e}")
|
| self.cookie_manager.report_failure()
|
|
|
|
|
| logger.error("All retry attempts failed.")
|
| raise HTTPException(status_code=500, detail="Failed to generate audio after retries.")
|
|
|
|
|
| engine = DoubaoTTS()
|
|
|
|
|
|
|
|
|
|
|
| class OpenAIRequest(BaseModel):
|
| model: Optional[str] = "tts-1"
|
| input: str
|
| voice: str
|
| response_format: Optional[str] = "aac"
|
| speed: Optional[float] = 1.0
|
| pitch: Optional[float] = 0.0
|
|
|
|
|
|
|
|
|
|
|
| @app.get("/")
|
| async def root():
|
| return {"status": "running", "service": "Doubao TTS OpenAI Server", "cookies_loaded": len(engine.cookie_manager.cookies)}
|
|
|
| @app.get("/v1/models")
|
| async def get_models():
|
| return {"object": "list", "data": engine.voices}
|
|
|
| @app.get("/v1/audio/speech")
|
| async def check_speech_endpoint():
|
| """
|
| Allow GET requests for connectivity checks.
|
| """
|
| return {"status": "ok", "message": "Speech endpoint is ready"}
|
|
|
| @app.post("/v1/audio/speech")
|
| async def create_speech(req: OpenAIRequest, token: str = Depends(verify_token)):
|
| """
|
| OpenAI-compatible speech generation endpoint.
|
| """
|
| if not req.input:
|
| raise HTTPException(status_code=400, detail="Input text is required")
|
|
|
| media_type = "audio/aac"
|
| if req.response_format == "mp3":
|
| media_type = "audio/mpeg"
|
|
|
| return StreamingResponse(
|
| engine.stream_audio(req.input, req.voice, req.speed, req.pitch),
|
| media_type=media_type
|
| )
|
|
|
| if __name__ == "__main__":
|
| print(f"Starting server on port {PORT}...")
|
| print(f"Please place cookie files in '{COOKIE_DIR}/' directory (e.g. 1.txt, 2.txt)")
|
| uvicorn.run(app, host=HOST, port=PORT) |