wsntxxn commited on
Commit
d64def4
·
verified ·
1 Parent(s): e7b2677

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -196
app.py CHANGED
@@ -1,197 +1,199 @@
1
- import gradio as gr
2
- from pathlib import Path
3
-
4
- import soundfile as sf
5
-
6
- # forcing torch.load to CPU
7
- import torch
8
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
- _old_load = torch.load
10
-
11
- def safe_torch_load(*args, **kwargs):
12
- args = list(args)
13
- if len(args) >= 2:
14
- args[1] = device
15
- else:
16
- kwargs['map_location'] = device
17
- return _old_load(*args, **kwargs)
18
-
19
- torch.load = safe_torch_load
20
-
21
- import torchaudio
22
- import hydra
23
- from omegaconf import OmegaConf
24
- import diffusers.schedulers as noise_schedulers
25
-
26
- from utils.config import register_omegaconf_resolvers
27
- from models.common import LoadPretrainedBase
28
-
29
- from huggingface_hub import hf_hub_download
30
- import fairseq
31
-
32
- register_omegaconf_resolvers()
33
- config = OmegaConf.load("configs/infer.yaml")
34
-
35
- ckpt_path = hf_hub_download(
36
- repo_id="assasinatee/STAR",
37
- filename="model.safetensors",
38
- repo_type="model",
39
- force_download=False
40
- )
41
-
42
- exp_config = OmegaConf.load("configs/config.yaml")
43
- if "pretrained_ckpt" in exp_config["model"]:
44
- exp_config["model"]["pretrained_ckpt"] = ckpt_path
45
- model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
46
-
47
- model = model.to(device)
48
-
49
- ckpt_path = hf_hub_download(
50
- repo_id="assasinatee/STAR",
51
- filename="hubert_large_ll60k.pt",
52
- repo_type="model",
53
- force_download=False
54
- )
55
- hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
56
- hubert_model = hubert_models[0].eval().to(device)
57
-
58
- scheduler = getattr(
59
- noise_schedulers,
60
- config["noise_scheduler"]["type"],
61
- ).from_pretrained(
62
- config["noise_scheduler"]["name"],
63
- subfolder="scheduler",
64
- )
65
-
66
- @torch.no_grad()
67
- def infer(audio_path: str) -> str:
68
- waveform_tts, sample_rate = torchaudio.load(audio_path)
69
- if sample_rate != 16000:
70
- waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
71
- if waveform_tts.shape[0] > 1:
72
- waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
73
- with torch.no_grad():
74
- features, _ = hubert_model.extract_features(waveform_tts.to(device))
75
-
76
- kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
77
- kwargs['content'] = [features]
78
- kwargs['condition'] = None
79
- kwargs['task'] = ["speech_to_audio"]
80
-
81
- model.eval()
82
- waveform = model.inference(
83
- scheduler=scheduler,
84
- **kwargs,
85
- )
86
-
87
- output_file = "output_audio.wav"
88
- sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
89
-
90
- return output_file
91
-
92
- with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
93
- gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
94
-
95
- gr.Markdown("""
96
- <div style="text-align: left; padding: 10px;">
97
-
98
- ## 📚️ Introduction
99
-
100
- STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems.
101
-
102
- Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output.
103
-
104
- ## 🗣️ Input
105
-
106
- A brief input speech utterance for the overall audio scene.
107
-
108
- > Example:A cat meowing and young female speaking
109
-
110
- ### 🎙️ Input Speech Example
111
- """)
112
-
113
- speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
114
-
115
- gr.Markdown("""
116
- <div style="text-align: left; padding: 10px;">
117
-
118
- ## 🎧️ Output
119
-
120
- Capture both auditory events and scene cues and generate corresponding audio
121
-
122
- ### 🔊 Output Audio Example
123
- """)
124
-
125
- audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
126
-
127
- gr.Markdown("""
128
- <div style="text-align: left; padding: 10px;">
129
-
130
- </div>
131
-
132
- ---
133
-
134
- </div>
135
-
136
- ## 🛠️ Online Inference
137
-
138
- You can upload your own samples, or try the quick examples provided below.
139
- """)
140
-
141
- with gr.Column():
142
- input_audio = gr.Audio(label="🗣️ Speech Input", type="filepath")
143
- btn = gr.Button("🎵Generate Audio!", variant="primary")
144
- output_audio = gr.Audio(label="🎧️ Generated Audio", type="filepath")
145
- btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
146
-
147
- gr.Markdown("""
148
- <div style="text-align: left; padding: 10px;">
149
-
150
- ## 🎯 Quick Examples
151
- """)
152
-
153
- display_caption = gr.Textbox(label="📝 Caption" ,visible=False)
154
-
155
- with gr.Tabs():
156
- with gr.Tab("VITS Generated Speech"):
157
- gr.Examples(
158
- examples=[
159
- ["wav/vits/1.wav", "A cat meowing and young female speaking"],
160
- ["wav/vits/2.wav", "Sustained industrial engine noise"],
161
- ["wav/vits/3.wav", "A woman talks and a baby whispers"],
162
- ["wav/vits/4.wav", "A man speaks followed by a toilet flush"],
163
- ["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"],
164
- ["wav/vits/6.wav", "A man speaking as birds are chirping"],
165
- ["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
166
- ["wav/vits/8.wav", "Birds chirping and a horse neighing"],
167
- ["wav/vits/9.wav", "Several church bells ringing"],
168
- ["wav/vits/10.wav", "A telephone rings with bell sounds"]
169
- ],
170
- inputs=[input_audio, display_caption],
171
- label="Click examples below to try!",
172
- cache_examples = False,
173
- examples_per_page = 10,
174
- )
175
-
176
- with gr.Tab("Real Human Speech"):
177
- gr.Examples(
178
- examples=[
179
- ["wav/human/1.wav", "A cat meowing and young female speaking"],
180
- ["wav/human/2.wav", "Sustained industrial engine noise"],
181
- ["wav/human/3.wav", "A woman talks and a baby whispers"],
182
- ["wav/human/4.wav", "A man speaks followed by a toilet flush"],
183
- ["wav/human/5.wav", "It is raining and thundering, and then a man speaks"],
184
- ["wav/human/6.wav", "A man speaking as birds are chirping"],
185
- ["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
186
- ["wav/human/8.wav", "Birds chirping and a horse neighing"],
187
- ["wav/human/9.wav", "Several church bells ringing"],
188
- ["wav/human/10.wav", "A telephone rings with bell sounds"]
189
- ],
190
- inputs=[input_audio, display_caption],
191
- label="Click examples below to try!",
192
- cache_examples = False,
193
- examples_per_page = 10,
194
- )
195
-
196
-
 
 
197
  demo.launch()
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+
4
+ import soundfile as sf
5
+
6
+ import spaces
7
+ # forcing torch.load to CPU
8
+ import torch
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ _old_load = torch.load
11
+
12
+ def safe_torch_load(*args, **kwargs):
13
+ args = list(args)
14
+ if len(args) >= 2:
15
+ args[1] = device
16
+ else:
17
+ kwargs['map_location'] = device
18
+ return _old_load(*args, **kwargs)
19
+
20
+ torch.load = safe_torch_load
21
+
22
+ import torchaudio
23
+ import hydra
24
+ from omegaconf import OmegaConf
25
+ import diffusers.schedulers as noise_schedulers
26
+
27
+ from utils.config import register_omegaconf_resolvers
28
+ from models.common import LoadPretrainedBase
29
+
30
+ from huggingface_hub import hf_hub_download
31
+ import fairseq
32
+
33
+ register_omegaconf_resolvers()
34
+ config = OmegaConf.load("configs/infer.yaml")
35
+
36
+ ckpt_path = hf_hub_download(
37
+ repo_id="assasinatee/STAR",
38
+ filename="model.safetensors",
39
+ repo_type="model",
40
+ force_download=False
41
+ )
42
+
43
+ exp_config = OmegaConf.load("configs/config.yaml")
44
+ if "pretrained_ckpt" in exp_config["model"]:
45
+ exp_config["model"]["pretrained_ckpt"] = ckpt_path
46
+ model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
47
+
48
+ model = model.to(device)
49
+
50
+ ckpt_path = hf_hub_download(
51
+ repo_id="assasinatee/STAR",
52
+ filename="hubert_large_ll60k.pt",
53
+ repo_type="model",
54
+ force_download=False
55
+ )
56
+ hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
57
+ hubert_model = hubert_models[0].eval().to(device)
58
+
59
+ scheduler = getattr(
60
+ noise_schedulers,
61
+ config["noise_scheduler"]["type"],
62
+ ).from_pretrained(
63
+ config["noise_scheduler"]["name"],
64
+ subfolder="scheduler",
65
+ )
66
+
67
+ @torch.no_grad()
68
+ @spaces.GPU(duration=60)
69
+ def infer(audio_path: str) -> str:
70
+ waveform_tts, sample_rate = torchaudio.load(audio_path)
71
+ if sample_rate != 16000:
72
+ waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
73
+ if waveform_tts.shape[0] > 1:
74
+ waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
75
+ with torch.no_grad():
76
+ features, _ = hubert_model.extract_features(waveform_tts.to(device))
77
+
78
+ kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
79
+ kwargs['content'] = [features]
80
+ kwargs['condition'] = None
81
+ kwargs['task'] = ["speech_to_audio"]
82
+
83
+ model.eval()
84
+ waveform = model.inference(
85
+ scheduler=scheduler,
86
+ **kwargs,
87
+ )
88
+
89
+ output_file = "output_audio.wav"
90
+ sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
91
+
92
+ return output_file
93
+
94
+ with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
95
+ gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
96
+
97
+ gr.Markdown("""
98
+ <div style="text-align: left; padding: 10px;">
99
+
100
+ ## 📚️ Introduction
101
+
102
+ STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems.
103
+
104
+ Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output.
105
+
106
+ ## 🗣️ Input
107
+
108
+ A brief input speech utterance for the overall audio scene.
109
+
110
+ > Example:A cat meowing and young female speaking
111
+
112
+ ### 🎙️ Input Speech Example
113
+ """)
114
+
115
+ speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
116
+
117
+ gr.Markdown("""
118
+ <div style="text-align: left; padding: 10px;">
119
+
120
+ ## 🎧️ Output
121
+
122
+ Capture both auditory events and scene cues and generate corresponding audio
123
+
124
+ ### 🔊 Output Audio Example
125
+ """)
126
+
127
+ audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
128
+
129
+ gr.Markdown("""
130
+ <div style="text-align: left; padding: 10px;">
131
+
132
+ </div>
133
+
134
+ ---
135
+
136
+ </div>
137
+
138
+ ## 🛠️ Online Inference
139
+
140
+ You can upload your own samples, or try the quick examples provided below.
141
+ """)
142
+
143
+ with gr.Column():
144
+ input_audio = gr.Audio(label="🗣️ Speech Input", type="filepath")
145
+ btn = gr.Button("🎵Generate Audio!", variant="primary")
146
+ output_audio = gr.Audio(label="🎧️ Generated Audio", type="filepath")
147
+ btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
148
+
149
+ gr.Markdown("""
150
+ <div style="text-align: left; padding: 10px;">
151
+
152
+ ## 🎯 Quick Examples
153
+ """)
154
+
155
+ display_caption = gr.Textbox(label="📝 Caption" ,visible=False)
156
+
157
+ with gr.Tabs():
158
+ with gr.Tab("VITS Generated Speech"):
159
+ gr.Examples(
160
+ examples=[
161
+ ["wav/vits/1.wav", "A cat meowing and young female speaking"],
162
+ ["wav/vits/2.wav", "Sustained industrial engine noise"],
163
+ ["wav/vits/3.wav", "A woman talks and a baby whispers"],
164
+ ["wav/vits/4.wav", "A man speaks followed by a toilet flush"],
165
+ ["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"],
166
+ ["wav/vits/6.wav", "A man speaking as birds are chirping"],
167
+ ["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
168
+ ["wav/vits/8.wav", "Birds chirping and a horse neighing"],
169
+ ["wav/vits/9.wav", "Several church bells ringing"],
170
+ ["wav/vits/10.wav", "A telephone rings with bell sounds"]
171
+ ],
172
+ inputs=[input_audio, display_caption],
173
+ label="Click examples below to try!",
174
+ cache_examples = False,
175
+ examples_per_page = 10,
176
+ )
177
+
178
+ with gr.Tab("Real Human Speech"):
179
+ gr.Examples(
180
+ examples=[
181
+ ["wav/human/1.wav", "A cat meowing and young female speaking"],
182
+ ["wav/human/2.wav", "Sustained industrial engine noise"],
183
+ ["wav/human/3.wav", "A woman talks and a baby whispers"],
184
+ ["wav/human/4.wav", "A man speaks followed by a toilet flush"],
185
+ ["wav/human/5.wav", "It is raining and thundering, and then a man speaks"],
186
+ ["wav/human/6.wav", "A man speaking as birds are chirping"],
187
+ ["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
188
+ ["wav/human/8.wav", "Birds chirping and a horse neighing"],
189
+ ["wav/human/9.wav", "Several church bells ringing"],
190
+ ["wav/human/10.wav", "A telephone rings with bell sounds"]
191
+ ],
192
+ inputs=[input_audio, display_caption],
193
+ label="Click examples below to try!",
194
+ cache_examples = False,
195
+ examples_per_page = 10,
196
+ )
197
+
198
+
199
  demo.launch()