|
|
from argparse import ArgumentParser |
|
|
import copy |
|
|
import gradio as gr |
|
|
from gradio.themes.utils import colors, fonts, sizes |
|
|
|
|
|
from utils.easydict import EasyDict |
|
|
from tasks.eval.model_utils import load_pllava |
|
|
from tasks.eval.eval_utils import ( |
|
|
ChatPllava, |
|
|
conv_plain_v1, |
|
|
Conversation, |
|
|
conv_templates |
|
|
) |
|
|
from tasks.eval.demo import pllava_theme |
|
|
|
|
|
SYSTEM="""You are Pllava, a large vision-language assistant. |
|
|
You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. |
|
|
Follow the instructions carefully and explain your answers in detail based on the provided video. |
|
|
""" |
|
|
INIT_CONVERSATION: Conversation = conv_plain_v1.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_model(args): |
|
|
|
|
|
print('Initializing PLLaVA') |
|
|
model, processor = load_pllava( |
|
|
args.pretrained_model_name_or_path, args.num_frames, |
|
|
use_lora=args.use_lora, |
|
|
weight_dir=args.weight_dir, |
|
|
lora_alpha=args.lora_alpha, |
|
|
use_multi_gpus=args.use_multi_gpus) |
|
|
if not args.use_multi_gpus: |
|
|
model = model.to('cuda') |
|
|
chat = ChatPllava(model, processor) |
|
|
return chat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_reset(chat_state, img_list): |
|
|
if chat_state is not None: |
|
|
chat_state = INIT_CONVERSATION.copy() |
|
|
if img_list is not None: |
|
|
img_list = [] |
|
|
return ( |
|
|
None, |
|
|
gr.update(value=None, interactive=True), |
|
|
gr.update(value=None, interactive=True), |
|
|
gr.update(placeholder='Please upload your video first', interactive=False), |
|
|
gr.update(value="Upload & Start Chat", interactive=True), |
|
|
chat_state, |
|
|
img_list |
|
|
) |
|
|
|
|
|
|
|
|
def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None): |
|
|
print(gr_img, gr_video) |
|
|
chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state |
|
|
img_list = [] if img_list is None else img_list |
|
|
|
|
|
if gr_img is None and gr_video is None: |
|
|
return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None |
|
|
if gr_video: |
|
|
llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments) |
|
|
return ( |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True, placeholder='Type and press Enter'), |
|
|
gr.update(value="Start Chatting", interactive=False), |
|
|
chat_state, |
|
|
img_list, |
|
|
) |
|
|
if gr_img: |
|
|
llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list) |
|
|
return ( |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True, placeholder='Type and press Enter'), |
|
|
gr.update(value="Start Chatting", interactive=False), |
|
|
chat_state, |
|
|
img_list |
|
|
) |
|
|
|
|
|
|
|
|
def gradio_ask(user_message, chatbot, chat_state, system): |
|
|
if len(user_message) == 0: |
|
|
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state |
|
|
chat_state = chat.ask(user_message, chat_state, system) |
|
|
chatbot = chatbot + [[user_message, None]] |
|
|
return '', chatbot, chat_state |
|
|
|
|
|
|
|
|
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): |
|
|
llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature) |
|
|
llm_message = llm_message.replace("<s>", "") |
|
|
chatbot[-1][1] = llm_message |
|
|
print(chat_state) |
|
|
print(f"Answer: {llm_message}") |
|
|
return chatbot, chat_state, img_list |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--pretrained_model_name_or_path", |
|
|
type=str, |
|
|
required=True, |
|
|
default='llava-hf/llava-1.5-7b-hf' |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_frames", |
|
|
type=int, |
|
|
required=True, |
|
|
default=4, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--use_lora", |
|
|
action='store_true' |
|
|
) |
|
|
parser.add_argument( |
|
|
"--use_multi_gpus", |
|
|
action='store_true' |
|
|
) |
|
|
parser.add_argument( |
|
|
"--weight_dir", |
|
|
type=str, |
|
|
required=False, |
|
|
default=None, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--conv_mode", |
|
|
type=str, |
|
|
required=False, |
|
|
default=None, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lora_alpha", |
|
|
type=int, |
|
|
required=False, |
|
|
default=None, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--server_port", |
|
|
type=int, |
|
|
required=False, |
|
|
default=7868, |
|
|
) |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
title = """<h1 align="center"><a href="https://github.com/magic-research/PLLaVA"><img src="https://raw.githubusercontent.com/magic-research/PLLaVA/main/assert/logo.png" alt="PLLAVA" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>""" |
|
|
description = ( |
|
|
"""<br><p><a href='https://github.com/magic-research/PLLaVA'> |
|
|
# PLLAVA! |
|
|
<img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p> |
|
|
- Upload A Video |
|
|
- Press Upload |
|
|
- Start Chatting |
|
|
""" |
|
|
) |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
model_description = f""" |
|
|
# MODEL INFO |
|
|
- pretrained_model_name_or_path:{args.pretrained_model_name_or_path} |
|
|
- use_lora:{args.use_lora} |
|
|
- weight_dir:{args.weight_dir} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="PLLaVA", |
|
|
theme=pllava_theme, |
|
|
css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: |
|
|
gr.Markdown(title) |
|
|
gr.Markdown(description) |
|
|
gr.Markdown(model_description) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=0.5, visible=True) as video_upload: |
|
|
|
|
|
with gr.Tab("Video", elem_id='video_tab'): |
|
|
up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360) |
|
|
with gr.Tab("Image", elem_id='image_tab'): |
|
|
up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360) |
|
|
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") |
|
|
clear = gr.Button("Restart") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(visible=True) as input_raws: |
|
|
system_string = gr.Textbox(SYSTEM, interactive=True, label='system') |
|
|
num_beams = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=5, |
|
|
value=1, |
|
|
step=1, |
|
|
interactive=True, |
|
|
label="beam search numbers", |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
interactive=True, |
|
|
label="Temperature", |
|
|
) |
|
|
|
|
|
chat_state = gr.State() |
|
|
img_list = gr.State() |
|
|
chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation') |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=0.7): |
|
|
text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False) |
|
|
with gr.Column(scale=0.15, min_width=0): |
|
|
run = gr.Button("💭Send") |
|
|
with gr.Column(scale=0.15, min_width=0): |
|
|
clear = gr.Button("🔄Clear") |
|
|
|
|
|
with gr.Row(): |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
['example/jesse_dance.mp4', 'What is the man doing?'], |
|
|
['example/yoga.mp4', 'What is the woman doing?'], |
|
|
['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'], |
|
|
|
|
|
['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'], |
|
|
['example/1917.mp4', 'Describe the background, characters and the actions in the provided video.'], |
|
|
], |
|
|
inputs=[up_video, text_input], |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
chat = init_model(args) |
|
|
INIT_CONVERSATION = conv_templates[args.conv_mode] |
|
|
upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list]) |
|
|
|
|
|
text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( |
|
|
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] |
|
|
) |
|
|
run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( |
|
|
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] |
|
|
) |
|
|
run.click(lambda: "", None, text_input) |
|
|
clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False) |
|
|
|
|
|
demo.queue(max_size=5) |
|
|
demo.launch() |
|
|
|
|
|
|