| |
| |
| import os |
| import json |
| from dotenv import load_dotenv |
| import fal_client |
| import requests |
| import time |
| import io |
| from pyht import Client as PyhtClient |
| from pyht.client import TTSOptions |
| import base64 |
| import tempfile |
| import random |
|
|
| load_dotenv() |
|
|
| ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") |
|
|
|
|
| def get_zerogpu_token(): |
| return random.choice(ZEROGPU_TOKENS) |
|
|
|
|
| model_mapping = { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| "spark-tts": { |
| "provider": "spark", |
| "model": "spark-tts", |
| }, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| "index-tts": { |
| "provider": "bilibili", |
| "model": "index-tts", |
| }, |
| } |
| url = "https://tts-agi-tts-router-v2.hf.space/tts" |
| headers = { |
| "accept": "application/json", |
| "Content-Type": "application/json", |
| "Authorization": f'Bearer {os.getenv("HF_TOKEN")}', |
| } |
| data = {"text": "string", "provider": "string", "model": "string"} |
|
|
|
|
| def predict_csm(script): |
| result = fal_client.subscribe( |
| "fal-ai/csm-1b", |
| arguments={ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| "scene": script |
| }, |
| with_logs=True, |
| ) |
| return requests.get(result["audio"]["url"]).content |
|
|
|
|
| def predict_playdialog(script): |
| |
| pyht_client = PyhtClient( |
| user_id=os.getenv("PLAY_USERID"), |
| api_key=os.getenv("PLAY_SECRETKEY"), |
| ) |
|
|
| |
| voice_1 = "s3://voice-cloning-zero-shot/baf1ef41-36b6-428c-9bdf-50ba54682bd8/original/manifest.json" |
| voice_2 = "s3://voice-cloning-zero-shot/e040bd1b-f190-4bdb-83f0-75ef85b18f84/original/manifest.json" |
|
|
| |
| if isinstance(script, list): |
| |
| text = "" |
| for turn in script: |
| speaker_id = turn.get("speaker_id", 0) |
| prefix = "Host 1:" if speaker_id == 0 else "Host 2:" |
| text += f"{prefix} {turn['text']}\n" |
| else: |
| |
| text = script |
|
|
| |
| options = TTSOptions( |
| voice=voice_1, voice_2=voice_2, turn_prefix="Host 1:", turn_prefix_2="Host 2:" |
| ) |
|
|
| |
| audio_chunks = [] |
| for chunk in pyht_client.tts(text, options, voice_engine="PlayDialog"): |
| audio_chunks.append(chunk) |
|
|
| |
| return b"".join(audio_chunks) |
|
|
|
|
| def predict_dia(script): |
| |
| if isinstance(script, list): |
| |
| formatted_text = "" |
| for turn in script: |
| speaker_id = turn.get("speaker_id", 0) |
| speaker_tag = "[S1]" if speaker_id == 0 else "[S2]" |
| text = turn.get("text", "").strip().replace("[S1]", "").replace("[S2]", "") |
| formatted_text += f"{speaker_tag} {text} " |
| text = formatted_text.strip() |
| else: |
| |
| text = script |
| print(text) |
| |
| headers = { |
| |
| "Authorization": f"Bearer {get_zerogpu_token()}" |
| } |
|
|
| response = requests.post( |
| "https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue", |
| headers=headers, |
| json={"data": [text]}, |
| ) |
|
|
| |
| event_id = response.json()["event_id"] |
|
|
| |
| stream_url = f"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue/{event_id}" |
|
|
| |
| with requests.get(stream_url, headers=headers, stream=True) as stream_response: |
| |
| for line in stream_response.iter_lines(): |
| if line: |
| if line.startswith(b"data: ") and not line.startswith(b"data: null"): |
| audio_data = line[6:] |
| return requests.get(json.loads(audio_data)[0]["url"]).content |
|
|
|
|
| def predict_index_tts(text, reference_audio_path=None): |
| from gradio_client import Client, handle_file |
| client = Client("IndexTeam/IndexTTS") |
| if reference_audio_path: |
| prompt = handle_file(reference_audio_path) |
| else: |
| raise ValueError("index-tts 需要 reference_audio_path") |
| result = client.predict( |
| prompt=prompt, |
| text=text, |
| api_name="/gen_single" |
| ) |
| return result |
|
|
|
|
| def predict_spark_tts(text, reference_audio_path=None): |
| from gradio_client import Client, handle_file |
| client = Client("amortalize/Spark-TTS-Zero") |
| prompt_wav = None |
| if reference_audio_path: |
| prompt_wav = handle_file(reference_audio_path) |
| result = client.predict( |
| text=text, |
| prompt_text=text, |
| prompt_wav_upload=prompt_wav, |
| prompt_wav_record=prompt_wav, |
| api_name="/voice_clone" |
| ) |
| return result |
|
|
|
|
| def predict_tts(text, model, reference_audio_path=None): |
| global client |
| print(f"Predicting TTS for {model}") |
| |
| if model == "csm-1b": |
| return predict_csm(text) |
| elif model == "playdialog-1.0": |
| return predict_playdialog(text) |
| elif model == "dia-1.6b": |
| return predict_dia(text) |
| elif model == "index-tts": |
| return predict_index_tts(text, reference_audio_path) |
| elif model == "spark-tts": |
| return predict_spark_tts(text, reference_audio_path) |
|
|
| if not model in model_mapping: |
| raise ValueError(f"Model {model} not found") |
|
|
| |
| payload = { |
| "text": text, |
| "provider": model_mapping[model]["provider"], |
| "model": model_mapping[model]["model"], |
| } |
| |
| supports_reference = model in [ |
| "styletts2", "eleven-multilingual-v2", "eleven-turbo-v2.5", "eleven-flash-v2.5" |
| ] |
| if reference_audio_path and supports_reference: |
| with open(reference_audio_path, "rb") as f: |
| audio_bytes = f.read() |
| audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") |
| |
| if model == "styletts2": |
| payload["reference_speaker"] = audio_b64 |
| else: |
| payload["reference_audio"] = audio_b64 |
|
|
| result = requests.post( |
| url, |
| headers=headers, |
| data=json.dumps(payload), |
| ) |
|
|
| response_json = result.json() |
|
|
| audio_data = response_json["audio_data"] |
| extension = response_json["extension"] |
| |
| audio_bytes = base64.b64decode(audio_data) |
|
|
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as temp_file: |
| temp_file.write(audio_bytes) |
| temp_path = temp_file.name |
|
|
| return temp_path |
|
|
|
|
| if __name__ == "__main__": |
| print( |
| predict_dia( |
| [ |
| {"text": "Hello, how are you?", "speaker_id": 0}, |
| {"text": "I'm great, thank you!", "speaker_id": 1}, |
| ] |
| ) |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|