Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,10 +13,13 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
| 13 |
|
| 14 |
|
| 15 |
repo_id = "parler-tts/parler-tts-mini-v1"
|
| 16 |
-
repo_id_large = "
|
|
|
|
| 17 |
|
| 18 |
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
|
| 19 |
model_large = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_large).to(device)
|
|
|
|
|
|
|
| 20 |
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
| 21 |
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
| 22 |
|
|
@@ -76,19 +79,23 @@ def preprocess(text):
|
|
| 76 |
return text
|
| 77 |
|
| 78 |
@spaces.GPU
|
| 79 |
-
def gen_tts(text, description,
|
| 80 |
inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
|
| 81 |
prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
|
| 82 |
|
| 83 |
set_seed(SEED)
|
| 84 |
-
if
|
| 85 |
generation = model_large.generate(
|
| 86 |
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
|
| 87 |
)
|
| 88 |
-
|
| 89 |
generation = model.generate(
|
| 90 |
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
|
| 91 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
audio_arr = generation.cpu().numpy().squeeze()
|
| 93 |
|
| 94 |
return SAMPLE_RATE, audio_arr
|
|
@@ -163,12 +170,12 @@ with gr.Blocks(css=css) as block:
|
|
| 163 |
with gr.Column():
|
| 164 |
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
|
| 165 |
description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
|
| 166 |
-
|
| 167 |
run_button = gr.Button("Generate Audio", variant="primary")
|
| 168 |
with gr.Column():
|
| 169 |
audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
|
| 170 |
|
| 171 |
-
inputs = [input_text, description,
|
| 172 |
outputs = [audio_out]
|
| 173 |
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
|
| 174 |
gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
repo_id = "parler-tts/parler-tts-mini-v1"
|
| 16 |
+
repo_id_large = "parler-tts/parler-tts-large-v1"
|
| 17 |
+
repo_id_tiny = "parler-tts/parler-tts-tiny-v1"
|
| 18 |
|
| 19 |
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
|
| 20 |
model_large = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_large).to(device)
|
| 21 |
+
model_tiny = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_tiny).to(device)
|
| 22 |
+
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
| 24 |
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
| 25 |
|
|
|
|
| 79 |
return text
|
| 80 |
|
| 81 |
@spaces.GPU
|
| 82 |
+
def gen_tts(text, description, version_to_use=False):
|
| 83 |
inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
|
| 84 |
prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
|
| 85 |
|
| 86 |
set_seed(SEED)
|
| 87 |
+
if version_to_use=="Large":
|
| 88 |
generation = model_large.generate(
|
| 89 |
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
|
| 90 |
)
|
| 91 |
+
elif version_to_use=="Miny":
|
| 92 |
generation = model.generate(
|
| 93 |
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
|
| 94 |
)
|
| 95 |
+
else:
|
| 96 |
+
generation = model_tiny.generate(
|
| 97 |
+
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
|
| 98 |
+
)
|
| 99 |
audio_arr = generation.cpu().numpy().squeeze()
|
| 100 |
|
| 101 |
return SAMPLE_RATE, audio_arr
|
|
|
|
| 170 |
with gr.Column():
|
| 171 |
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
|
| 172 |
description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
|
| 173 |
+
version_to_use = gr.Radio(["Tiny", "Mini", "Large"], value="Mini", label="Checkpoint to use", info="The larger the model, the better it is, at the cost of speed.")
|
| 174 |
run_button = gr.Button("Generate Audio", variant="primary")
|
| 175 |
with gr.Column():
|
| 176 |
audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
|
| 177 |
|
| 178 |
+
inputs = [input_text, description, version_to_use]
|
| 179 |
outputs = [audio_out]
|
| 180 |
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
|
| 181 |
gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)
|