|
|
import gradio as gr
|
|
|
from pathlib import Path
|
|
|
|
|
|
import soundfile as sf
|
|
|
|
|
|
|
|
|
import torch
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
_old_load = torch.load
|
|
|
|
|
|
def safe_torch_load(*args, **kwargs):
|
|
|
args = list(args)
|
|
|
if len(args) >= 2:
|
|
|
args[1] = device
|
|
|
else:
|
|
|
kwargs['map_location'] = device
|
|
|
return _old_load(*args, **kwargs)
|
|
|
|
|
|
torch.load = safe_torch_load
|
|
|
|
|
|
import torchaudio
|
|
|
import hydra
|
|
|
from omegaconf import OmegaConf
|
|
|
import diffusers.schedulers as noise_schedulers
|
|
|
|
|
|
from utils.config import register_omegaconf_resolvers
|
|
|
from models.common import LoadPretrainedBase
|
|
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
import fairseq
|
|
|
|
|
|
register_omegaconf_resolvers()
|
|
|
config = OmegaConf.load("configs/infer.yaml")
|
|
|
|
|
|
ckpt_path = hf_hub_download(
|
|
|
repo_id="assasinatee/STAR",
|
|
|
filename="model.safetensors",
|
|
|
repo_type="model",
|
|
|
force_download=False
|
|
|
)
|
|
|
|
|
|
exp_config = OmegaConf.load("configs/config.yaml")
|
|
|
if "pretrained_ckpt" in exp_config["model"]:
|
|
|
exp_config["model"]["pretrained_ckpt"] = ckpt_path
|
|
|
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
ckpt_path = hf_hub_download(
|
|
|
repo_id="assasinatee/STAR",
|
|
|
filename="hubert_large_ll60k.pt",
|
|
|
repo_type="model",
|
|
|
force_download=False
|
|
|
)
|
|
|
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
|
|
hubert_model = hubert_models[0].eval().to(device)
|
|
|
|
|
|
scheduler = getattr(
|
|
|
noise_schedulers,
|
|
|
config["noise_scheduler"]["type"],
|
|
|
).from_pretrained(
|
|
|
config["noise_scheduler"]["name"],
|
|
|
subfolder="scheduler",
|
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def infer(audio_path: str) -> str:
|
|
|
waveform_tts, sample_rate = torchaudio.load(audio_path)
|
|
|
if sample_rate != 16000:
|
|
|
waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
|
|
|
if waveform_tts.shape[0] > 1:
|
|
|
waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
|
|
|
with torch.no_grad():
|
|
|
features, _ = hubert_model.extract_features(waveform_tts.to(device))
|
|
|
|
|
|
kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
|
|
|
kwargs['content'] = [features]
|
|
|
kwargs['condition'] = None
|
|
|
kwargs['task'] = ["speech_to_audio"]
|
|
|
|
|
|
model.eval()
|
|
|
waveform = model.inference(
|
|
|
scheduler=scheduler,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
output_file = "output_audio.wav"
|
|
|
sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
|
|
|
|
|
|
return output_file
|
|
|
|
|
|
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Glass()) as demo:
|
|
|
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
|
|
|
|
|
|
gr.Markdown("""
|
|
|
<div style="text-align: left; padding: 10px;">
|
|
|
|
|
|
## 📚️ Introduction
|
|
|
|
|
|
STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems.
|
|
|
Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output.
|
|
|
|
|
|
## 🗣️ Input
|
|
|
|
|
|
A brief input speech utterance for the overall audio scene.
|
|
|
|
|
|
> Example:A cat meowing and young female speaking
|
|
|
|
|
|
### 🎙️ Input Speech Example
|
|
|
""")
|
|
|
|
|
|
speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
|
|
|
|
|
|
gr.Markdown("""
|
|
|
<div style="text-align: left; padding: 10px;">
|
|
|
|
|
|
## 🎧️ Output
|
|
|
|
|
|
Capture both auditory events and scene cues and generate corresponding audio
|
|
|
|
|
|
### 🔊 Output Audio Example
|
|
|
""")
|
|
|
|
|
|
audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
|
|
|
|
|
|
gr.Markdown("""
|
|
|
<div style="text-align: left; padding: 10px;">
|
|
|
|
|
|
</div>
|
|
|
|
|
|
---
|
|
|
|
|
|
</div>
|
|
|
|
|
|
## 🛠️ Online Inference
|
|
|
|
|
|
You can upload your own samples, or try the quick examples provided below.
|
|
|
""")
|
|
|
|
|
|
with gr.Column():
|
|
|
input_audio = gr.Audio(label="Speech Input", type="filepath")
|
|
|
btn = gr.Button("🎵Generate Audio!", variant="primary")
|
|
|
output_audio = gr.Audio(label="Generated Audio", type="filepath")
|
|
|
btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
|
|
|
|
|
|
gr.Markdown("""
|
|
|
<div style="text-align: left; padding: 10px;">
|
|
|
|
|
|
## 🎯 Quick Examples
|
|
|
""")
|
|
|
|
|
|
with gr.Tabs():
|
|
|
with gr.Tab("VITS Generated Speech"):
|
|
|
gr.Markdown("| 🎧 Audio | 📝 Caption |\n|:--:|:--|")
|
|
|
gr.Examples(
|
|
|
examples=[
|
|
|
["wav/vits/1.wav", "A cat meowing and young female speaking"],
|
|
|
["wav/vits/2.wav", "Sustained industrial engine noise"],
|
|
|
["wav/vits/3.wav", "A woman talks and a baby whispers"],
|
|
|
["wav/vits/4.wav", "A man speaks followed by a toilet flush"],
|
|
|
["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"],
|
|
|
["wav/vits/6.wav", "A man speaking as birds are chirping"],
|
|
|
["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
|
|
|
["wav/vits/8.wav", "Birds chirping and a horse neighing"],
|
|
|
["wav/vits/9.wav", "Several church bells ringing"],
|
|
|
["wav/vits/10.wav", "A telephone rings with bell sounds"]
|
|
|
],
|
|
|
inputs=[input_audio, _],
|
|
|
label="Click examples below to try!",
|
|
|
cache_examples = False,
|
|
|
examples_per_page = 5,
|
|
|
)
|
|
|
|
|
|
with gr.Tab("Real human Speech"):
|
|
|
gr.Markdown("| 🎧 Audio | 📝 Caption |\n|:--:|:--|")
|
|
|
gr.Examples(
|
|
|
examples=[
|
|
|
["wav/human/1.wav", "A cat meowing and young female speaking"],
|
|
|
["wav/human/2.wav", "Sustained industrial engine noise"],
|
|
|
["wav/human/3.wav", "A woman talks and a baby whispers"],
|
|
|
["wav/human/4.wav", "A man speaks followed by a toilet flush"],
|
|
|
["wav/human/5.wav", "It is raining and thundering, and then a man speaks"],
|
|
|
["wav/human/6.wav", "A man speaking as birds are chirping"],
|
|
|
["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
|
|
|
["wav/human/8.wav", "Birds chirping and a horse neighing"],
|
|
|
["wav/human/9.wav", "Several church bells ringing"],
|
|
|
["wav/human/10.wav", "A telephone rings with bell sounds"]
|
|
|
],
|
|
|
inputs=[input_audio, _],
|
|
|
label="Click examples below to try!",
|
|
|
cache_examples = False,
|
|
|
examples_per_page = 5,
|
|
|
)
|
|
|
|
|
|
|
|
|
demo.launch() |