maltose1 commited on
Commit
3a009bb
·
verified ·
1 Parent(s): d13bc24

Upload 4 files

Browse files
Files changed (1) hide show
  1. tts-server.py +10 -3
tts-server.py CHANGED
@@ -244,7 +244,7 @@ class DoubaoTTS:
244
  except Exception as e:
245
  logger.error(f"Failed to save models.json: {e}")
246
 
247
- async def stream_audio(self, text: str, voice: str, speed: float = 1.0, pitch: float = 1.0) -> AsyncGenerator[bytes, None]:
248
  """Connect to WebSocket and yield audio chunks with retry logic."""
249
 
250
  # Map OpenAI speed (0.25 - 4.0) to Doubao rate (-100 to 100)
@@ -262,7 +262,7 @@ class DoubaoTTS:
262
  return
263
 
264
  params = self._get_common_params()
265
- ws_url = f"{self.ws_url}?format=aac&speaker={voice}&speech_rate={doubao_rate}&pitch={doubao_pitch}{params}"
266
 
267
  headers = {
268
  "Cookie": cookie,
@@ -359,8 +359,15 @@ async def create_speech(req: OpenAIRequest, token: str = Depends(verify_token)):
359
  if req.response_format == "mp3":
360
  media_type = "audio/mpeg"
361
 
 
 
 
 
 
 
 
362
  return StreamingResponse(
363
- engine.stream_audio(req.input, req.voice, req.speed, req.pitch),
364
  media_type=media_type
365
  )
366
 
 
244
  except Exception as e:
245
  logger.error(f"Failed to save models.json: {e}")
246
 
247
+ async def stream_audio(self, text: str, voice: str, format: str = "aac", speed: float = 1.0, pitch: float = 1.0) -> AsyncGenerator[bytes, None]:
248
  """Connect to WebSocket and yield audio chunks with retry logic."""
249
 
250
  # Map OpenAI speed (0.25 - 4.0) to Doubao rate (-100 to 100)
 
262
  return
263
 
264
  params = self._get_common_params()
265
+ ws_url = f"{self.ws_url}?format={format}&speaker={voice}&speech_rate={doubao_rate}&pitch={doubao_pitch}{params}"
266
 
267
  headers = {
268
  "Cookie": cookie,
 
359
  if req.response_format == "mp3":
360
  media_type = "audio/mpeg"
361
 
362
+ # Determine format to request from Doubao
363
+ target_format = "aac"
364
+ if req.response_format == "mp3":
365
+ target_format = "mp3"
366
+ elif req.response_format == "pcm":
367
+ target_format = "pcm"
368
+
369
  return StreamingResponse(
370
+ engine.stream_audio(req.input, req.voice, target_format, req.speed, req.pitch),
371
  media_type=media_type
372
  )
373