F5-TTS-pt-br / services /tts_service.py
fuuuzzy's picture
Upload folder using huggingface_hub
7c71fa7 verified
import os
import requests
import logging
import shutil
import subprocess
from urllib.parse import urlparse
from typing import List, Dict, Any, Optional
from AgentF5TTSChunk import AgentF5TTS
logger = logging.getLogger("services.tts")
def get_audio_duration(file_path: str) -> float:
"""Get duration of audio file using ffprobe."""
try:
cmd = [
'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1', file_path
]
output = subprocess.check_output(cmd).decode().strip()
return float(output)
except Exception as e:
logger.error(f"Failed to get duration for {file_path}: {e}")
return 0.0
class TTSService:
def __init__(self, config: Dict[str, Any]):
self.config = config['tts']
self.voices_dir = self.config['voices_dir']
self.output_dir = self.config['output_dir']
# Ensure directories exist
os.makedirs(self.voices_dir, exist_ok=True)
os.makedirs(self.output_dir, exist_ok=True)
# Load Model
logger.info("Loading F5-TTS Model...")
try:
self.agent = AgentF5TTS(
ckpt_file=self.config['checkpoint_file'],
device=self.config.get('device', 'cuda')
)
logger.info("F5-TTS Model Loaded successfully.")
except Exception as e:
logger.error(f"Failed to load F5-TTS Model: {e}")
raise e
def _get_extension_from_url(self, url: str) -> str:
parsed = urlparse(url)
path = parsed.path
ext = os.path.splitext(path)[1]
if not ext:
return ".wav"
return ext
def _download_file(self, url: str, path: str):
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
with open(path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def prepare_voices(self, character_voices: List[Dict[str, str]]) -> Dict[str, Dict[str, str]]:
"""
Ensure all reference voices are available locally.
Returns a map of character_name -> {'path': local_file_path, 'text': ref_text}
"""
voice_map = {}
for cv in character_voices:
char_name = cv.get('character')
voice_id = cv.get('id')
url = cv.get('timbre_url') # Updated from character_url
text = cv.get('timbre_text', "") # New field
if not voice_id:
continue
# Use ID as filename to avoid duplicates
ext = ".wav"
if url:
ext = self._get_extension_from_url(url)
filename = f"{voice_id}{ext}"
local_path = os.path.join(self.voices_dir, filename)
# Download if not exists
if not os.path.exists(local_path):
if url:
try:
logger.info(f"Downloading voice {voice_id}")
self._download_file(url, local_path)
except Exception as e:
logger.error(f"Failed to download voice {voice_id}: {e}")
continue
else:
logger.warning(f"Voice {voice_id} missing locally and no URL.")
continue
if os.path.exists(local_path):
voice_data = {'path': local_path, 'text': text}
if char_name:
voice_map[char_name] = voice_data
voice_map[str(voice_id)] = voice_data
return voice_map
def process_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a TTS generation task.
Returns dictionary containing list of generated audio segments with metadata.
"""
task_id = task['task_id']
data = task['data']
character_voices = data.get('character_voice', [])
content = data.get('content', [])
if not content:
raise ValueError("No content provided.")
# 1. Prepare Voices
voice_map = self.prepare_voices(character_voices)
# 2. Create Task Output Directory
task_out_dir = os.path.join(self.output_dir, task_id)
os.makedirs(task_out_dir, exist_ok=True)
segments_metadata = []
# 3. Inference Loop
logger.info(f"Starting inference for {len(content)} segments")
for idx, segment in enumerate(content):
char_name = segment.get('character')
text = segment.get('translation')
start_time = segment.get('start', 0.0)
end_time = segment.get('end', 0.0)
if not text:
continue
# Calculate original duration for merger
original_duration = max(0.0, end_time - start_time)
voice_data = voice_map.get(char_name)
if not voice_data:
logger.warning(f"Segment {idx}: No voice for '{char_name}'. Skipping.")
continue
ref_audio_path = voice_data['path']
ref_audio_text = voice_data['text']
out_filename = f"{idx:04d}.wav"
out_path = os.path.join(task_out_dir, out_filename)
try:
self.agent.infer(
ref_file=ref_audio_path,
ref_text=ref_audio_text, # Pass the reference text
gen_text=text,
file_wave=out_path,
remove_silence=self.config.get('remove_silence', True),
speed=self.config.get('speed', 1.0)
)
if os.path.exists(out_path):
gen_duration = get_audio_duration(out_path)
segments_metadata.append({
'index': idx,
'path': out_path,
'start_time': start_time,
'end_time': end_time,
'original_duration': original_duration,
'gen_duration': gen_duration
})
except Exception as e:
logger.error(f"Inference failed for segment {idx}: {e}")
if not segments_metadata:
raise Exception("No audio generated.")
return {
'task_id': task_id,
'segments': segments_metadata,
'task_dir': task_out_dir,
'hook_url': data.get('hook_url'),
'video_url': data.get('video_url'),
'priority': task.get('priority', 3)
}