| import os |
| import torch |
| import gradio as gr |
| import soundfile as sf |
| import numpy as np |
| import torchaudio |
| from omnivoice import OmniVoice |
|
|
| |
| def safe_load(path, *args, **kwargs): |
| data, sr = sf.read(path) |
| data = torch.tensor(data).float() |
| |
| if data.ndim > 1: |
| data = data.mean(axis=1) |
| |
| return data.unsqueeze(0), sr |
|
|
| torchaudio.load = safe_load |
|
|
| |
| os.environ["HF_HOME"] = "/tmp/hf_cache" |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
| os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| print("π Loading model...") |
| model = OmniVoice.from_pretrained( |
| "k2-fsa/OmniVoice", |
| device_map=device, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
| ) |
| print("β
Model loaded") |
|
|
| |
| def clone_voice(audio_file, text, lang, ref_text): |
|
|
| |
| waveform, sr = sf.read(audio_file) |
| waveform = torch.tensor(waveform).float() |
|
|
| if waveform.ndim > 1: |
| waveform = waveform.mean(axis=1) |
|
|
| waveform = waveform.unsqueeze(0) |
|
|
| |
| if sr != 24000: |
| resampler = torchaudio.transforms.Resample(sr, 24000) |
| waveform = resampler(waveform) |
|
|
| temp_audio = "temp.wav" |
| sf.write(temp_audio, waveform.squeeze().cpu().numpy(), 24000) |
|
|
| |
| final_text = f"[{lang}] {text}" |
|
|
| |
| audio = model.generate( |
| text=final_text, |
| ref_audio=temp_audio, |
| ref_text=ref_text if ref_text else None, |
| language=lang |
| ) |
|
|
| |
| if isinstance(audio, list): |
| audio = audio[0] |
|
|
| if not isinstance(audio, torch.Tensor): |
| audio = torch.tensor(audio) |
|
|
| if audio.dim() == 1: |
| audio = audio.unsqueeze(0) |
|
|
| |
| output_file = "output.wav" |
| sf.write(output_file, audio.squeeze().cpu().numpy(), 24000) |
|
|
| return output_file |
|
|
|
|
| |
| demo = gr.Interface( |
| fn=clone_voice, |
| inputs=[ |
| gr.Audio(type="filepath", label="Upload Voice Sample"), |
| gr.Textbox(label="Enter Text"), |
| gr.Textbox(label="Language Code (en, hi, etc)"), |
| gr.Textbox(label="Reference Text (optional)") |
| ], |
| outputs=gr.Audio(label="Generated Voice"), |
| title="π€ OmniVoice Cloner", |
| description="Upload voice β enter text β language β generate cloned speech" |
| ) |
|
|
| demo.launch() |