File size: 13,118 Bytes
80ceab0
ef7643d
80ceab0
 
 
ef7643d
 
 
80ceab0
 
 
 
 
 
 
 
 
 
 
 
ef7643d
80ceab0
 
1c5650b
ef7643d
 
80ceab0
 
 
 
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ceab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f08b
 
 
80ceab0
cf5f08b
ef7643d
 
40a4325
80ceab0
ef7643d
cf5f08b
ef7643d
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
80ceab0
ef7643d
 
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
cf5f08b
 
ef7643d
cf5f08b
ef7643d
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f08b
 
ef7643d
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f08b
ef7643d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# Run with `conda activate llava`
import warnings
import copy
import torch
import numpy as np
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
from typing import Optional, Dict, Any, Union, List
from decord import VideoReader, cpu

# Handle both relative and absolute imports
try:
    from .base import BaseVideoModel
except ImportError:
    from base import BaseVideoModel

warnings.filterwarnings("ignore")

class LLaVAVideoModel(BaseVideoModel):
    def __init__(
        self,
        model_name: str = "Isotr0py/LLaVA-Video-7B-Qwen2-hf",
        dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
        device_map: Optional[Union[str, Dict]] = "auto",
        attn_implementation: Optional[str] = "flash_attention_2",
        load_8bit: Optional[bool] = False,
        load_4bit: Optional[bool] = False,
    ):
        super().__init__(model_name)
        self.dtype = dtype

        # For quantized models (8-bit or 4-bit), device_map must be "auto" or a dict, not a device string
        quantization_config = None
        if load_8bit or load_4bit:
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=load_8bit, 
                load_in_4bit=load_4bit,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )   
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_name,
            quantization_config=quantization_config,
            device_map=device_map,
            attn_implementation=attn_implementation,
            dtype=dtype,
        )
        self.processor = AutoProcessor.from_pretrained(model_name)


    def load_video(
        self,
        video_path: str,
        fps: float = 1.0,
        max_frames_num: int = -1,
        force_sample: bool = False,
    ):
        if max_frames_num == 0:
            return np.zeros((1, 336, 336, 3))
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
        total_frame_num = len(vr)
        video_time = total_frame_num / vr.get_avg_fps()
        fps = round(vr.get_avg_fps() / fps)
        frame_idx = [i for i in range(0, len(vr), fps)]
        frame_time = [i / fps for i in frame_idx]
        if (max_frames_num > 0 and len(frame_idx) > max_frames_num) or force_sample:
            sample_fps = max_frames_num
            uniform_sampled_frames = np.linspace(
                0, total_frame_num - 1, sample_fps, dtype=int
            )
            frame_idx = uniform_sampled_frames.tolist()
            frame_time = [i / vr.get_avg_fps() for i in frame_idx]
        frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
        spare_frames = vr.get_batch(frame_idx).asnumpy()
        return spare_frames, frame_time, video_time

    def chat(
        self,
        prompt: str,
        video_path: str,
        max_new_tokens: int = 512,
        do_sample: Optional[
            bool
        ] = True,  # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
        temperature: float = 0.7,
        video_mode: Optional[str] = "video",
        fps: Optional[float] = 1.0,
        num_frames: Optional[int] = 10,
        **kwargs: Any,
    ) -> str:
        # Ensure only one of fps or num_frames is provided
        if video_mode == "frames":
            fps = None
        elif video_mode == "video":
            num_frames = None
        conversation = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video", 
                        "video": video_path,
                    },
                    {"type": "text", "text": prompt}
                ],
            },
        ]

        inputs = self.processor.apply_chat_template(
            conversation, 
            add_generation_prompt=True, 
            tokenize=True, 
            return_dict=True, 
            return_tensors="pt",
            do_sample_frames=True,
            fps=fps,
            num_frames=num_frames
        ).to(self.model.device)
        with torch.no_grad():
            out = self.model.generate(
                **inputs,
                do_sample=do_sample,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                **kwargs,
            )
        raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
        response = raw_response.split("assistant")[1].strip()
        return response

#     def chat_with_confidence(
#         self,
#         prompt: str,
#         video_path: str,
#         fps: float = 1.0,
#         max_new_tokens: int = 512,
#         temperature: float = 0.7,
#         do_sample: Optional[
#             bool
#         ] = True,  # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
#         token_choices: Optional[List[str]] = ["Yes", "No"],
#         logits_temperature: Optional[float] = 1.0,
#         return_confidence: Optional[bool] = False,
#         top_k_tokens: Optional[int] = 10,
#         debug: Optional[bool] = False,
#     ) -> Dict[str, Any]:
#         video, _, _ = self.load_video(video_path, fps)
#         video = self.image_processor.preprocess(video, return_tensors="pt")[
#             "pixel_values"
#         ].to(device=self.model.device, dtype=self.dtype)
#         video = [video]
#         conv_template = (
#             "qwen_1_5"  # Make sure you use correct chat template for different models
#         )
#         question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
#         conv = copy.deepcopy(conv_templates[conv_template])
#         conv.append_message(conv.roles[0], question)
#         conv.append_message(conv.roles[1], None)
#         prompt_question = conv.get_prompt()
#         input_ids = (
#             tokenizer_image_token(
#                 prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
#             )
#             .unsqueeze(0)
#             .to(self.model.device)
#         )
#         with torch.no_grad():
#             outputs = self.model.generate(
#                 input_ids,
#                 images=video,
#                 modalities=["video"],
#                 do_sample=do_sample,  # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
#                 temperature=temperature,
#                 max_new_tokens=max_new_tokens,
#                 output_scores=True,
#                 return_dict_in_generate=True,
#             )
#         generated_ids = outputs.sequences
#         scores = outputs.scores  # Tuple of tensors, one per generated token

#         print(f"Number of generated tokens: {len(scores)}")
#         print(f"Vocabulary size: {scores[0].shape[1]}")
#         # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
#         if debug:
#             print("****Running inference in debug mode****")
#             # Print first token scores shape and max/min scores in debug mode
#             print(f"Single token scores shape: {scores[0].shape}")
#             print(
#                 f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
#             )

#             # Print details about top 10 tokens based on logits
#             logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
#             print(f"\n{'─'*80}")
#             print(
#                 f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
#             )
#             print(f"{'─'*80}")
#             top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
#             for i in range(top_k_tokens):
#                 score = top_k_tokens_scores.values[0, i].item()
#                 score_index = top_k_tokens_scores.indices[0, i].item()
#                 token = self.tokenizer.decode(score_index)
#                 print(f"#{i+1}th Token: {token}")
#                 print(f"#{i+1}th Token index: {score_index}")
#                 print(f"#{i+1}th Token score: {score}")
#                 print("--------------------------------")

#         # Decode the text
#         output_response = self.tokenizer.batch_decode(
#             generated_ids,
#             skip_special_tokens=True,
#             clean_up_tokenization_spaces=False,
#         )[0]

#         # Convert scores to probabilities
#         # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
#         selected_token_probs = []
#         selected_token_logits = []
#         first_token_probs = torch.softmax(scores[0], dim=-1)

#         # Now, find indices of tokens in token_choices and get their probabilities
#         for token_choice in token_choices:
#             # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
#             token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
#                 0
#             ]
#             selected_token_probs.append(first_token_probs[0, token_index].item())
#             selected_token_logits.append(scores[0][0, token_index].item())

#         # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
#         if return_confidence:
#             first_token_id = generated_ids[0][
#                 0
#             ].item()  # First token of the first sequence
#             confidence = (
#                 first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
#                 if sum(selected_token_probs) > 0
#                 else 0.0
#             )
#             return {
#                 "response": output_response,
#                 "confidence": confidence,
#             }

#         # Return token logits
#         else:
#             token_logits = dict(zip(token_choices, selected_token_logits))
#             top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
#             top_k_tokens_list: List[Tuple[str, int, float]] = []
#             for i in range(top_k_tokens):
#                 logit_index = top_k_logits_indices.indices[0, i].item()
#                 token = self.tokenizer.decode(logit_index)
#                 logit = top_k_logits_indices.values[0, i].item()
#                 top_k_tokens_list.append((token, logit_index, logit))
#             return {
#                 "response": output_response,
#                 "top_k_tokens": top_k_tokens_list,
#                 "token_logits": token_logits,
#             }


# if __name__ == "__main__":
#     model_path = "lmms-lab/LLaVA-Video-7B-Qwen2"  # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
#     device_map = "cuda:0"
#     model = LLaVAVideoModel(model_path, device_map=device_map)
#     prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
#     token_choices = ["Yes", "No"]
#     video_path = (
#         "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
#     )

#     generation_config = {
#         "max_new_tokens": 128,
#         "do_sample": False,  # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
#         "temperature": 0.7,
#         "logits_temperature": 1.0,
#         "fps": 1.0,
#         "return_confidence": False,
#         "top_k_tokens": 10,
#         "debug": False,
#     }
#     output = model.chat_with_confidence(
#         prompt, video_path, token_choices=token_choices, **generation_config
#     )
#     response = output["response"]
#     print(f"Response: {response}")

#     if generation_config["return_confidence"]:
#         confidence = output["confidence"]
#         print(f"Confidence: {confidence}")
#     else:
#         # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
#         # otherwise, the raw logits are used, which are not filtered.
#         logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
#         print(f"\n{'─'*80}")
#         print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
#         print(f"{'─'*80}")
#         top_k_tokens = output["top_k_tokens"]
#         for i in range(len(top_k_tokens)):
#             print(f"Top {i+1} token: {top_k_tokens[i][0]}")
#             print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
#             print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
#             print("--------------------------------")