File size: 3,555 Bytes
93c90ef
4853fdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93c90ef
 
 
 
 
 
4853fdc
93c90ef
 
 
 
4853fdc
 
 
 
 
 
 
 
 
93c90ef
4853fdc
 
 
 
93c90ef
4853fdc
93c90ef
 
 
4853fdc
8193575
4853fdc
 
 
 
 
 
 
93c90ef
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from pathlib import Path

import soundfile as sf
import torch
import torchaudio
import hydra
from omegaconf import OmegaConf
import diffusers.schedulers as noise_schedulers

from utils.config import register_omegaconf_resolvers
from models.common import LoadPretrainedBase

from huggingface_hub import hf_hub_download
import fairseq

register_omegaconf_resolvers()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = OmegaConf.load("configs/infer.yaml")

ckpt_path = hf_hub_download(
    repo_id="assasinatee/STAR",   
    filename="model.safetensors",  
    repo_type="model",          
    force_download=False
)

exp_config = OmegaConf.load("configs/config.yaml")
if "pretrained_ckpt" in exp_config["model"]:
    exp_config["model"]["pretrained_ckpt"] = ckpt_path
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])

model = model.to(device)

ckpt_path = hf_hub_download(
    repo_id="assasinatee/STAR", 
    filename="hubert_large_ll60k.pt",           
    repo_type="model",             
    force_download=False
)
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert_model = hubert_models[0].eval().to(device)

scheduler = getattr(
    noise_schedulers,
    config["noise_scheduler"]["type"],
).from_pretrained(
    config["noise_scheduler"]["name"],
    subfolder="scheduler",
)

@torch.no_grad()
def infer(audio_path: str) -> str:
    waveform_tts, sample_rate = torchaudio.load(audio_path)
    if sample_rate != 16000:
        waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
    if waveform_tts.shape[0] > 1:
        waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
    with torch.no_grad():
        features, _ = hubert_model.extract_features(waveform_tts.to(device))

    kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
    kwargs['content'] = [features]
    kwargs['condition'] = None
    kwargs['task']      = ["speech_to_audio"]

    model.eval()   
    waveform = model.inference(
        scheduler=scheduler,
        **kwargs,
    )

    output_file = "output_audio.wav"
    sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])

    return output_file    

with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
    
    gr.Markdown("""

<div style="text-align: left; padding: 10px;">

                

## 🗣️ Input



A brief input speech utterance for the overall audio scene.  



> Example:A cat meowing and young female speaking



### 🎙️ Input Speech Example

""")
    
    speech = gr.Audio(value="wav/speech.wav",  label="Input Speech Example", type="filepath")

    gr.Markdown("""

<div style="text-align: left; padding: 10px;">



### 🎧️ Output Audio Example

""")
    
    audio = gr.Audio(value="wav/audio.wav",  label="Generated Audio Example", type="filepath")

    gr.Markdown("""

</div>

---

</div>

""")       

    with gr.Column():
        input_audio = gr.Audio(label="Speech Input", type="filepath")
        btn = gr.Button("🎵Generate Audio!", variant="primary")
        output_audio = gr.Audio(label="Generated Audio", type="filepath")
        btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
              
demo.launch()