| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import _thread as thread |
| import base64 |
| import datetime |
| import hashlib |
| import hmac |
| import json |
| import queue |
| import re |
| import ssl |
| import time |
| from abc import ABC |
| from datetime import datetime |
| from time import mktime |
| from typing import Annotated, Literal |
| from urllib.parse import urlencode |
| from wsgiref.handlers import format_date_time |
|
|
| import httpx |
| import ormsgpack |
| import requests |
| import websocket |
| from pydantic import BaseModel, conint |
|
|
| from rag.utils import num_tokens_from_string |
|
|
|
|
| class ServeReferenceAudio(BaseModel): |
| audio: bytes |
| text: str |
|
|
|
|
| class ServeTTSRequest(BaseModel): |
| text: str |
| chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 |
| |
| format: Literal["wav", "pcm", "mp3"] = "mp3" |
| mp3_bitrate: Literal[64, 128, 192] = 128 |
| |
| references: list[ServeReferenceAudio] = [] |
| |
| |
| |
| reference_id: str | None = None |
| |
| normalize: bool = True |
| |
| latency: Literal["normal", "balanced"] = "normal" |
|
|
|
|
| class Base(ABC): |
| def __init__(self, key, model_name, base_url): |
| pass |
|
|
| def tts(self, audio): |
| pass |
|
|
| def normalize_text(self, text): |
| return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) |
|
|
|
|
| class FishAudioTTS(Base): |
| def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): |
| if not base_url: |
| base_url = "https://api.fish.audio/v1/tts" |
| key = json.loads(key) |
| self.headers = { |
| "api-key": key.get("fish_audio_ak"), |
| "content-type": "application/msgpack", |
| } |
| self.ref_id = key.get("fish_audio_refid") |
| self.base_url = base_url |
|
|
| def tts(self, text): |
| from http import HTTPStatus |
|
|
| text = self.normalize_text(text) |
| request = ServeTTSRequest(text=text, reference_id=self.ref_id) |
|
|
| with httpx.Client() as client: |
| try: |
| with client.stream( |
| method="POST", |
| url=self.base_url, |
| content=ormsgpack.packb( |
| request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC |
| ), |
| headers=self.headers, |
| timeout=None, |
| ) as response: |
| if response.status_code == HTTPStatus.OK: |
| for chunk in response.iter_bytes(): |
| yield chunk |
| else: |
| response.raise_for_status() |
|
|
| yield num_tokens_from_string(text) |
|
|
| except httpx.HTTPStatusError as e: |
| raise RuntimeError(f"**ERROR**: {e}") |
|
|
|
|
| class QwenTTS(Base): |
| def __init__(self, key, model_name, base_url=""): |
| import dashscope |
|
|
| self.model_name = model_name |
| dashscope.api_key = key |
|
|
| def tts(self, text): |
| from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse |
| from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult |
| from collections import deque |
|
|
| class Callback(ResultCallback): |
| def __init__(self) -> None: |
| self.dque = deque() |
|
|
| def _run(self): |
| while True: |
| if not self.dque: |
| time.sleep(0) |
| continue |
| val = self.dque.popleft() |
| if val: |
| yield val |
| else: |
| break |
|
|
| def on_open(self): |
| pass |
|
|
| def on_complete(self): |
| self.dque.append(None) |
|
|
| def on_error(self, response: SpeechSynthesisResponse): |
| raise RuntimeError(str(response)) |
|
|
| def on_close(self): |
| pass |
|
|
| def on_event(self, result: SpeechSynthesisResult): |
| if result.get_audio_frame() is not None: |
| self.dque.append(result.get_audio_frame()) |
|
|
| text = self.normalize_text(text) |
| callback = Callback() |
| SpeechSynthesizer.call(model=self.model_name, |
| text=text, |
| callback=callback, |
| format="mp3") |
| try: |
| for data in callback._run(): |
| yield data |
| yield num_tokens_from_string(text) |
|
|
| except Exception as e: |
| raise RuntimeError(f"**ERROR**: {e}") |
|
|
|
|
| class OpenAITTS(Base): |
| def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): |
| if not base_url: base_url = "https://api.openai.com/v1" |
| self.api_key = key |
| self.model_name = model_name |
| self.base_url = base_url |
| self.headers = { |
| "Authorization": f"Bearer {self.api_key}", |
| "Content-Type": "application/json" |
| } |
|
|
| def tts(self, text, voice="alloy"): |
| text = self.normalize_text(text) |
| payload = { |
| "model": self.model_name, |
| "voice": voice, |
| "input": text |
| } |
|
|
| response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True) |
|
|
| if response.status_code != 200: |
| raise Exception(f"**Error**: {response.status_code}, {response.text}") |
| for chunk in response.iter_content(): |
| if chunk: |
| yield chunk |
|
|
|
|
| class SparkTTS: |
| STATUS_FIRST_FRAME = 0 |
| STATUS_CONTINUE_FRAME = 1 |
| STATUS_LAST_FRAME = 2 |
|
|
| def __init__(self, key, model_name, base_url=""): |
| key = json.loads(key) |
| self.APPID = key.get("spark_app_id", "xxxxxxx") |
| self.APISecret = key.get("spark_api_secret", "xxxxxxx") |
| self.APIKey = key.get("spark_api_key", "xxxxxx") |
| self.model_name = model_name |
| self.CommonArgs = {"app_id": self.APPID} |
| self.audio_queue = queue.Queue() |
|
|
| |
|
|
| |
| def create_url(self): |
| url = 'wss://tts-api.xfyun.cn/v2/tts' |
| now = datetime.now() |
| date = format_date_time(mktime(now.timetuple())) |
| signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" |
| signature_origin += "date: " + date + "\n" |
| signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" |
| signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), |
| digestmod=hashlib.sha256).digest() |
| signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') |
| authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( |
| self.APIKey, "hmac-sha256", "host date request-line", signature_sha) |
| authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') |
| v = { |
| "authorization": authorization, |
| "date": date, |
| "host": "ws-api.xfyun.cn" |
| } |
| url = url + '?' + urlencode(v) |
| return url |
|
|
| def tts(self, text): |
| BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} |
| Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} |
| CommonArgs = {"app_id": self.APPID} |
| audio_queue = self.audio_queue |
| model_name = self.model_name |
|
|
| class Callback: |
| def __init__(self): |
| self.audio_queue = audio_queue |
|
|
| def on_message(self, ws, message): |
| message = json.loads(message) |
| code = message["code"] |
| sid = message["sid"] |
| audio = message["data"]["audio"] |
| audio = base64.b64decode(audio) |
| status = message["data"]["status"] |
| if status == 2: |
| ws.close() |
| if code != 0: |
| errMsg = message["message"] |
| raise Exception(f"sid:{sid} call error:{errMsg} code:{code}") |
| else: |
| self.audio_queue.put(audio) |
|
|
| def on_error(self, ws, error): |
| raise Exception(error) |
|
|
| def on_close(self, ws, close_status_code, close_msg): |
| self.audio_queue.put(None) |
|
|
| def on_open(self, ws): |
| def run(*args): |
| d = {"common": CommonArgs, |
| "business": BusinessArgs, |
| "data": Data} |
| ws.send(json.dumps(d)) |
|
|
| thread.start_new_thread(run, ()) |
|
|
| wsUrl = self.create_url() |
| websocket.enableTrace(False) |
| a = Callback() |
| ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, |
| on_message=a.on_message) |
| status_code = 0 |
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) |
| while True: |
| audio_chunk = self.audio_queue.get() |
| if audio_chunk is None: |
| if status_code == 0: |
| raise Exception( |
| f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") |
| else: |
| break |
| status_code = 1 |
| yield audio_chunk |
|
|