STAR / app.py
Yixuan Li
first commit
4853fdc
raw
history blame
3.56 kB
import gradio as gr
from pathlib import Path
import soundfile as sf
import torch
import torchaudio
import hydra
from omegaconf import OmegaConf
import diffusers.schedulers as noise_schedulers
from utils.config import register_omegaconf_resolvers
from models.common import LoadPretrainedBase
from huggingface_hub import hf_hub_download
import fairseq
register_omegaconf_resolvers()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = OmegaConf.load("configs/infer.yaml")
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="model.safetensors",
repo_type="model",
force_download=False
)
exp_config = OmegaConf.load("configs/config.yaml")
if "pretrained_ckpt" in exp_config["model"]:
exp_config["model"]["pretrained_ckpt"] = ckpt_path
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
model = model.to(device)
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="hubert_large_ll60k.pt",
repo_type="model",
force_download=False
)
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert_model = hubert_models[0].eval().to(device)
scheduler = getattr(
noise_schedulers,
config["noise_scheduler"]["type"],
).from_pretrained(
config["noise_scheduler"]["name"],
subfolder="scheduler",
)
@torch.no_grad()
def infer(audio_path: str) -> str:
waveform_tts, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
if waveform_tts.shape[0] > 1:
waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
with torch.no_grad():
features, _ = hubert_model.extract_features(waveform_tts.to(device))
kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
kwargs['content'] = [features]
kwargs['condition'] = None
kwargs['task'] = ["speech_to_audio"]
model.eval()
waveform = model.inference(
scheduler=scheduler,
**kwargs,
)
output_file = "output_audio.wav"
sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
return output_file
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
## 🗣️ Input
A brief input speech utterance for the overall audio scene.
> Example:A cat meowing and young female speaking
### 🎙️ Input Speech Example
""")
speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
### 🎧️ Output Audio Example
""")
audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
gr.Markdown("""
</div>
---
</div>
""")
with gr.Column():
input_audio = gr.Audio(label="Speech Input", type="filepath")
btn = gr.Button("🎵Generate Audio!", variant="primary")
output_audio = gr.Audio(label="Generated Audio", type="filepath")
btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
demo.launch()