File size: 3,555 Bytes
93c90ef 4853fdc 93c90ef 4853fdc 93c90ef 4853fdc 93c90ef 4853fdc 93c90ef 4853fdc 93c90ef 4853fdc 8193575 4853fdc 93c90ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
from pathlib import Path
import soundfile as sf
import torch
import torchaudio
import hydra
from omegaconf import OmegaConf
import diffusers.schedulers as noise_schedulers
from utils.config import register_omegaconf_resolvers
from models.common import LoadPretrainedBase
from huggingface_hub import hf_hub_download
import fairseq
register_omegaconf_resolvers()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = OmegaConf.load("configs/infer.yaml")
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="model.safetensors",
repo_type="model",
force_download=False
)
exp_config = OmegaConf.load("configs/config.yaml")
if "pretrained_ckpt" in exp_config["model"]:
exp_config["model"]["pretrained_ckpt"] = ckpt_path
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
model = model.to(device)
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="hubert_large_ll60k.pt",
repo_type="model",
force_download=False
)
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert_model = hubert_models[0].eval().to(device)
scheduler = getattr(
noise_schedulers,
config["noise_scheduler"]["type"],
).from_pretrained(
config["noise_scheduler"]["name"],
subfolder="scheduler",
)
@torch.no_grad()
def infer(audio_path: str) -> str:
waveform_tts, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
if waveform_tts.shape[0] > 1:
waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
with torch.no_grad():
features, _ = hubert_model.extract_features(waveform_tts.to(device))
kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
kwargs['content'] = [features]
kwargs['condition'] = None
kwargs['task'] = ["speech_to_audio"]
model.eval()
waveform = model.inference(
scheduler=scheduler,
**kwargs,
)
output_file = "output_audio.wav"
sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
return output_file
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
## 🗣️ Input
A brief input speech utterance for the overall audio scene.
> Example:A cat meowing and young female speaking
### 🎙️ Input Speech Example
""")
speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
### 🎧️ Output Audio Example
""")
audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
gr.Markdown("""
</div>
---
</div>
""")
with gr.Column():
input_audio = gr.Audio(label="Speech Input", type="filepath")
btn = gr.Button("🎵Generate Audio!", variant="primary")
output_audio = gr.Audio(label="Generated Audio", type="filepath")
btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
demo.launch()
|