F5-TTS-THAI / src /f5_tts /multi_speech_processor.py
pythonlearnreal's picture
Upload folder using huggingface_hub
106478e verified
"""
Multi-Speech Processor for F5-TTS Thai
จัดการการประมวลผล multi-speech และ segment editing
"""
import tempfile
import numpy as np
import gradio as gr
import soundfile as sf
from collections import OrderedDict
from f5_tts.infer.infer_gradio import parse_speechtypes_text, infer
from f5_tts.cleantext.number_tha import replace_numbers_with_thai
from f5_tts.cleantext.th_repeat import process_thai_repeat
from f5_tts.config import MAX_SEGMENTS
class MultiSpeechProcessor:
"""จัดการการประมวลผล Multi-Speech และ Segment Editing"""
def __init__(self, model_manager):
self.model_manager = model_manager
def generate_multistyle_speech(self,
gen_text,
cross_fade_duration,
nfe_step,
speech_types_data,
remove_silence,
silence_inputs):
"""สร้างเสียงหลายสไตล์"""
# จัดระเบียบข้อมูล speech types
speech_types = self._organize_speech_types(speech_types_data)
# แยก segments จากข้อความ
segments = parse_speechtypes_text(gen_text)
# สร้างเสียงสำหรับแต่ละ segment
generated_audio_segments = []
segment_infos = []
current_style = "Regular"
for idx, segment in enumerate(segments):
style = segment["style"]
text = segment["text"]
# เลือก style
if style in speech_types:
current_style = style
else:
gr.Warning(f"ไม่พบสไตล์ {style} จะใช้สไตล์ Regular แทน")
current_style = "Regular"
# ตรวจสอบ reference audio
try:
ref_audio = speech_types[current_style]["audio"]
except KeyError:
gr.Warning(f"กรุณาใส่เสียงต้นฉบับสำหรับสไตล์ {current_style}")
return self._empty_output()
ref_text = speech_types[current_style].get("ref_text", "")
# ประมวลผลข้อความ
ms_cleaned_text = process_thai_repeat(replace_numbers_with_thai(text))
# สร้างเสียง
audio_out, _, ref_text_out = infer(
ref_audio,
ref_text,
ms_cleaned_text,
self.model_manager.get_model(),
self.model_manager.get_vocoder(),
remove_silence,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
show_info=print
)
sr, audio_data = audio_out
# เพิ่ม silence
audio_data = self._add_silence(audio_data, sr, silence_inputs, idx)
generated_audio_segments.append(audio_data)
segment_infos.append({
"index": idx,
"style": style,
"text": text,
"ref_audio": ref_audio,
"ref_text": ref_text,
"audio_data": audio_data,
"sr": sr,
"silence_ms": self._get_silence_value(silence_inputs, idx)
})
# อัปเดต ref_text
speech_types[current_style]["ref_text"] = ref_text_out
if generated_audio_segments:
return self._combine_segments(generated_audio_segments, segment_infos, sr)
else:
gr.Warning("ไม่สามารถสร้างเสียงได้")
return self._empty_output()
def update_silence_all(self, silence_inputs, segments, sr):
"""อัปเดต silence ของทุก segment"""
if not segments or len(segments) == 0:
return self._empty_segment_output() + [None, None, segments, sr]
# อัปเดต silence ของแต่ละ segment
for idx, seg in enumerate(segments):
audio_data = seg["audio_data"]
old_silence_ms = seg.get("silence_ms", 0)
old_silence_samples = int((old_silence_ms / 1000.0) * seg["sr"])
# ตัด silence เดิมออก
if old_silence_samples > 0 and len(audio_data) > old_silence_samples:
audio_data = audio_data[:-old_silence_samples]
# เติม silence ใหม่
silence_ms = self._get_silence_value(silence_inputs, idx)
seg["silence_ms"] = silence_ms
silence_samples = int((silence_ms / 1000.0) * seg["sr"])
if silence_samples > 0:
seg["audio_data"] = np.concatenate([audio_data, np.zeros(silence_samples, dtype=audio_data.dtype)])
else:
seg["audio_data"] = audio_data
# ต่อเสียงใหม่
final_audio_data = np.concatenate([s["audio_data"] for s in segments])
download_path = self._save_audio(final_audio_data, sr)
return self._prepare_segment_outputs(segments) + [(sr, final_audio_data), download_path, segments, sr]
def regenerate_segment(self, idx, new_text, silence_ms, segments, cross_fade_duration, nfe_step):
"""สร้าง segment ใหม่"""
if not segments or idx >= len(segments):
return self._empty_segment_output() + [None, None, segments, 24000]
seg = segments[idx]
# ใช้ข้อความใหม่
ms_cleaned_text = process_thai_repeat(replace_numbers_with_thai(new_text))
# สร้างเสียงใหม่
audio_out, _, _ = infer(
seg["ref_audio"],
seg["ref_text"],
ms_cleaned_text,
self.model_manager.get_model(),
self.model_manager.get_vocoder(),
True,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
show_info=print
)
sr, audio_data = audio_out
# เพิ่ม silence
try:
silence_ms = float(silence_ms)
except Exception:
silence_ms = 0
silence_samples = int((silence_ms / 1000.0) * sr)
if silence_samples > 0:
audio_data = np.concatenate([audio_data, np.zeros(silence_samples, dtype=audio_data.dtype)])
# อัปเดต segment
segments[idx]["audio_data"] = audio_data
segments[idx]["sr"] = sr
segments[idx]["text"] = new_text
segments[idx]["silence_ms"] = silence_ms
# ต่อเสียงใหม่
final_audio_data = np.concatenate([s["audio_data"] for s in segments])
download_path = self._save_audio(final_audio_data, sr)
return self._prepare_segment_outputs(segments) + [(sr, final_audio_data), download_path, segments, sr]
def validate_speech_types(self, gen_text, speech_type_names):
"""ตรวจสอบ speech types ที่จำเป็น"""
speech_types_available = set(name for name in speech_type_names if name)
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["style"] for segment in segments)
missing_speech_types = speech_types_in_text - speech_types_available
return gr.update(interactive=len(missing_speech_types) == 0)
def _organize_speech_types(self, speech_types_data):
"""จัดระเบียบข้อมูล speech types"""
max_speech_types = len(speech_types_data) // 3
speech_type_names_list = speech_types_data[:max_speech_types]
speech_type_audios_list = speech_types_data[max_speech_types:2 * max_speech_types]
speech_type_ref_texts_list = speech_types_data[2 * max_speech_types:3 * max_speech_types]
speech_types = OrderedDict()
ref_text_idx = 0
for name_input, audio_input, ref_text_input in zip(
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
):
if name_input and audio_input:
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
else:
speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
ref_text_idx += 1
return speech_types
def _add_silence(self, audio_data, sr, silence_inputs, idx):
"""เพิ่ม silence ให้ audio"""
silence_ms = self._get_silence_value(silence_inputs, idx)
silence_samples = int((silence_ms / 1000.0) * sr)
if silence_samples > 0:
return np.concatenate([audio_data, np.zeros(silence_samples, dtype=audio_data.dtype)])
return audio_data
def _get_silence_value(self, silence_inputs, idx):
"""ดึงค่า silence สำหรับ index ที่กำหนด"""
if idx < len(silence_inputs) and silence_inputs[idx] is not None:
try:
return float(silence_inputs[idx])
except Exception:
return 0
return 0
def _save_audio(self, audio_data, sr):
"""บันทึกไฟล์เสียง"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_audio:
sf.write(tmp_audio.name, audio_data, sr)
return tmp_audio.name
def _combine_segments(self, generated_audio_segments, segment_infos, sr):
"""รวม segments เข้าด้วยกัน"""
final_audio_data = np.concatenate(generated_audio_segments)
download_path = self._save_audio(final_audio_data, sr)
return (
(sr, final_audio_data),
download_path,
*self._prepare_segment_outputs(segment_infos),
segment_infos,
sr
)
def _prepare_segment_outputs(self, segments):
"""เตรียม output สำหรับ segment players"""
segment_outputs = [gr.update(visible=False, value=None) for _ in range(MAX_SEGMENTS)]
segment_texts = [gr.update(visible=False, value="") for _ in range(MAX_SEGMENTS)]
segment_silences = [gr.update(visible=False, value=0) for _ in range(MAX_SEGMENTS)]
segment_btn_vis = [gr.update(visible=False) for _ in range(MAX_SEGMENTS)]
for i, seg in enumerate(segments):
if i < MAX_SEGMENTS:
segment_outputs[i] = gr.update(value=(seg["sr"], seg["audio_data"]), visible=True)
segment_texts[i] = gr.update(value=seg["text"], visible=True)
segment_silences[i] = gr.update(value=seg["silence_ms"], visible=True)
segment_btn_vis[i] = gr.update(visible=True)
return segment_outputs + segment_texts + segment_silences + segment_btn_vis
def _empty_output(self):
"""ส่งคืน empty output"""
empty_segments = [gr.update(visible=False, value=None) for _ in range(MAX_SEGMENTS)]
empty_texts = [gr.update(visible=False, value="") for _ in range(MAX_SEGMENTS)]
empty_silences = [gr.update(visible=False, value=0) for _ in range(MAX_SEGMENTS)]
empty_btns = [gr.update(visible=False) for _ in range(MAX_SEGMENTS)]
return (
None, None,
*empty_segments, *empty_texts, *empty_silences, *empty_btns,
[], 24000
)
def _empty_segment_output(self):
"""ส่งคืน empty segment output"""
empty_segments = [gr.update(visible=False, value=None) for _ in range(MAX_SEGMENTS)]
empty_texts = [gr.update(visible=False, value="") for _ in range(MAX_SEGMENTS)]
empty_silences = [gr.update(visible=False, value=0) for _ in range(MAX_SEGMENTS)]
empty_btns = [gr.update(visible=False) for _ in range(MAX_SEGMENTS)]
return empty_segments + empty_texts + empty_silences + empty_btns