Video-Inference-Demo / models /llava_video.py
jena-shreyas's picture
Add transformers v5 integration to models
ef7643d
# Run with `conda activate llava`
import warnings
import copy
import torch
import numpy as np
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
from typing import Optional, Dict, Any, Union, List
from decord import VideoReader, cpu
# Handle both relative and absolute imports
try:
from .base import BaseVideoModel
except ImportError:
from base import BaseVideoModel
warnings.filterwarnings("ignore")
class LLaVAVideoModel(BaseVideoModel):
def __init__(
self,
model_name: str = "Isotr0py/LLaVA-Video-7B-Qwen2-hf",
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
device_map: Optional[Union[str, Dict]] = "auto",
attn_implementation: Optional[str] = "flash_attention_2",
load_8bit: Optional[bool] = False,
load_4bit: Optional[bool] = False,
):
super().__init__(model_name)
self.dtype = dtype
# For quantized models (8-bit or 4-bit), device_map must be "auto" or a dict, not a device string
quantization_config = None
if load_8bit or load_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_8bit,
load_in_4bit=load_4bit,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map=device_map,
attn_implementation=attn_implementation,
dtype=dtype,
)
self.processor = AutoProcessor.from_pretrained(model_name)
def load_video(
self,
video_path: str,
fps: float = 1.0,
max_frames_num: int = -1,
force_sample: bool = False,
):
if max_frames_num == 0:
return np.zeros((1, 336, 336, 3))
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps() / fps)
frame_idx = [i for i in range(0, len(vr), fps)]
frame_time = [i / fps for i in frame_idx]
if (max_frames_num > 0 and len(frame_idx) > max_frames_num) or force_sample:
sample_fps = max_frames_num
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, sample_fps, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i / vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames, frame_time, video_time
def chat(
self,
prompt: str,
video_path: str,
max_new_tokens: int = 512,
do_sample: Optional[
bool
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
temperature: float = 0.7,
video_mode: Optional[str] = "video",
fps: Optional[float] = 1.0,
num_frames: Optional[int] = 10,
**kwargs: Any,
) -> str:
# Ensure only one of fps or num_frames is provided
if video_mode == "frames":
fps = None
elif video_mode == "video":
num_frames = None
conversation = [
{
"role": "user",
"content": [
{
"type": "video",
"video": video_path,
},
{"type": "text", "text": prompt}
],
},
]
inputs = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
do_sample_frames=True,
fps=fps,
num_frames=num_frames
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
**kwargs,
)
raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
response = raw_response.split("assistant")[1].strip()
return response
# def chat_with_confidence(
# self,
# prompt: str,
# video_path: str,
# fps: float = 1.0,
# max_new_tokens: int = 512,
# temperature: float = 0.7,
# do_sample: Optional[
# bool
# ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
# token_choices: Optional[List[str]] = ["Yes", "No"],
# logits_temperature: Optional[float] = 1.0,
# return_confidence: Optional[bool] = False,
# top_k_tokens: Optional[int] = 10,
# debug: Optional[bool] = False,
# ) -> Dict[str, Any]:
# video, _, _ = self.load_video(video_path, fps)
# video = self.image_processor.preprocess(video, return_tensors="pt")[
# "pixel_values"
# ].to(device=self.model.device, dtype=self.dtype)
# video = [video]
# conv_template = (
# "qwen_1_5" # Make sure you use correct chat template for different models
# )
# question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
# conv = copy.deepcopy(conv_templates[conv_template])
# conv.append_message(conv.roles[0], question)
# conv.append_message(conv.roles[1], None)
# prompt_question = conv.get_prompt()
# input_ids = (
# tokenizer_image_token(
# prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
# )
# .unsqueeze(0)
# .to(self.model.device)
# )
# with torch.no_grad():
# outputs = self.model.generate(
# input_ids,
# images=video,
# modalities=["video"],
# do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
# temperature=temperature,
# max_new_tokens=max_new_tokens,
# output_scores=True,
# return_dict_in_generate=True,
# )
# generated_ids = outputs.sequences
# scores = outputs.scores # Tuple of tensors, one per generated token
# print(f"Number of generated tokens: {len(scores)}")
# print(f"Vocabulary size: {scores[0].shape[1]}")
# # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
# if debug:
# print("****Running inference in debug mode****")
# # Print first token scores shape and max/min scores in debug mode
# print(f"Single token scores shape: {scores[0].shape}")
# print(
# f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
# )
# # Print details about top 10 tokens based on logits
# logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
# print(f"\n{'─'*80}")
# print(
# f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
# )
# print(f"{'─'*80}")
# top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
# for i in range(top_k_tokens):
# score = top_k_tokens_scores.values[0, i].item()
# score_index = top_k_tokens_scores.indices[0, i].item()
# token = self.tokenizer.decode(score_index)
# print(f"#{i+1}th Token: {token}")
# print(f"#{i+1}th Token index: {score_index}")
# print(f"#{i+1}th Token score: {score}")
# print("--------------------------------")
# # Decode the text
# output_response = self.tokenizer.batch_decode(
# generated_ids,
# skip_special_tokens=True,
# clean_up_tokenization_spaces=False,
# )[0]
# # Convert scores to probabilities
# # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
# selected_token_probs = []
# selected_token_logits = []
# first_token_probs = torch.softmax(scores[0], dim=-1)
# # Now, find indices of tokens in token_choices and get their probabilities
# for token_choice in token_choices:
# # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
# token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
# 0
# ]
# selected_token_probs.append(first_token_probs[0, token_index].item())
# selected_token_logits.append(scores[0][0, token_index].item())
# # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
# if return_confidence:
# first_token_id = generated_ids[0][
# 0
# ].item() # First token of the first sequence
# confidence = (
# first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
# if sum(selected_token_probs) > 0
# else 0.0
# )
# return {
# "response": output_response,
# "confidence": confidence,
# }
# # Return token logits
# else:
# token_logits = dict(zip(token_choices, selected_token_logits))
# top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
# top_k_tokens_list: List[Tuple[str, int, float]] = []
# for i in range(top_k_tokens):
# logit_index = top_k_logits_indices.indices[0, i].item()
# token = self.tokenizer.decode(logit_index)
# logit = top_k_logits_indices.values[0, i].item()
# top_k_tokens_list.append((token, logit_index, logit))
# return {
# "response": output_response,
# "top_k_tokens": top_k_tokens_list,
# "token_logits": token_logits,
# }
# if __name__ == "__main__":
# model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
# device_map = "cuda:0"
# model = LLaVAVideoModel(model_path, device_map=device_map)
# prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
# token_choices = ["Yes", "No"]
# video_path = (
# "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
# )
# generation_config = {
# "max_new_tokens": 128,
# "do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
# "temperature": 0.7,
# "logits_temperature": 1.0,
# "fps": 1.0,
# "return_confidence": False,
# "top_k_tokens": 10,
# "debug": False,
# }
# output = model.chat_with_confidence(
# prompt, video_path, token_choices=token_choices, **generation_config
# )
# response = output["response"]
# print(f"Response: {response}")
# if generation_config["return_confidence"]:
# confidence = output["confidence"]
# print(f"Confidence: {confidence}")
# else:
# # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
# # otherwise, the raw logits are used, which are not filtered.
# logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
# print(f"\n{'─'*80}")
# print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
# print(f"{'─'*80}")
# top_k_tokens = output["top_k_tokens"]
# for i in range(len(top_k_tokens)):
# print(f"Top {i+1} token: {top_k_tokens[i][0]}")
# print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
# print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
# print("--------------------------------")