| import shutil |
| import subprocess |
|
|
| import torch |
| import gradio as gr |
| from fastapi import FastAPI |
| import os |
| from PIL import Image |
| import tempfile |
| from decord import VideoReader, cpu |
| from transformers import TextStreamer |
|
|
| from llava.constants import DEFAULT_X_TOKEN, X_TOKEN_INDEX |
| from llava.conversation import conv_templates, SeparatorStyle, Conversation |
| from llava.serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css |
|
|
| import os, re, math, time, tempfile, shutil |
| import requests |
| import numpy as np |
| from PIL import Image |
| from decord import VideoReader |
| import ffmpeg |
|
|
|
|
| |
| GEN_KW = dict( |
| do_sample=False, |
| temperature=0.0, |
| top_p=1.0, |
| repetition_penalty=1.15, |
| no_repeat_ngram_size=3, |
| use_cache=False, |
| ) |
|
|
| |
| def _big_gpu(): |
| try: |
| return (torch.cuda.is_available() and |
| torch.cuda.get_device_properties(0).total_memory/1024**3 >= 40) |
| except Exception: |
| return False |
|
|
| MAX_NEW_TOKENS_SMALL = 128 |
| MAX_NEW_TOKENS_BIG = 256 |
|
|
|
|
| def _uniform_indices(n_total, n_want): |
| if n_total <= 0 or n_want <= 0: |
| return [] |
| return np.linspace(0, n_total-1, n_want).round().astype(int).tolist() |
|
|
| def sample_frames(video_path, n_frames=8): |
| """Return (frames_numpy[N,H,W,3], timestamps_sec[N]) sampled uniformly.""" |
| vr = VideoReader(video_path) |
| idx = _uniform_indices(len(vr), n_frames) |
| frames = vr.get_batch(idx).asnumpy() |
| fps = float(vr.get_avg_fps()) |
| ts = [i / fps for i in idx] |
| return frames, ts |
|
|
| def mmss(s): |
| m = int(s // 60); ss = int(round(s - 60*m)) |
| return f"{m:02d}:{ss:02d}" |
|
|
| def fetch_video_from_url(url, out_dir=None, max_seconds=None): |
| """Download URL to a local mp4; optionally trim with ffmpeg to first max_seconds.""" |
| if out_dir is None: |
| out_dir = tempfile.mkdtemp() |
| local = os.path.join(out_dir, "input.mp4") |
| with requests.get(url, stream=True, timeout=30) as r: |
| r.raise_for_status() |
| with open(local, "wb") as f: |
| for chunk in r.iter_content(chunk_size=1<<20): |
| if chunk: |
| f.write(chunk) |
| if (max_seconds is not None) and max_seconds > 0: |
| trimmed = os.path.join(out_dir, "input_trimmed.mp4") |
| ( |
| ffmpeg |
| .input(local) |
| .output(trimmed, t=max_seconds, c='copy', loglevel="error") |
| .overwrite_output() |
| .run() |
| ) |
| return trimmed |
| return local |
|
|
|
|
| def keep_frame_lines(text, T): |
| """Enforce 'Frame i: ...' lines; fill missing frames with placeholders.""" |
| lines = [] |
| for ln in text.splitlines(): |
| m = re.match(r"^Frame\s+(\d+)\s*:\s*(.+)$", ln.strip()) |
| if not m: |
| continue |
| i = int(m.group(1)) |
| body = " ".join(m.group(2).split()[:10]) |
| if 1 <= i <= T: |
| lines.append((i, f"Frame {i}: {body}")) |
| have = {i for i,_ in lines} |
| for i in range(1, T+1): |
| if i not in have: |
| lines.append((i, f"Frame {i}: (no description)")) |
| return "\n".join(t for _, t in sorted(lines)) |
|
|
|
|
| def build_framewise_prompt(T): |
| return ( |
| f"You will output exactly {T} plain lines, one per frame.\n" |
| "Format strictly:\n" |
| "Frame 1: <<=10 words>\n" |
| "Frame 2: <<=10 words>\n" |
| "...\n" |
| "No brackets [], no JSON, no code blocks, no numbered list other than 'Frame i:'." |
| ) |
|
|
|
|
|
|
|
|
| def save_image_to_local(image): |
| filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') |
| image = Image.open(image) |
| image.save(filename) |
| |
| return filename |
|
|
|
|
| def save_video_to_local(video_path): |
| filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') |
| shutil.copyfile(video_path, filename) |
| return filename |
|
|
|
|
| def generate(image1, video, textbox_in, first_run, state, state_, images_tensor): |
| flag = 1 |
| if not textbox_in: |
| if len(state_.messages) > 0: |
| textbox_in = state_.messages[-1][1] |
| state_.messages.pop(-1) |
| flag = 0 |
| else: |
| return "Please enter instruction" |
|
|
| image1 = image1 if image1 else "none" |
| video = video if video else "none" |
| |
|
|
| if type(state) is not Conversation: |
| state = conv_templates[conv_mode].copy() |
| state_ = conv_templates[conv_mode].copy() |
| images_tensor = [[], []] |
|
|
| first_run = False if len(state.messages) > 0 else True |
|
|
| text_en_in = textbox_in.replace("picture", "image") |
|
|
| |
| image_processor = handler.image_processor |
| if os.path.exists(image1) and not os.path.exists(video): |
| tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] |
| |
| tensor = tensor.to(handler.model.device, dtype=dtype) |
| images_tensor[0] = images_tensor[0] + [tensor] |
| images_tensor[1] = images_tensor[1] + ['image'] |
| print(torch.cuda.memory_allocated()) |
| print(torch.cuda.max_memory_allocated()) |
| video_processor = handler.video_processor |
| if not os.path.exists(image1) and os.path.exists(video): |
| tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] |
| |
| tensor = tensor.to(handler.model.device, dtype=dtype) |
| images_tensor[0] = images_tensor[0] + [tensor] |
| images_tensor[1] = images_tensor[1] + ['video'] |
| print(torch.cuda.memory_allocated()) |
| print(torch.cuda.max_memory_allocated()) |
| if os.path.exists(image1) and os.path.exists(video): |
| tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] |
| |
| tensor = tensor.to(handler.model.device, dtype=dtype) |
| images_tensor[0] = images_tensor[0] + [tensor] |
| images_tensor[1] = images_tensor[1] + ['video'] |
| |
| |
| tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] |
| |
| tensor = tensor.to(handler.model.device, dtype=dtype) |
| images_tensor[0] = images_tensor[0] + [tensor] |
| images_tensor[1] = images_tensor[1] + ['image'] |
| print(torch.cuda.memory_allocated()) |
| print(torch.cuda.max_memory_allocated()) |
| |
|
|
|
|
| if os.path.exists(image1) and not os.path.exists(video): |
| text_en_in = DEFAULT_X_TOKEN['IMAGE'] + '\n' + text_en_in |
| if not os.path.exists(image1) and os.path.exists(video): |
| text_en_in = DEFAULT_X_TOKEN['VIDEO'] + '\n' + text_en_in |
| if os.path.exists(image1) and os.path.exists(video): |
| text_en_in = DEFAULT_X_TOKEN['VIDEO'] + '\n' + text_en_in + '\n' + DEFAULT_X_TOKEN['IMAGE'] |
|
|
| text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) |
| state_.messages[-1] = (state_.roles[1], text_en_out) |
|
|
| text_en_out = text_en_out.split('#')[0] |
| textbox_out = text_en_out |
|
|
| show_images = "" |
| if os.path.exists(image1): |
| filename = save_image_to_local(image1) |
| show_images += f'<img src="./file={filename}" style="display: inline-block;width: 250px;max-height: 400px;">' |
| if os.path.exists(video): |
| filename = save_video_to_local(video) |
| show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>' |
|
|
| if flag: |
| state.append_message(state.roles[0], textbox_in + "\n" + show_images) |
| state.append_message(state.roles[1], textbox_out) |
| torch.cuda.empty_cache() |
| return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) |
|
|
| def regenerate(state, state_): |
| state.messages.pop(-1) |
| state_.messages.pop(-1) |
| if len(state.messages) > 0: |
| return state, state_, state.to_gradio_chatbot(), False |
| return (state, state_, state.to_gradio_chatbot(), True) |
|
|
|
|
| def clear_history(state, state_): |
| state = conv_templates[conv_mode].copy() |
| state_ = conv_templates[conv_mode].copy() |
| return (gr.update(value=None, interactive=True), |
| gr.update(value=None, interactive=True),\ |
| gr.update(value=None, interactive=True),\ |
| True, state, state_, state.to_gradio_chatbot(), [[], []]) |
|
|
|
|
|
|
| conv_mode = "llava_v1" |
| model_path = 'LanguageBind/Video-LLaVA-7B' |
| device = 'cuda' |
| load_8bit = False |
| load_4bit = True |
| dtype = torch.float16 |
| handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device) |
| |
| if not os.path.exists("temp"): |
| os.makedirs("temp") |
|
|
| print(torch.cuda.memory_allocated()) |
| print(torch.cuda.max_memory_allocated()) |
|
|
| app = FastAPI() |
|
|
| textbox = gr.Textbox( |
| show_label=False, placeholder="Enter text and press ENTER", container=False |
| ) |
| with gr.Blocks(title='Video-LLaVA🚀', theme=gr.themes.Default(), css=block_css) as demo: |
| gr.Markdown(title_markdown) |
| state = gr.State() |
| state_ = gr.State() |
| first_run = gr.State() |
| images_tensor = gr.State() |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| image1 = gr.Image(label="Input Image", type="filepath") |
| video = gr.Video(label="Input Video") |
|
|
| cur_dir = os.path.dirname(os.path.abspath(__file__)) |
| gr.Examples( |
| examples=[ |
| [ |
| f"{cur_dir}/examples/extreme_ironing.jpg", |
| "What is unusual about this image?", |
| ], |
| [ |
| f"{cur_dir}/examples/waterview.jpg", |
| "What are the things I should be cautious about when I visit here?", |
| ], |
| [ |
| f"{cur_dir}/examples/desert.jpg", |
| "If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?", |
| ], |
| ], |
| inputs=[image1, textbox], |
| ) |
|
|
| with gr.Column(scale=7): |
| chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750) |
| with gr.Row(): |
| with gr.Column(scale=8): |
| textbox.render() |
| with gr.Column(scale=1, min_width=50): |
| submit_btn = gr.Button( |
| value="Send", variant="primary", interactive=True |
| ) |
| with gr.Row(elem_id="buttons") as button_row: |
| upvote_btn = gr.Button(value="👍 Upvote", interactive=True) |
| downvote_btn = gr.Button(value="👎 Downvote", interactive=True) |
| flag_btn = gr.Button(value="⚠️ Flag", interactive=True) |
| |
| regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) |
| clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) |
|
|
| with gr.Row(): |
| gr.Examples( |
| examples=[ |
| [ |
| f"{cur_dir}/examples/sample_img_8.png", |
| f"{cur_dir}/examples/sample_demo_8.mp4", |
| "Are the image and the video depicting the same place?", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_img_22.png", |
| f"{cur_dir}/examples/sample_demo_22.mp4", |
| "Are the instruments in the pictures used in the video?", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_img_13.png", |
| f"{cur_dir}/examples/sample_demo_13.mp4", |
| "Does the flag in the image appear in the video?", |
| ], |
| ], |
| inputs=[image1, video, textbox], |
| ) |
| gr.Examples( |
| examples=[ |
| [ |
| f"{cur_dir}/examples/sample_demo_1.mp4", |
| "Why is this video funny?", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_demo_7.mp4", |
| "Create a short fairy tale with a moral lesson inspired by the video.", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_demo_8.mp4", |
| "Where is this video taken from? What place/landmark is shown in the video?", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_demo_12.mp4", |
| "What does the woman use to split the logs and how does she do it?", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_demo_18.mp4", |
| "Describe the video in detail.", |
| ], |
| [ |
| f"{cur_dir}/examples/sample_demo_22.mp4", |
| "Describe the activity in the video.", |
| ], |
| ], |
| inputs=[video, textbox], |
| ) |
| gr.Markdown(tos_markdown) |
| gr.Markdown(learn_more_markdown) |
|
|
| submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor], |
| [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) |
|
|
| regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( |
| generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) |
|
|
| clear_btn.click(clear_history, [state, state_], |
| [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) |
|
|
| |
| demo.launch() |
|
|
|
|
| |
|
|