# Run with `conda activate llava` import warnings import copy import torch import numpy as np from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig from typing import Optional, Dict, Any, Union, List from decord import VideoReader, cpu # Handle both relative and absolute imports try: from .base import BaseVideoModel except ImportError: from base import BaseVideoModel warnings.filterwarnings("ignore") class LLaVAVideoModel(BaseVideoModel): def __init__( self, model_name: str = "Isotr0py/LLaVA-Video-7B-Qwen2-hf", dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16, device_map: Optional[Union[str, Dict]] = "auto", attn_implementation: Optional[str] = "flash_attention_2", load_8bit: Optional[bool] = False, load_4bit: Optional[bool] = False, ): super().__init__(model_name) self.dtype = dtype # For quantized models (8-bit or 4-bit), device_map must be "auto" or a dict, not a device string quantization_config = None if load_8bit or load_4bit: quantization_config = BitsAndBytesConfig( load_in_8bit=load_8bit, load_in_4bit=load_4bit, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) self.model = AutoModelForImageTextToText.from_pretrained( model_name, quantization_config=quantization_config, device_map=device_map, attn_implementation=attn_implementation, dtype=dtype, ) self.processor = AutoProcessor.from_pretrained(model_name) def load_video( self, video_path: str, fps: float = 1.0, max_frames_num: int = -1, force_sample: bool = False, ): if max_frames_num == 0: return np.zeros((1, 336, 336, 3)) vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) total_frame_num = len(vr) video_time = total_frame_num / vr.get_avg_fps() fps = round(vr.get_avg_fps() / fps) frame_idx = [i for i in range(0, len(vr), fps)] frame_time = [i / fps for i in frame_idx] if (max_frames_num > 0 and len(frame_idx) > max_frames_num) or force_sample: sample_fps = max_frames_num uniform_sampled_frames = np.linspace( 0, total_frame_num - 1, sample_fps, dtype=int ) frame_idx = uniform_sampled_frames.tolist() frame_time = [i / vr.get_avg_fps() for i in frame_idx] frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) spare_frames = vr.get_batch(frame_idx).asnumpy() return spare_frames, frame_time, video_time def chat( self, prompt: str, video_path: str, max_new_tokens: int = 512, do_sample: Optional[ bool ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P! temperature: float = 0.7, video_mode: Optional[str] = "video", fps: Optional[float] = 1.0, num_frames: Optional[int] = 10, **kwargs: Any, ) -> str: # Ensure only one of fps or num_frames is provided if video_mode == "frames": fps = None elif video_mode == "video": num_frames = None conversation = [ { "role": "user", "content": [ { "type": "video", "video": video_path, }, {"type": "text", "text": prompt} ], }, ] inputs = self.processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", do_sample_frames=True, fps=fps, num_frames=num_frames ).to(self.model.device) with torch.no_grad(): out = self.model.generate( **inputs, do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, **kwargs, ) raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] response = raw_response.split("assistant")[1].strip() return response # def chat_with_confidence( # self, # prompt: str, # video_path: str, # fps: float = 1.0, # max_new_tokens: int = 512, # temperature: float = 0.7, # do_sample: Optional[ # bool # ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P! # token_choices: Optional[List[str]] = ["Yes", "No"], # logits_temperature: Optional[float] = 1.0, # return_confidence: Optional[bool] = False, # top_k_tokens: Optional[int] = 10, # debug: Optional[bool] = False, # ) -> Dict[str, Any]: # video, _, _ = self.load_video(video_path, fps) # video = self.image_processor.preprocess(video, return_tensors="pt")[ # "pixel_values" # ].to(device=self.model.device, dtype=self.dtype) # video = [video] # conv_template = ( # "qwen_1_5" # Make sure you use correct chat template for different models # ) # question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}" # conv = copy.deepcopy(conv_templates[conv_template]) # conv.append_message(conv.roles[0], question) # conv.append_message(conv.roles[1], None) # prompt_question = conv.get_prompt() # input_ids = ( # tokenizer_image_token( # prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" # ) # .unsqueeze(0) # .to(self.model.device) # ) # with torch.no_grad(): # outputs = self.model.generate( # input_ids, # images=video, # modalities=["video"], # do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P! # temperature=temperature, # max_new_tokens=max_new_tokens, # output_scores=True, # return_dict_in_generate=True, # ) # generated_ids = outputs.sequences # scores = outputs.scores # Tuple of tensors, one per generated token # print(f"Number of generated tokens: {len(scores)}") # print(f"Vocabulary size: {scores[0].shape[1]}") # # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode # if debug: # print("****Running inference in debug mode****") # # Print first token scores shape and max/min scores in debug mode # print(f"Single token scores shape: {scores[0].shape}") # print( # f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}" # ) # # Print details about top 10 tokens based on logits # logits_type = "POST-PROCESSED" if do_sample is True else "RAW" # print(f"\n{'─'*80}") # print( # f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):" # ) # print(f"{'─'*80}") # top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1) # for i in range(top_k_tokens): # score = top_k_tokens_scores.values[0, i].item() # score_index = top_k_tokens_scores.indices[0, i].item() # token = self.tokenizer.decode(score_index) # print(f"#{i+1}th Token: {token}") # print(f"#{i+1}th Token index: {score_index}") # print(f"#{i+1}th Token score: {score}") # print("--------------------------------") # # Decode the text # output_response = self.tokenizer.batch_decode( # generated_ids, # skip_special_tokens=True, # clean_up_tokenization_spaces=False, # )[0] # # Convert scores to probabilities # # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token # selected_token_probs = [] # selected_token_logits = [] # first_token_probs = torch.softmax(scores[0], dim=-1) # # Now, find indices of tokens in token_choices and get their probabilities # for token_choice in token_choices: # # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens) # token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[ # 0 # ] # selected_token_probs.append(first_token_probs[0, token_index].item()) # selected_token_logits.append(scores[0][0, token_index].item()) # # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs # if return_confidence: # first_token_id = generated_ids[0][ # 0 # ].item() # First token of the first sequence # confidence = ( # first_token_probs[0, first_token_id].item() / sum(selected_token_probs) # if sum(selected_token_probs) > 0 # else 0.0 # ) # return { # "response": output_response, # "confidence": confidence, # } # # Return token logits # else: # token_logits = dict(zip(token_choices, selected_token_logits)) # top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1) # top_k_tokens_list: List[Tuple[str, int, float]] = [] # for i in range(top_k_tokens): # logit_index = top_k_logits_indices.indices[0, i].item() # token = self.tokenizer.decode(logit_index) # logit = top_k_logits_indices.values[0, i].item() # top_k_tokens_list.append((token, logit_index, logit)) # return { # "response": output_response, # "top_k_tokens": top_k_tokens_list, # "token_logits": token_logits, # } # if __name__ == "__main__": # model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct" # device_map = "cuda:0" # model = LLaVAVideoModel(model_path, device_map=device_map) # prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:' # token_choices = ["Yes", "No"] # video_path = ( # "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4" # ) # generation_config = { # "max_new_tokens": 128, # "do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits # "temperature": 0.7, # "logits_temperature": 1.0, # "fps": 1.0, # "return_confidence": False, # "top_k_tokens": 10, # "debug": False, # } # output = model.chat_with_confidence( # prompt, video_path, token_choices=token_choices, **generation_config # ) # response = output["response"] # print(f"Response: {response}") # if generation_config["return_confidence"]: # confidence = output["confidence"] # print(f"Confidence: {confidence}") # else: # # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf, # # otherwise, the raw logits are used, which are not filtered. # logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW" # print(f"\n{'─'*80}") # print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):") # print(f"{'─'*80}") # top_k_tokens = output["top_k_tokens"] # for i in range(len(top_k_tokens)): # print(f"Top {i+1} token: {top_k_tokens[i][0]}") # print(f"Top {i+1} token index: {top_k_tokens[i][1]}") # print(f"Top {i+1} token logit: {top_k_tokens[i][2]}") # print("--------------------------------")