jena-shreyas's picture
Add transformers v5 integration to models
ef7643d
# This script requires transformers==4.57.0
import torch
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
BitsAndBytesConfig,
)
from typing import Optional, Dict, Any, Union, List, Tuple
# Handle both relative and absolute imports
try:
from .base import BaseVideoModel
except ImportError:
from base import BaseVideoModel
class Qwen3VLModel(BaseVideoModel):
def __init__(
self,
model_name: str = "Qwen/Qwen3-VL-4B-Instruct",
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
device_map: Optional[Union[str, Dict]] = "auto",
attn_implementation: Optional[str] = "flash_attention_2",
load_8bit: Optional[bool] = False,
load_4bit: Optional[bool] = False,
):
super().__init__(model_name)
quantization_config = None
if load_8bit or load_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_8bit,
load_in_4bit=load_4bit,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map=device_map,
attn_implementation=attn_implementation,
dtype=dtype,
)
self.processor = AutoProcessor.from_pretrained(model_name)
def chat(
self,
prompt: str,
video_path: str,
temperature: float = 0.7,
do_sample: Optional[
bool
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
max_new_tokens: int = 512,
video_mode: Optional[str] = "video", # Choose from "video" or "frames"
fps: Optional[float] = 1.0,
num_frames: Optional[int] = 10,
**kwargs: Any,
) -> str:
# Ensure only one of fps or num_frames is provided
if video_mode == "frames":
fps = None
elif video_mode == "video":
num_frames = None
conversation = [
{
"role": "user",
"content": [
{
"type": "video",
"video": video_path,
},
{"type": "text", "text": prompt}
],
},
]
inputs = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
do_sample_frames=True,
fps=fps,
num_frames=num_frames
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
**kwargs,
)
raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
response = raw_response.split("assistant")[1].strip()
return response
# def chat_with_confidence(
# self,
# prompt: str,
# video_path: str,
# max_new_tokens: int = 512,
# temperature: float = 0.7,
# do_sample: Optional[
# bool
# ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
# fps: Optional[float] = 1.0,
# num_frames: Optional[int] = 10,
# token_choices: Optional[List[str]] = ["Yes", "No"],
# logits_temperature: Optional[float] = 1.0,
# return_confidence: Optional[bool] = False,
# top_k_tokens: Optional[int] = 10,
# debug: Optional[bool] = False,
# **kwargs: Any,
# ) -> Dict[str, Any]:
# """
# Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
# Args:
# prompt (str): The text prompt to generate a response for.
# video_path (str): The path to the video file.
# temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
# max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512.
# token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
# generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
# return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
# top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
# debug (bool, optional): Whether to run in debug mode. Defaults to False.
# Returns:
# Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
# e.g., return_confidence: False
# Output:
# {
# "response": "Yes",
# "top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
# }
# e.g., return_confidence: True
# Output:
# {
# "response": "Yes",
# "confidence": 0.9999
# }
# """
# # Messages containing a local video path and a text query
# messages = [
# {
# "role": "user",
# "content": [
# {
# "type": "video",
# "video": video_path,
# # "max_pixels": 360 * 420,
# "fps": fps,
# },
# {"type": "text", "text": prompt},
# ],
# }
# ]
# inputs = self.processor.apply_chat_template(
# messages,
# tokenize=True,
# add_generation_prompt=True,
# return_dict=True,
# return_tensors="pt",
# )
# inputs = inputs.to(self.model.device)
# # In debug mode, inspect what logits processors will be used
# if debug:
# print("\n" + "=" * 80)
# print("INSPECTING GENERATION CONFIG & WARPERS")
# print("=" * 80)
# # Get the generation config to see what processors will be added
# gen_config = self.model.generation_config
# print(f"Generation config attributes:")
# print(f" Processor-related:")
# print(
# f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
# )
# print(
# f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
# )
# print(
# f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
# )
# print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
# print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
# print(
# f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
# )
# print(
# f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
# )
# print(f" Warper-related (THESE MASK TOKENS TO -INF):")
# print(f" - temperature: {temperature} (passed as arg)")
# print(
# f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
# )
# print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
# print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
# print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
# print(
# f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
# )
# print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
# print(
# f"\n ⚠️ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
# )
# print("=" * 80 + "\n")
# # Inference with scores
# with torch.no_grad():
# outputs = self.model.generate(
# **inputs,
# temperature=temperature,
# max_new_tokens=max_new_tokens,
# do_sample=do_sample,
# output_scores=True,
# output_logits=True, # Get TRUE raw logits before any processing
# return_dict_in_generate=True,
# **kwargs,
# )
# generated_ids = outputs.sequences
# scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling
# logits = (
# outputs.logits if hasattr(outputs, "logits") else None
# ) # TRUE raw logits from model
# scores = tuple(
# s / logits_temperature for s in scores
# ) # Scales the logits by a factor for normalization during reporting
# print(f"Number of generated tokens: {len(scores)}")
# print(f"Vocabulary size: {scores[0].shape[1]}")
# # Check if logits differ from scores
# if debug and logits is not None:
# print(f"\n[IMPORTANT] output_logits available: True")
# print(
# f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):"
# )
# logits_raw = logits[0] / logits_temperature # First token's raw logits
# scores_first = scores[0] # First token's processed scores
# logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs()
# max_diff = logits_diff.max().item()
# if max_diff > 0.001:
# print(
# f"[IMPORTANT] ⚠️ outputs.scores ARE DIFFERENT from outputs.logits!"
# )
# print(f"[IMPORTANT] Max difference: {max_diff:.6f}")
# print(
# f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!"
# )
# else:
# print(f"[IMPORTANT] ✓ outputs.scores == outputs.logits (both are raw)")
# elif debug:
# print(
# f"\n[IMPORTANT] output_logits not available in this transformers version"
# )
# # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
# if debug:
# print("\n" + "=" * 80)
# print("****Running inference in debug mode****")
# print("=" * 80)
# # Use truly raw logits if available, otherwise use scores
# raw_logits_to_show = (
# logits[0] / logits_temperature if logits is not None else scores[0]
# )
# logits_label = (
# "TRUE RAW LOGITS (from outputs.logits)"
# if logits is not None
# else "LOGITS (from outputs.scores)"
# )
# # Print first token scores shape and max/min scores in debug mode
# print(
# f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}"
# )
# print(
# f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}"
# )
# # Print details about top 3 tokens from RAW logits
# print(f"\n{'─'*80}")
# print(f"TOP 3 TOKENS FROM {logits_label}:")
# print(f"{'─'*80}")
# top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1)
# for i in range(3):
# token_id = top_3_tokens.indices[0, i].item()
# token_text = self.processor.decode(token_id)
# token_logit = top_3_tokens.values[0, i].item()
# print(
# f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
# )
# # Now compare with POST-PROCESSED logits (outputs.scores)
# scores_first = scores[0] / logits_temperature
# print(f"\n{'─'*80}")
# print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):")
# print(f"{'─'*80}")
# print(
# f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}"
# )
# top_3_processed = torch.topk(scores_first, k=3, dim=-1)
# for i in range(3):
# token_id = top_3_processed.indices[0, i].item()
# token_text = self.processor.decode(token_id)
# token_logit = top_3_processed.values[0, i].item()
# print(
# f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
# )
# # Check if the distributions differ (compare against truly raw logits if available)
# print(f"\n{'─'*80}")
# print("DIFFERENCE ANALYSIS (Raw → Post-Processed):")
# print(f"{'─'*80}")
# logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs()
# max_diff = logit_diff.max().item()
# num_changed = (logit_diff > 0.001).sum().item()
# print(f" Max logit difference: {max_diff:.6f}")
# print(
# f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}"
# )
# if max_diff > 0.001:
# print(f"\n ⚠️ LOGITS WERE MODIFIED BY PROCESSORS!")
# # Show which tokens changed the most
# top_changes = torch.topk(logit_diff[0], k=min(5, num_changed))
# print(f"\n Top 5 most changed tokens:")
# for i in range(min(5, len(top_changes.indices))):
# token_id = top_changes.indices[i].item()
# token_text = self.processor.decode(token_id)
# raw_logit = raw_logits_to_show[0, token_id].item()
# processed_logit = scores_first[0, token_id].item()
# diff = top_changes.values[i].item()
# print(f" Token='{token_text}' | ID={token_id}")
# print(
# f" Raw: {raw_logit:.4f} → Processed: {processed_logit:.4f} (Δ={diff:.4f})"
# )
# else:
# print(f" ✓ No significant modifications detected")
# # Show what token was actually selected
# print(f"\n{'─'*80}")
# print("ACTUALLY GENERATED TOKEN:")
# print(f"{'─'*80}")
# first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item()
# first_generated_token = self.processor.decode(first_generated_id)
# raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item()
# print(f" Token: '{first_generated_token}' | ID={first_generated_id}")
# print(f" Raw logit: {raw_logit_for_generated:.4f}")
# processed_logit_for_generated = scores_first[0, first_generated_id].item()
# print(f" Post-processed logit: {processed_logit_for_generated:.4f}")
# # Check if this token is in top-k of raw logits
# top_k_raw_indices = torch.topk(
# raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1
# ).indices[0]
# is_in_top10_raw = first_generated_id in top_k_raw_indices
# print(f" In top-10 of RAW logits: {is_in_top10_raw}")
# if not is_in_top10_raw:
# print(
# f"\n 🚨 CRITICAL: Generated token was NOT in top-10 of raw logits!"
# )
# print(
# f" This proves that logits processors modified the distribution."
# )
# # Find the rank of the generated token in raw logits
# sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True)
# raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[
# 0
# ].item() + 1
# print(f" Raw logits rank: {raw_rank}")
# print("=" * 80 + "\n")
# # Trim the prompt tokens from generated sequences
# generated_ids_trimmed = [
# out_ids[len(in_ids) :]
# for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
# ]
# # Decode the text
# output_response = self.processor.batch_decode(
# generated_ids_trimmed,
# skip_special_tokens=True,
# clean_up_tokenization_spaces=False,
# )[0]
# # Convert scores to probabilities
# # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
# selected_token_probs = []
# selected_token_logits = []
# first_token_probs = torch.softmax(scores[0], dim=-1)
# # Now, find indices of tokens in token_choices and get their probabilities
# for token_choice in token_choices:
# # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
# token_index = self.processor.tokenizer.encode(
# token_choice, add_special_tokens=False
# )[0]
# selected_token_probs.append(first_token_probs[0, token_index].item())
# selected_token_logits.append(scores[0][0, token_index].item())
# # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
# if return_confidence:
# first_token_id = generated_ids_trimmed[0][
# 0
# ].item() # First token of the first sequence
# confidence = (
# first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
# if sum(selected_token_probs) > 0
# else 0.0
# )
# return {
# "response": output_response,
# "confidence": confidence,
# }
# # Return token logits
# else:
# token_logits = dict(zip(token_choices, selected_token_logits))
# top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
# top_k_tokens_list: List[Tuple[str, int, float]] = []
# for i in range(top_k_tokens):
# logit_index = top_k_logits_indices.indices[0, i].item()
# token = self.processor.decode(logit_index)
# logit = top_k_logits_indices.values[0, i].item()
# top_k_tokens_list.append((token, logit_index, logit))
# return {
# "response": output_response,
# "top_k_tokens": top_k_tokens_list,
# "token_logits": token_logits,
# }
# if __name__ == "__main__":
# model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
# model = Qwen3VLModel(model_path)
# prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:'
# token_choices = ["Yes", "No"]
# video_path = (
# "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4"
# )
# generation_config = {
# "max_new_tokens": 128,
# "do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
# "temperature": 0.7,
# "logits_temperature": 1.0,
# "fps": 1.0,
# "return_confidence": False,
# "top_k_tokens": 10,
# "debug": False,
# }
# output = model.chat_with_confidence(
# prompt, video_path, token_choices=token_choices, **generation_config
# )
# response = output["response"]
# print(f"Response: {response}")
# if generation_config["return_confidence"]:
# confidence = output["confidence"]
# print(f"Confidence: {confidence}")
# else:
# # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
# # otherwise, the raw logits are used, which are not filtered.
# logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
# top_k_tokens = output["top_k_tokens"]
# for i in range(len(top_k_tokens)):
# print(f"Top {i+1} token: {top_k_tokens[i][0]}")
# print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
# print("--------------------------------")