Spaces:
Sleeping
Sleeping
Niki Zhang
commited on
Update app.py
Browse filesCombine with TTS module
app.py
CHANGED
|
@@ -18,6 +18,16 @@ from caption_anything.segmenter import build_segmenter
|
|
| 18 |
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
| 19 |
from segment_anything import sam_model_registry
|
| 20 |
import easyocr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
args = parse_augment()
|
| 23 |
args.segmenter = "huge"
|
|
@@ -102,12 +112,12 @@ def init_openai_api_key(api_key=""):
|
|
| 102 |
print(text_refiner)
|
| 103 |
openai_available = text_refiner is not None
|
| 104 |
if openai_available:
|
| 105 |
-
return [gr.update(visible=True)]*
|
| 106 |
else:
|
| 107 |
-
return [gr.update(visible=False)]*
|
| 108 |
|
| 109 |
def init_wo_openai_api_key():
|
| 110 |
-
return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*
|
| 111 |
|
| 112 |
def get_click_prompt(chat_input, click_state, click_mode):
|
| 113 |
inputs = json.loads(chat_input)
|
|
@@ -256,7 +266,8 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
| 256 |
|
| 257 |
|
| 258 |
def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
|
| 259 |
-
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state
|
|
|
|
| 260 |
print("state",state)
|
| 261 |
|
| 262 |
click_index = click_index_state
|
|
@@ -291,13 +302,23 @@ def submit_caption(image_input, state, generated_caption, text_refiner, visual_c
|
|
| 291 |
print("new_cap",new_cap)
|
| 292 |
refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
|
| 293 |
input_points=input_points, input_labels=input_labels)
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
|
| 303 |
|
|
@@ -531,6 +552,7 @@ def create_ui():
|
|
| 531 |
interactive=True,
|
| 532 |
label="Generated Caption Length",
|
| 533 |
)
|
|
|
|
| 534 |
enable_wiki = gr.Radio(
|
| 535 |
choices=["Yes", "No"],
|
| 536 |
value="No",
|
|
@@ -541,6 +563,7 @@ def create_ui():
|
|
| 541 |
examples=examples,
|
| 542 |
inputs=[example_image],
|
| 543 |
)
|
|
|
|
| 544 |
with gr.Column(scale=0.5):
|
| 545 |
with gr.Column(visible=True) as module_key_input:
|
| 546 |
openai_api_key = gr.Textbox(
|
|
@@ -567,18 +590,52 @@ def create_ui():
|
|
| 567 |
with gr.Row():
|
| 568 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
| 569 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
|
|
|
| 571 |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
|
| 572 |
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
|
| 573 |
-
modules_not_need_gpt2, module_key_input,
|
| 574 |
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
|
| 575 |
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
|
| 576 |
modules_not_need_gpt,
|
| 577 |
-
modules_not_need_gpt2, module_key_input,
|
| 578 |
disable_chatGPT_button.click(init_wo_openai_api_key,
|
| 579 |
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
|
| 580 |
modules_not_need_gpt,
|
| 581 |
-
modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
| 582 |
|
| 583 |
enable_chatGPT_button.click(
|
| 584 |
lambda: (None, [], [], [[], [], []], "", "", ""),
|
|
@@ -663,13 +720,19 @@ def create_ui():
|
|
| 663 |
|
| 664 |
|
| 665 |
submit_button_click.click(
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
|
| 674 |
|
| 675 |
|
|
@@ -683,6 +746,9 @@ def create_ui():
|
|
| 683 |
show_progress=False, queue=True
|
| 684 |
)
|
| 685 |
|
|
|
|
|
|
|
|
|
|
| 686 |
return iface
|
| 687 |
|
| 688 |
|
|
@@ -690,4 +756,3 @@ if __name__ == '__main__':
|
|
| 690 |
iface = create_ui()
|
| 691 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
| 692 |
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
| 693 |
-
|
|
|
|
| 18 |
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
| 19 |
from segment_anything import sam_model_registry
|
| 20 |
import easyocr
|
| 21 |
+
import tts
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
article = """
|
| 27 |
+
<div style='margin:20px auto;'>
|
| 28 |
+
<p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p>
|
| 29 |
+
</div>
|
| 30 |
+
"""
|
| 31 |
|
| 32 |
args = parse_augment()
|
| 33 |
args.segmenter = "huge"
|
|
|
|
| 112 |
print(text_refiner)
|
| 113 |
openai_available = text_refiner is not None
|
| 114 |
if openai_available:
|
| 115 |
+
return [gr.update(visible=True)]*7 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
|
| 116 |
else:
|
| 117 |
+
return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
|
| 118 |
|
| 119 |
def init_wo_openai_api_key():
|
| 120 |
+
return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*3 + [gr.update(visible=False)]*2 + [None, None, None]
|
| 121 |
|
| 122 |
def get_click_prompt(chat_input, click_state, click_mode):
|
| 123 |
inputs = json.loads(chat_input)
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
|
| 269 |
+
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
|
| 270 |
+
input_text, input_language, input_audio, input_mic, use_mic, agree):
|
| 271 |
print("state",state)
|
| 272 |
|
| 273 |
click_index = click_index_state
|
|
|
|
| 302 |
print("new_cap",new_cap)
|
| 303 |
refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
|
| 304 |
input_points=input_points, input_labels=input_labels)
|
| 305 |
+
try:
|
| 306 |
+
waveform_visual, audio_output = tts.predict(new_cap, input_language, input_audio, input_mic, use_mic, agree)
|
| 307 |
+
print("error tts")
|
| 308 |
+
yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
|
| 309 |
+
except Exception as e:
|
| 310 |
+
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
|
| 311 |
+
print(f"Error during TTS prediction: {str(e)}")
|
| 312 |
+
yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
|
| 313 |
|
| 314 |
+
else:
|
| 315 |
+
try:
|
| 316 |
+
waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
|
| 317 |
+
yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
|
| 318 |
+
except Exception as e:
|
| 319 |
+
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
|
| 320 |
+
print(f"Error during TTS prediction: {str(e)}")
|
| 321 |
+
yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
|
| 322 |
|
| 323 |
|
| 324 |
|
|
|
|
| 552 |
interactive=True,
|
| 553 |
label="Generated Caption Length",
|
| 554 |
)
|
| 555 |
+
# 是否启用wiki内容整合到caption中
|
| 556 |
enable_wiki = gr.Radio(
|
| 557 |
choices=["Yes", "No"],
|
| 558 |
value="No",
|
|
|
|
| 563 |
examples=examples,
|
| 564 |
inputs=[example_image],
|
| 565 |
)
|
| 566 |
+
|
| 567 |
with gr.Column(scale=0.5):
|
| 568 |
with gr.Column(visible=True) as module_key_input:
|
| 569 |
openai_api_key = gr.Textbox(
|
|
|
|
| 590 |
with gr.Row():
|
| 591 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
| 592 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
| 593 |
+
|
| 594 |
+
# TTS interface hidden initially
|
| 595 |
+
with gr.Column(visible=False) as tts_interface:
|
| 596 |
+
input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality")
|
| 597 |
+
input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en")
|
| 598 |
+
input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav")
|
| 599 |
+
input_mic = gr.Audio(source="microphone", type="filepath", label="Use Microphone for Reference")
|
| 600 |
+
use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False)
|
| 601 |
+
agree = gr.Checkbox(label="Agree", value=True)
|
| 602 |
+
output_waveform = gr.Video(label="Waveform Visual")
|
| 603 |
+
output_audio = gr.Audio(label="Synthesised Audio")
|
| 604 |
+
|
| 605 |
+
with gr.Row():
|
| 606 |
+
submit_tts = gr.Button(value="Submit", interactive=True)
|
| 607 |
+
clear_tts = gr.Button(value="Clear", interactive=True)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def clear_tts_fields():
|
| 611 |
+
return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
|
| 612 |
+
|
| 613 |
+
submit_tts.click(
|
| 614 |
+
tts.predict,
|
| 615 |
+
inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree],
|
| 616 |
+
outputs=[output_waveform, output_audio],
|
| 617 |
+
queue=True
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
clear_tts.click(
|
| 621 |
+
clear_tts_fields,
|
| 622 |
+
inputs=None,
|
| 623 |
+
outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio],
|
| 624 |
+
queue=False
|
| 625 |
+
)
|
| 626 |
|
| 627 |
+
|
| 628 |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
|
| 629 |
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
|
| 630 |
+
modules_not_need_gpt2, tts_interface,module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
| 631 |
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
|
| 632 |
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
|
| 633 |
modules_not_need_gpt,
|
| 634 |
+
modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
| 635 |
disable_chatGPT_button.click(init_wo_openai_api_key,
|
| 636 |
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
|
| 637 |
modules_not_need_gpt,
|
| 638 |
+
modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
| 639 |
|
| 640 |
enable_chatGPT_button.click(
|
| 641 |
lambda: (None, [], [], [[], [], []], "", "", ""),
|
|
|
|
| 720 |
|
| 721 |
|
| 722 |
submit_button_click.click(
|
| 723 |
+
submit_caption,
|
| 724 |
+
inputs=[
|
| 725 |
+
image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
|
| 726 |
+
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
|
| 727 |
+
input_text, input_language, input_audio, input_mic, use_mic, agree
|
| 728 |
+
],
|
| 729 |
+
outputs=[
|
| 730 |
+
chatbot, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,
|
| 731 |
+
output_waveform, output_audio
|
| 732 |
+
],
|
| 733 |
+
show_progress=True,
|
| 734 |
+
queue=True
|
| 735 |
+
)
|
| 736 |
|
| 737 |
|
| 738 |
|
|
|
|
| 746 |
show_progress=False, queue=True
|
| 747 |
)
|
| 748 |
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
|
| 752 |
return iface
|
| 753 |
|
| 754 |
|
|
|
|
| 756 |
iface = create_ui()
|
| 757 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
| 758 |
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|