Spaces:
Paused
Paused
| import gradio as gr | |
| from vid2persona import init | |
| from vid2persona.pipeline import vlm | |
| from vid2persona.pipeline import llm | |
| init.init_model("HuggingFaceH4/zephyr-7b-beta") | |
| init.auth_gcp() | |
| init.get_env_vars() | |
| prompt_tpl_path = "vid2persona/prompts" | |
| async def extract_traits(video_path): | |
| traits = await vlm.get_traits( | |
| init.gcp_project_id, | |
| init.gcp_project_location, | |
| video_path, | |
| prompt_tpl_path | |
| ) | |
| if 'characters' in traits: | |
| traits = traits['characters'][0] | |
| return [ | |
| traits, [], | |
| gr.Textbox("", interactive=True), | |
| gr.Button(interactive=True), | |
| gr.Button(interactive=True), | |
| gr.Button(interactive=True) | |
| ] | |
| async def conversation( | |
| message: str, messages: list, traits: dict, | |
| model_id: str, max_input_token_length: int, | |
| max_new_tokens: int, temperature: float, | |
| top_p: float, top_k: float, repetition_penalty: float, | |
| ): | |
| messages = messages + [[message, ""]] | |
| yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)] | |
| async for partial_response in llm.chat( | |
| message, messages, traits, | |
| prompt_tpl_path, model_id, | |
| max_input_token_length, max_new_tokens, | |
| temperature, top_p, top_k, | |
| repetition_penalty, hf_token=init.hf_access_token | |
| ): | |
| last_message = messages[-1] | |
| last_message[1] = last_message[1] + partial_response | |
| messages[-1] = last_message | |
| yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] | |
| yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)] | |
| async def regen_conversation( | |
| messages: list, traits: dict, | |
| model_id: str, max_input_token_length: int, | |
| max_new_tokens: int, temperature: float, | |
| top_p: float, top_k: float, repetition_penalty: float, | |
| ): | |
| if len(messages) > 0: | |
| message = messages[-1][0] | |
| messages = messages[:-1] | |
| messages = messages + [[message, ""]] | |
| yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] | |
| async for partial_response in llm.chat( | |
| message, messages, traits, | |
| prompt_tpl_path, model_id, | |
| max_input_token_length, max_new_tokens, | |
| temperature, top_p, top_k, | |
| repetition_penalty, hf_token=init.hf_access_token | |
| ): | |
| last_message = messages[-1] | |
| last_message[1] = last_message[1] + partial_response | |
| messages[-1] = last_message | |
| yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] | |
| yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)] | |
| with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"]) | |
| gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them. " | |
| "[Gemini 1.0 Pro Vision model on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/overview) is used " | |
| "to grasp traits of video characters, then [HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) model " | |
| "is used to make conversation with them.",) | |
| gr.Markdown("This space is modified to be working on Hugging Face [ZeroGPU](https://huggingface.co/zero-gpu-explorers). If you wish to run " | |
| "the same application on your own machine, please check out the [project repository](https://github.com/deep-diver/Vid2Persona). " | |
| "You can interact with other LLMs to make conversation besides HuggingFaceH4/zephyr-7b-beta by running them locally, or by " | |
| "connecting them through remotely hosted within Text Generation Inference framework as [Hugging Face PRO](https://huggingface.co/blog/inference-pro) user.") | |
| with gr.Column(elem_classes=["group"]): | |
| with gr.Row(): | |
| video = gr.Video(label="upload short video clip", max_length=180) | |
| traits = gr.Json(label="extracted traits") | |
| with gr.Row(): | |
| trait_gen = gr.Button("generate traits") | |
| with gr.Column(elem_classes=["group"]): | |
| chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"]) | |
| with gr.Row(): | |
| clear = gr.Button("clear conversation", interactive=False) | |
| regen = gr.Button("regenerate the last", interactive=False) | |
| stop = gr.Button("stop", interactive=False) | |
| user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"]) | |
| with gr.Accordion("parameters' control pane", open=False): | |
| model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value="HuggingFaceH4/zephyr-7b-beta", label="Model ID", visible=False) | |
| with gr.Row(): | |
| max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=4096, label="max-input-tokens") | |
| max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=256, label="max-new-tokens") | |
| with gr.Row(): | |
| temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="temperature") | |
| top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.9, label="top-p") | |
| top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=50, label="top-k") | |
| repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.2, label="repetition-penalty") | |
| with gr.Row(): | |
| gr.Markdown( | |
| "[](https://github.com/deep-diver/Vid2Persona) " | |
| "[](https://twitter.com/algo_diver) " | |
| "[](https://twitter.com/RisingSayak )", | |
| elem_id="bottom-md" | |
| ) | |
| trait_gen.click( | |
| extract_traits, | |
| [video], | |
| [traits, chatbot, user_input, clear, regen, stop], | |
| concurrency_limit=5, | |
| ) | |
| conv = user_input.submit( | |
| conversation, | |
| [ | |
| user_input, chatbot, traits, | |
| model_id, max_input_token_length, | |
| max_new_tokens, temperature, | |
| top_p, top_k, repetition_penalty, | |
| ], | |
| [chatbot, user_input, clear, regen], | |
| concurrency_limit=5, | |
| ) | |
| clear.click( | |
| lambda: [ | |
| gr.Chatbot([]), | |
| gr.Button(interactive=False), | |
| gr.Button(interactive=False), | |
| ], | |
| None, [chatbot, clear, regen], | |
| concurrency_limit=5, | |
| ) | |
| conv_regen = regen.click( | |
| regen_conversation, | |
| [ | |
| chatbot, traits, | |
| model_id, max_input_token_length, | |
| max_new_tokens, temperature, | |
| top_p, top_k, repetition_penalty, | |
| ], | |
| [chatbot, user_input, clear, regen], | |
| concurrency_limit=5, | |
| ) | |
| stop.click( | |
| lambda: [ | |
| gr.Button(interactive=True), | |
| gr.Button(interactive=True), | |
| gr.Button(interactive=True), | |
| ], None, [clear, regen, stop], | |
| cancels=[conv, conv_regen], | |
| concurrency_limit=5, | |
| ) | |
| gr.Examples( | |
| [["assets/sample1.mp4"], ["assets/sample2.mp4"], ["assets/sample3.mp4"], ["assets/sample4.mp4"]], | |
| video, | |
| [traits, chatbot, user_input, clear, regen, stop], | |
| extract_traits, | |
| cache_examples=True | |
| ) | |
| demo.queue( | |
| max_size=256 | |
| ).launch( | |
| debug=True | |
| ) |