pdf_explainer / src /processors /generate_simple_tts_audio.py
spagestic's picture
feat: implement simplified audio processing with enhanced TTS API integration
8f3c067
"""Simplified TTS audio generation that uses the enhanced API endpoints."""
import os
import requests
import tempfile
import soundfile as sf
import numpy as np
import gradio as gr
def generate_simple_tts_audio(text_input: str, audio_prompt_input=None, progress=None):
"""
Generate TTS audio using the enhanced API that handles chunking and concatenation server-side.
Args:
text_input: The text to convert to speech (any length)
audio_prompt_input: Optional audio prompt for voice cloning
progress: Optional progress callback
Returns:
Tuple of (sample_rate, audio_data)
"""
# Use the new full-text endpoint that handles everything server-side
FULL_TEXT_ENDPOINT = os.getenv("FULL_TEXT_TTS_ENDPOINT", "YOUR-MODAL-ENDPOINT-URL/generate_full_text_audio")
GENERATE_WITH_FILE_ENDPOINT = os.getenv("GENERATE_WITH_FILE_ENDPOINT", "YOUR-MODAL-ENDPOINT-URL/generate_with_file")
if not text_input or len(text_input.strip()) == 0:
raise gr.Error("Please enter some text to synthesize.")
if progress:
progress(0.1, desc="Preparing request for full-text processing...")
try:
if audio_prompt_input is None:
# Use the new full-text endpoint for enhanced processing
if progress:
progress(0.3, desc="Sending full text to enhanced TTS API...")
payload = {
"text": text_input,
"max_chunk_size": 800,
"silence_duration": 0.5,
"fade_duration": 0.1,
"overlap_sentences": 0
}
response = requests.post(
FULL_TEXT_ENDPOINT,
json=payload,
headers={"Content-Type": "application/json"},
timeout=300, # Longer timeout for full-text processing
stream=True
)
if response.status_code != 200:
raise gr.Error(f"API Error: {response.status_code} - {response.text}")
if progress:
progress(0.6, desc="Server processing text chunks in parallel...")
# Get content length if available for progress tracking
content_length = response.headers.get('content-length')
chunks_processed = response.headers.get('X-Chunks-Processed', 'unknown')
total_chars = response.headers.get('X-Total-Characters', len(text_input))
if progress:
progress(0.7, desc=f"Processing {chunks_processed} chunks ({total_chars} characters)...")
bytes_downloaded = 0
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
temp_file.write(chunk)
bytes_downloaded += len(chunk)
# Update progress based on bytes downloaded
if progress:
progress(0.7, desc=f"Downloading processed audio... ({bytes_downloaded // 1024}KB)")
temp_path = temp_file.name
if progress:
progress(0.9, desc="Loading final audio...")
audio_data, sample_rate = sf.read(temp_path)
os.unlink(temp_path)
if progress:
progress(1.0, desc=f"Complete! Processed {chunks_processed} chunks into final audio.")
return (sample_rate, audio_data)
else:
# For voice cloning, still use the original endpoint
if progress:
progress(0.3, desc="Preparing voice cloning request...")
files = {'text': (None, text_input)}
with open(audio_prompt_input, 'rb') as f:
audio_content = f.read()
files['voice_prompt'] = ('voice_prompt.wav', audio_content, 'audio/wav')
if progress:
progress(0.5, desc="Sending request with voice cloning...")
response = requests.post(
GENERATE_WITH_FILE_ENDPOINT,
files=files,
timeout=180,
stream=True
)
if response.status_code != 200:
raise gr.Error(f"API Error: {response.status_code} - {response.text}")
if progress:
progress(0.8, desc="Processing cloned voice...")
bytes_downloaded = 0
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
temp_file.write(chunk)
bytes_downloaded += len(chunk)
if progress:
progress(0.8, desc=f"Downloading cloned audio... ({bytes_downloaded // 1024}KB)")
temp_path = temp_file.name
audio_data, sample_rate = sf.read(temp_path)
os.unlink(temp_path)
if progress:
progress(1.0, desc="Voice cloning complete!")
return (sample_rate, audio_data)
except requests.exceptions.Timeout:
raise gr.Error("Request timed out. The API might be processing a large text. Please try again.")
except requests.exceptions.ConnectionError:
raise gr.Error("Unable to connect to the API. Please check if the endpoint URL is correct.")
except Exception as e:
raise gr.Error(f"Error generating audio: {str(e)}")
def get_api_processing_info(text: str) -> dict:
"""
Get processing information from the API without generating audio.
Args:
text: The text to analyze
Returns:
Dictionary with processing information
"""
try:
# This could be enhanced to call an API info endpoint
text_length = len(text.strip()) if text else 0
estimated_chunks = max(1, text_length // 800)
return {
"text_length": text_length,
"estimated_chunks": estimated_chunks,
"processing_mode": "server_side_parallel_gpu",
"benefits": [
"Server-side GPU acceleration",
"Parallel chunk processing",
"Automatic audio concatenation",
"Optimized for large texts",
"No client-side resource usage"
]
}
except Exception as e:
return {"error": f"Failed to analyze text: {str(e)}"}