File size: 4,460 Bytes
eec1ab5
17d4493
eec1ab5
21fd64b
 
 
 
 
 
c47030c
 
 
 
621c8c9
17d4493
 
 
 
 
 
 
 
 
0e49d67
 
17d4493
0e49d67
17d4493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21fd64b
17d4493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e49d67
 
 
 
 
 
 
 
 
 
 
17d4493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e49d67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import spaces
import gradio as gr

import subprocess  # πŸ₯²
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)
# subprocess.run(
#     "pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git",
#     shell=True,
# )

import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
import copy
import warnings
from decord import VideoReader, cpu
import numpy as np
import tempfile
import os

#warnings.filterwarnings("ignore")

def load_video(video_path, max_frames_num, fps=1, force_sample=False):
    if max_frames_num == 0:
        return np.zeros((1, 336, 336, 3))
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    total_frame_num = len(vr)
    video_time = total_frame_num / vr.get_avg_fps()
    fps = round(vr.get_avg_fps()/fps)
    frame_idx = [i for i in range(0, len(vr), fps)]
    frame_time = [i/fps for i in frame_idx]
    if len(frame_idx) > max_frames_num or force_sample:
        sample_fps = max_frames_num
        uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
        frame_idx = uniform_sampled_frames.tolist()
        frame_time = [i/vr.get_avg_fps() for i in frame_idx]
    frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
    spare_frames = vr.get_batch(frame_idx).asnumpy()
    return spare_frames, frame_time, video_time

# Load the model
pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_map = "auto"

print("Loading model...")
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
model.eval()
print("Model loaded successfully!")

@spaces.GPU
def process_video(video_path, question):
    max_frames_num = 64
    video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
    video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
    video = [video]

    conv_template = "qwen_1_5"
    time_instruction = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
    
    full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
    
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], full_question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()
    
    input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            images=video,
            modalities=["video"],
            do_sample=False,
            temperature=0,
            max_new_tokens=4096,
        )
    
    response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
    return response

def gradio_interface(video_file, question):
    if video_file is None:
        return "Please upload a video file."
    
    # Save the uploaded video to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
        temp_video.write(video_file.read())
        temp_video_path = temp_video.name

    try:
        response = process_video(temp_video_path, question)
    finally:
        os.unlink(temp_video_path)
    
    return response

with gr.Blocks() as demo:
    gr.Markdown("# LLaVA-Video-7B-Qwen2 Demo")
    gr.Markdown("Upload a video and ask a question about it.")
    
    with gr.Row():
        video_input = gr.Video()
        question_input = gr.Textbox(label="Question", placeholder="Ask a question about the video...")
    
    submit_button = gr.Button("Submit")
    output = gr.Textbox(label="Response")
    
    submit_button.click(
        fn=gradio_interface,
        inputs=[video_input, question_input],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch(show_error=True)