Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import io | |
| import torch | |
| import librosa | |
| import zipfile | |
| import requests | |
| import torchaudio | |
| import numpy as np | |
| import gradio as gr | |
| from uroman import uroman | |
| import concurrent.futures | |
| from pydub import AudioSegment | |
| from datasets import load_dataset | |
| from IPython.display import Audio | |
| from scipy.signal import butter, lfilter | |
| from speechbrain.pretrained import EncoderClassifier | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| # Variables | |
| spk_model_name = "speechbrain/spkrec-xvect-voxceleb" | |
| dataset_name = "truong-xuan-linh/vi-xvector-speechbrain" | |
| cache_dir="temp/" | |
| default_model_name = "truong-xuan-linh/speecht5-vietnamese-voiceclone-lsvsc" | |
| speaker_id = "speech_dataset_denoised" | |
| # Active device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models and datasets | |
| speaker_model = EncoderClassifier.from_hparams( | |
| source=spk_model_name, | |
| run_opts={"device": device}, | |
| savedir=os.path.join("/tmp", spk_model_name), | |
| ) | |
| dataset = load_dataset( | |
| dataset_name, | |
| download_mode="force_redownload", | |
| verification_mode="no_checks", | |
| cache_dir=cache_dir, | |
| revision="5ea5e4345258333cbc6d1dd2544f6c658e66a634" | |
| ) | |
| dataset = dataset["train"].to_list() | |
| dataset_dict = {} | |
| for rc in dataset: | |
| dataset_dict[rc["speaker_id"]] = rc["embedding"] | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| # Model utility functions | |
| def remove_special_characters(sentence): | |
| # Use regular expression to keep only letters, periods, and commas | |
| sentence_after_removal = re.sub(r'[^a-zA-Z\s,.\u00C0-\u1EF9]', ' ,', sentence) | |
| return sentence_after_removal | |
| def create_speaker_embedding(waveform): | |
| with torch.no_grad(): | |
| speaker_embeddings = speaker_model.encode_batch(waveform) | |
| speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=-1) | |
| return speaker_embeddings | |
| def butter_bandpass(lowcut, highcut, fs, order=5): | |
| nyq = 0.5 * fs | |
| low = lowcut / nyq | |
| high = highcut / nyq | |
| b, a = butter(order, [low, high], btype='band') | |
| return b, a | |
| def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): | |
| b, a = butter_bandpass(lowcut, highcut, fs, order=order) | |
| y = lfilter(b, a, data) | |
| return y | |
| def korean_splitter(string): | |
| pattern = re.compile('[가-힣]+') | |
| matches = pattern.findall(string) | |
| return matches | |
| def uroman_normalization(string): | |
| korean_inputs = korean_splitter(string) | |
| for korean_input in korean_inputs: | |
| korean_roman = uroman(korean_input) | |
| string = string.replace(korean_input, korean_roman) | |
| return string | |
| # Model class | |
| class Model(): | |
| def __init__(self, model_name, speaker_url=""): | |
| self.model_name = model_name | |
| self.processor = SpeechT5Processor.from_pretrained(model_name) | |
| self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name) | |
| self.model.eval() | |
| self.speaker_url = speaker_url | |
| if speaker_url: | |
| print(f"download speaker_url") | |
| response = requests.get(speaker_url) | |
| audio_stream = io.BytesIO(response.content) | |
| audio_segment = AudioSegment.from_file(audio_stream, format="wav") | |
| audio_segment = audio_segment.set_channels(1) | |
| audio_segment = audio_segment.set_frame_rate(16000) | |
| audio_segment = audio_segment.set_sample_width(2) | |
| wavform, _ = torchaudio.load(audio_segment.export()) | |
| self.speaker_embeddings = create_speaker_embedding(wavform)[0] | |
| else: | |
| self.speaker_embeddings = None | |
| if model_name == "truong-xuan-linh/speecht5-vietnamese-commonvoice" or model_name == "truong-xuan-linh/speecht5-irmvivoice": | |
| self.speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file | |
| def inference(self, text, speaker_id=None): | |
| if "voiceclone" in self.model_name: | |
| if not self.speaker_url: | |
| self.speaker_embeddings = torch.tensor(dataset_dict[speaker_id]) | |
| with torch.no_grad(): | |
| full_speech = [] | |
| separators = r";|\.|!|\?|\n" | |
| text = uroman_normalization(text) | |
| text = remove_special_characters(text) | |
| text = text.replace(" ", "▁") | |
| split_texts = re.split(separators, text) | |
| for split_text in split_texts: | |
| if split_text != "▁": | |
| split_text = split_text.lower() + "▁" | |
| print(split_text) | |
| inputs = self.processor.tokenizer(text=split_text, return_tensors="pt") | |
| speech = self.model.generate_speech(inputs["input_ids"], threshold=0.5, speaker_embeddings=self.speaker_embeddings, vocoder=vocoder) | |
| full_speech.append(speech.numpy()) | |
| return np.concatenate(full_speech) | |
| def moving_average(data, window_size): | |
| return np.convolve(data, np.ones(window_size)/window_size, mode='same') | |
| # Initialize model | |
| model = Model( | |
| model_name=default_model_name, | |
| speaker_url="" | |
| ) | |
| # Audio processing functions | |
| def read_srt(file_path): | |
| subtitles = [] | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| lines = file.readlines() | |
| for i in range(0, len(lines), 4): | |
| if i+2 < len(lines): | |
| start_time, end_time = lines[i+1].strip().split('-->') | |
| start_time = start_time.strip() | |
| end_time = end_time.strip() | |
| text = lines[i+2].strip() | |
| # Delete trailing dots | |
| while text.endswith('.'): | |
| text = text[:-1] | |
| subtitles.append((start_time, end_time, text)) | |
| return subtitles | |
| def is_valid_srt(file_path): | |
| try: | |
| read_srt(file_path) | |
| return True | |
| except: | |
| return False | |
| def time_to_seconds(time_str): | |
| h, m, s = time_str.split(':') | |
| seconds = int(h) * 3600 + int(m) * 60 + float(s.replace(',', '.')) | |
| return seconds | |
| def closest_speedup_factor(factor, allowed_factors): | |
| return min(allowed_factors, key=lambda x: abs(x - factor)) + 0.1 | |
| def lowpass_filter(audio_data, cutoff=4000, fs=16000, order=4): | |
| """ | |
| Áp dụng bộ lọc thông thấp cho dữ liệu âm thanh. | |
| Parameters: | |
| - audio_data: numpy array chứa dữ liệu âm thanh. | |
| - cutoff: Tần số cắt (Hz). | |
| - fs: Tần số lấy mẫu (Hz). | |
| - order: Bậc của bộ lọc. | |
| Returns: | |
| - filtered_audio: numpy array của âm thanh đã được lọc. | |
| """ | |
| # Tạo bộ lọc butterworth | |
| nyq = 0.5 * fs | |
| normal_cutoff = cutoff / nyq | |
| b, a = butter(order, normal_cutoff, btype='low', analog=False) | |
| # Áp dụng bộ lọc | |
| filtered_audio = lfilter(b, a, audio_data) | |
| return filtered_audio | |
| def generate_audio_with_pause(srt_file_path, speaker_id, speed_of_non_edit_speech): | |
| subtitles = read_srt(srt_file_path) | |
| audio_clips = [] | |
| # allowed_factors = [1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] | |
| for i, (start_time, end_time, text) in enumerate(subtitles): | |
| # print("=====================================") | |
| # print("Text number:", i) | |
| # print(f"Start: {start_time}, End: {end_time}, Text: {text}") | |
| # Generate initial audio | |
| audio_data = model.inference(text=text, speaker_id=speaker_id) | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| # Calculate required duration | |
| desired_duration = time_to_seconds(end_time) - time_to_seconds(start_time) | |
| current_duration = len(audio_data) / 16000 | |
| # print(f"Time to seconds: {time_to_seconds(start_time)}, {time_to_seconds(end_time)}") | |
| # print(f"Desired duration: {desired_duration}, Current duration: {current_duration}") | |
| # Adjust audio speed by speedup | |
| if current_duration > desired_duration: | |
| raw_speedup_factor = current_duration / desired_duration | |
| # speedup_factor = closest_speedup_factor(raw_speedup_factor, allowed_factors) | |
| speedup_factor = raw_speedup_factor | |
| audio_data = librosa.effects.time_stretch( | |
| y=audio_data, | |
| rate=speedup_factor, | |
| n_fft=1024, | |
| hop_length=256 | |
| ) | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| audio_data = audio_data * 1.2 | |
| if current_duration < desired_duration: | |
| if speed_of_non_edit_speech != 1: | |
| audio_data = librosa.effects.time_stretch( | |
| y=audio_data, | |
| rate=speed_of_non_edit_speech, | |
| n_fft=1024, | |
| hop_length=256 | |
| ) | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| audio_data = audio_data * 1.2 | |
| current_duration = len(audio_data) / 16000 | |
| padding = int((desired_duration - current_duration) * 16000) | |
| audio_data = np.concatenate([np.zeros(padding), audio_data]) | |
| # print(f"Final audio duration: {len(audio_data) / 16000}") | |
| # print("=====================================") | |
| audio_clips.append(lowpass_filter(audio_data)) | |
| # Add pause | |
| if i < len(subtitles) - 1: | |
| next_start_time = subtitles[i + 1][0] | |
| pause_duration = time_to_seconds(next_start_time) - time_to_seconds(end_time) | |
| if pause_duration: | |
| pause_samples = int(pause_duration * 16000) | |
| audio_clips.append(np.zeros(pause_samples)) | |
| final_audio = np.concatenate(audio_clips) | |
| return final_audio | |
| def check_input_files(srt_files): | |
| if not srt_files: | |
| return None | |
| invalid_files = [] | |
| for srt_file in srt_files: | |
| if not is_valid_srt(srt_file.name): | |
| invalid_files.append(srt_file.name) | |
| if invalid_files: | |
| raise gr.Warning(f"Invalid SRT files: {', '.join(invalid_files)}") | |
| def srt_to_audio_multi(srt_files, speaker_id, speed_of_non_edit_speech): | |
| output_paths = [] | |
| invalid_files = [] | |
| def process_file(srt_file): | |
| if not is_valid_srt(srt_file.name): | |
| invalid_files.append(srt_file.name) | |
| return None | |
| audio_data = generate_audio_with_pause(srt_file.name, speaker_id, speed_of_non_edit_speech) | |
| output_path = os.path.join(cache_dir, f'output_{os.path.basename(srt_file.name)}.wav') | |
| torchaudio.save(output_path, torch.tensor(audio_data).unsqueeze(0), 16000) | |
| return output_path | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [executor.submit(process_file, srt_file) for srt_file in srt_files] | |
| for future in concurrent.futures.as_completed(futures): | |
| result = future.result() | |
| if result: | |
| output_paths.append(result) | |
| if invalid_files: | |
| raise gr.Warning(f"Invalid SRT files: {', '.join(invalid_files)}") | |
| return output_paths | |
| def download_all(outputs): | |
| # If no outputs, return None | |
| if not outputs: | |
| raise gr.Warning("No files available for download.") | |
| zip_path = os.path.join(cache_dir, "all_outputs.zip") | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for file_path in outputs: | |
| zipf.write(file_path, os.path.basename(file_path)) | |
| return zip_path | |
| # Initialize model | |
| model = Model( | |
| model_name=default_model_name, | |
| speaker_url="" | |
| ) | |
| # UI display | |
| css = ''' | |
| #title{text-align: center} | |
| #container{display: flex; justify-content: space-between; align-items: center;} | |
| #setting-box{padding: 10px; border: 1px solid #ccc; border-radius: 5px;} | |
| #setting-heading{margin-bottom: 10px; text-align: center;} | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| title = gr.HTML( | |
| """<h1>SRT to Audio Tool</h1>""", | |
| elem_id="title", | |
| ) | |
| with gr.Column(elem_id="setting-box"): | |
| heading = gr.HTML("<h2>Settings</h2>", elem_id="setting-heading") | |
| with gr.Row(): | |
| speaker_id = gr.Dropdown( | |
| label="Speaker ID", | |
| choices=list(dataset_dict.keys()), | |
| value=speaker_id | |
| ) | |
| speed_of_non_edit_speech = gr.Slider( | |
| label="Speed of non-edit speech", | |
| minimum=1, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.2 | |
| ) | |
| with gr.Row(elem_id="container"): | |
| inp_srt = gr.File( | |
| label="Upload SRT files", | |
| file_count="multiple", | |
| type="filepath", | |
| file_types=["srt"], | |
| height=600 | |
| ) | |
| out = gr.File( | |
| label="Generated Audio Files", | |
| file_count="multiple", | |
| type="filepath", | |
| height=600, | |
| interactive=False | |
| ) | |
| btn = gr.Button("Generate") | |
| download_btn = gr.Button("Download All") | |
| download_out = gr.File( | |
| label="Download ZIP", | |
| interactive=False, | |
| height=100 | |
| ) | |
| inp_srt.change(check_input_files, inputs=inp_srt) | |
| btn.click( | |
| fn=srt_to_audio_multi, | |
| inputs=[inp_srt, speaker_id, speed_of_non_edit_speech], | |
| outputs=out | |
| ) | |
| download_btn.click(fn=download_all, inputs=out, outputs=download_out) | |
| if __name__ == "__main__": | |
| demo.launch() |