| # This script requires transformers==4.57.0 | |
| import torch | |
| from transformers import ( | |
| AutoModelForImageTextToText, | |
| AutoProcessor, | |
| BitsAndBytesConfig, | |
| ) | |
| from typing import Optional, Dict, Any, Union, List, Tuple | |
| # Handle both relative and absolute imports | |
| try: | |
| from .base import BaseVideoModel | |
| except ImportError: | |
| from base import BaseVideoModel | |
| class Qwen3VLModel(BaseVideoModel): | |
| def __init__( | |
| self, | |
| model_name: str = "Qwen/Qwen3-VL-4B-Instruct", | |
| dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16, | |
| device_map: Optional[Union[str, Dict]] = "auto", | |
| attn_implementation: Optional[str] = "flash_attention_2", | |
| load_8bit: Optional[bool] = False, | |
| load_4bit: Optional[bool] = False, | |
| ): | |
| super().__init__(model_name) | |
| quantization_config = None | |
| if load_8bit or load_4bit: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=load_8bit, | |
| load_in_4bit=load_4bit, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| quantization_config=quantization_config, | |
| device_map=device_map, | |
| attn_implementation=attn_implementation, | |
| dtype=dtype, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| def chat( | |
| self, | |
| prompt: str, | |
| video_path: str, | |
| temperature: float = 0.7, | |
| do_sample: Optional[ | |
| bool | |
| ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P! | |
| max_new_tokens: int = 512, | |
| video_mode: Optional[str] = "video", # Choose from "video" or "frames" | |
| fps: Optional[float] = 1.0, | |
| num_frames: Optional[int] = 10, | |
| **kwargs: Any, | |
| ) -> str: | |
| # Ensure only one of fps or num_frames is provided | |
| if video_mode == "frames": | |
| fps = None | |
| elif video_mode == "video": | |
| num_frames = None | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "video", | |
| "video": video_path, | |
| }, | |
| {"type": "text", "text": prompt} | |
| ], | |
| }, | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| do_sample_frames=True, | |
| fps=fps, | |
| num_frames=num_frames | |
| ).to(self.model.device) | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **inputs, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| **kwargs, | |
| ) | |
| raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] | |
| response = raw_response.split("assistant")[1].strip() | |
| return response | |
| # def chat_with_confidence( | |
| # self, | |
| # prompt: str, | |
| # video_path: str, | |
| # max_new_tokens: int = 512, | |
| # temperature: float = 0.7, | |
| # do_sample: Optional[ | |
| # bool | |
| # ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P! | |
| # fps: Optional[float] = 1.0, | |
| # num_frames: Optional[int] = 10, | |
| # token_choices: Optional[List[str]] = ["Yes", "No"], | |
| # logits_temperature: Optional[float] = 1.0, | |
| # return_confidence: Optional[bool] = False, | |
| # top_k_tokens: Optional[int] = 10, | |
| # debug: Optional[bool] = False, | |
| # **kwargs: Any, | |
| # ) -> Dict[str, Any]: | |
| # """ | |
| # Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits. | |
| # Args: | |
| # prompt (str): The text prompt to generate a response for. | |
| # video_path (str): The path to the video file. | |
| # temperature (float, optional): The temperature to use for generation. Defaults to 0.7. | |
| # max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512. | |
| # token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"]. | |
| # generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None. | |
| # return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False. | |
| # top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False. | |
| # debug (bool, optional): Whether to run in debug mode. Defaults to False. | |
| # Returns: | |
| # Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits. | |
| # e.g., return_confidence: False | |
| # Output: | |
| # { | |
| # "response": "Yes", | |
| # "top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)], | |
| # } | |
| # e.g., return_confidence: True | |
| # Output: | |
| # { | |
| # "response": "Yes", | |
| # "confidence": 0.9999 | |
| # } | |
| # """ | |
| # # Messages containing a local video path and a text query | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # { | |
| # "type": "video", | |
| # "video": video_path, | |
| # # "max_pixels": 360 * 420, | |
| # "fps": fps, | |
| # }, | |
| # {"type": "text", "text": prompt}, | |
| # ], | |
| # } | |
| # ] | |
| # inputs = self.processor.apply_chat_template( | |
| # messages, | |
| # tokenize=True, | |
| # add_generation_prompt=True, | |
| # return_dict=True, | |
| # return_tensors="pt", | |
| # ) | |
| # inputs = inputs.to(self.model.device) | |
| # # In debug mode, inspect what logits processors will be used | |
| # if debug: | |
| # print("\n" + "=" * 80) | |
| # print("INSPECTING GENERATION CONFIG & WARPERS") | |
| # print("=" * 80) | |
| # # Get the generation config to see what processors will be added | |
| # gen_config = self.model.generation_config | |
| # print(f"Generation config attributes:") | |
| # print(f" Processor-related:") | |
| # print( | |
| # f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}" | |
| # ) | |
| # print( | |
| # f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}" | |
| # ) | |
| # print( | |
| # f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}" | |
| # ) | |
| # print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}") | |
| # print(f" - min_length: {getattr(gen_config, 'min_length', None)}") | |
| # print( | |
| # f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}" | |
| # ) | |
| # print( | |
| # f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}" | |
| # ) | |
| # print(f" Warper-related (THESE MASK TOKENS TO -INF):") | |
| # print(f" - temperature: {temperature} (passed as arg)") | |
| # print( | |
| # f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}" | |
| # ) | |
| # print(f" - top_k: {getattr(gen_config, 'top_k', None)}") | |
| # print(f" - top_p: {getattr(gen_config, 'top_p', None)}") | |
| # print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}") | |
| # print( | |
| # f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}" | |
| # ) | |
| # print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}") | |
| # print( | |
| # f"\n ⚠️ If top_k or top_p are set, they will mask non-selected tokens to -inf!" | |
| # ) | |
| # print("=" * 80 + "\n") | |
| # # Inference with scores | |
| # with torch.no_grad(): | |
| # outputs = self.model.generate( | |
| # **inputs, | |
| # temperature=temperature, | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=do_sample, | |
| # output_scores=True, | |
| # output_logits=True, # Get TRUE raw logits before any processing | |
| # return_dict_in_generate=True, | |
| # **kwargs, | |
| # ) | |
| # generated_ids = outputs.sequences | |
| # scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling | |
| # logits = ( | |
| # outputs.logits if hasattr(outputs, "logits") else None | |
| # ) # TRUE raw logits from model | |
| # scores = tuple( | |
| # s / logits_temperature for s in scores | |
| # ) # Scales the logits by a factor for normalization during reporting | |
| # print(f"Number of generated tokens: {len(scores)}") | |
| # print(f"Vocabulary size: {scores[0].shape[1]}") | |
| # # Check if logits differ from scores | |
| # if debug and logits is not None: | |
| # print(f"\n[IMPORTANT] output_logits available: True") | |
| # print( | |
| # f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):" | |
| # ) | |
| # logits_raw = logits[0] / logits_temperature # First token's raw logits | |
| # scores_first = scores[0] # First token's processed scores | |
| # logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs() | |
| # max_diff = logits_diff.max().item() | |
| # if max_diff > 0.001: | |
| # print( | |
| # f"[IMPORTANT] ⚠️ outputs.scores ARE DIFFERENT from outputs.logits!" | |
| # ) | |
| # print(f"[IMPORTANT] Max difference: {max_diff:.6f}") | |
| # print( | |
| # f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!" | |
| # ) | |
| # else: | |
| # print(f"[IMPORTANT] ✓ outputs.scores == outputs.logits (both are raw)") | |
| # elif debug: | |
| # print( | |
| # f"\n[IMPORTANT] output_logits not available in this transformers version" | |
| # ) | |
| # # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode | |
| # if debug: | |
| # print("\n" + "=" * 80) | |
| # print("****Running inference in debug mode****") | |
| # print("=" * 80) | |
| # # Use truly raw logits if available, otherwise use scores | |
| # raw_logits_to_show = ( | |
| # logits[0] / logits_temperature if logits is not None else scores[0] | |
| # ) | |
| # logits_label = ( | |
| # "TRUE RAW LOGITS (from outputs.logits)" | |
| # if logits is not None | |
| # else "LOGITS (from outputs.scores)" | |
| # ) | |
| # # Print first token scores shape and max/min scores in debug mode | |
| # print( | |
| # f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}" | |
| # ) | |
| # print( | |
| # f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}" | |
| # ) | |
| # # Print details about top 3 tokens from RAW logits | |
| # print(f"\n{'─'*80}") | |
| # print(f"TOP 3 TOKENS FROM {logits_label}:") | |
| # print(f"{'─'*80}") | |
| # top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1) | |
| # for i in range(3): | |
| # token_id = top_3_tokens.indices[0, i].item() | |
| # token_text = self.processor.decode(token_id) | |
| # token_logit = top_3_tokens.values[0, i].item() | |
| # print( | |
| # f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}" | |
| # ) | |
| # # Now compare with POST-PROCESSED logits (outputs.scores) | |
| # scores_first = scores[0] / logits_temperature | |
| # print(f"\n{'─'*80}") | |
| # print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):") | |
| # print(f"{'─'*80}") | |
| # print( | |
| # f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}" | |
| # ) | |
| # top_3_processed = torch.topk(scores_first, k=3, dim=-1) | |
| # for i in range(3): | |
| # token_id = top_3_processed.indices[0, i].item() | |
| # token_text = self.processor.decode(token_id) | |
| # token_logit = top_3_processed.values[0, i].item() | |
| # print( | |
| # f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}" | |
| # ) | |
| # # Check if the distributions differ (compare against truly raw logits if available) | |
| # print(f"\n{'─'*80}") | |
| # print("DIFFERENCE ANALYSIS (Raw → Post-Processed):") | |
| # print(f"{'─'*80}") | |
| # logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs() | |
| # max_diff = logit_diff.max().item() | |
| # num_changed = (logit_diff > 0.001).sum().item() | |
| # print(f" Max logit difference: {max_diff:.6f}") | |
| # print( | |
| # f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}" | |
| # ) | |
| # if max_diff > 0.001: | |
| # print(f"\n ⚠️ LOGITS WERE MODIFIED BY PROCESSORS!") | |
| # # Show which tokens changed the most | |
| # top_changes = torch.topk(logit_diff[0], k=min(5, num_changed)) | |
| # print(f"\n Top 5 most changed tokens:") | |
| # for i in range(min(5, len(top_changes.indices))): | |
| # token_id = top_changes.indices[i].item() | |
| # token_text = self.processor.decode(token_id) | |
| # raw_logit = raw_logits_to_show[0, token_id].item() | |
| # processed_logit = scores_first[0, token_id].item() | |
| # diff = top_changes.values[i].item() | |
| # print(f" Token='{token_text}' | ID={token_id}") | |
| # print( | |
| # f" Raw: {raw_logit:.4f} → Processed: {processed_logit:.4f} (Δ={diff:.4f})" | |
| # ) | |
| # else: | |
| # print(f" ✓ No significant modifications detected") | |
| # # Show what token was actually selected | |
| # print(f"\n{'─'*80}") | |
| # print("ACTUALLY GENERATED TOKEN:") | |
| # print(f"{'─'*80}") | |
| # first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item() | |
| # first_generated_token = self.processor.decode(first_generated_id) | |
| # raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item() | |
| # print(f" Token: '{first_generated_token}' | ID={first_generated_id}") | |
| # print(f" Raw logit: {raw_logit_for_generated:.4f}") | |
| # processed_logit_for_generated = scores_first[0, first_generated_id].item() | |
| # print(f" Post-processed logit: {processed_logit_for_generated:.4f}") | |
| # # Check if this token is in top-k of raw logits | |
| # top_k_raw_indices = torch.topk( | |
| # raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1 | |
| # ).indices[0] | |
| # is_in_top10_raw = first_generated_id in top_k_raw_indices | |
| # print(f" In top-10 of RAW logits: {is_in_top10_raw}") | |
| # if not is_in_top10_raw: | |
| # print( | |
| # f"\n 🚨 CRITICAL: Generated token was NOT in top-10 of raw logits!" | |
| # ) | |
| # print( | |
| # f" This proves that logits processors modified the distribution." | |
| # ) | |
| # # Find the rank of the generated token in raw logits | |
| # sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True) | |
| # raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[ | |
| # 0 | |
| # ].item() + 1 | |
| # print(f" Raw logits rank: {raw_rank}") | |
| # print("=" * 80 + "\n") | |
| # # Trim the prompt tokens from generated sequences | |
| # generated_ids_trimmed = [ | |
| # out_ids[len(in_ids) :] | |
| # for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| # ] | |
| # # Decode the text | |
| # output_response = self.processor.batch_decode( | |
| # generated_ids_trimmed, | |
| # skip_special_tokens=True, | |
| # clean_up_tokenization_spaces=False, | |
| # )[0] | |
| # # Convert scores to probabilities | |
| # # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token | |
| # selected_token_probs = [] | |
| # selected_token_logits = [] | |
| # first_token_probs = torch.softmax(scores[0], dim=-1) | |
| # # Now, find indices of tokens in token_choices and get their probabilities | |
| # for token_choice in token_choices: | |
| # # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens) | |
| # token_index = self.processor.tokenizer.encode( | |
| # token_choice, add_special_tokens=False | |
| # )[0] | |
| # selected_token_probs.append(first_token_probs[0, token_index].item()) | |
| # selected_token_logits.append(scores[0][0, token_index].item()) | |
| # # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs | |
| # if return_confidence: | |
| # first_token_id = generated_ids_trimmed[0][ | |
| # 0 | |
| # ].item() # First token of the first sequence | |
| # confidence = ( | |
| # first_token_probs[0, first_token_id].item() / sum(selected_token_probs) | |
| # if sum(selected_token_probs) > 0 | |
| # else 0.0 | |
| # ) | |
| # return { | |
| # "response": output_response, | |
| # "confidence": confidence, | |
| # } | |
| # # Return token logits | |
| # else: | |
| # token_logits = dict(zip(token_choices, selected_token_logits)) | |
| # top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1) | |
| # top_k_tokens_list: List[Tuple[str, int, float]] = [] | |
| # for i in range(top_k_tokens): | |
| # logit_index = top_k_logits_indices.indices[0, i].item() | |
| # token = self.processor.decode(logit_index) | |
| # logit = top_k_logits_indices.values[0, i].item() | |
| # top_k_tokens_list.append((token, logit_index, logit)) | |
| # return { | |
| # "response": output_response, | |
| # "top_k_tokens": top_k_tokens_list, | |
| # "token_logits": token_logits, | |
| # } | |
| # if __name__ == "__main__": | |
| # model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct" | |
| # model = Qwen3VLModel(model_path) | |
| # prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:' | |
| # token_choices = ["Yes", "No"] | |
| # video_path = ( | |
| # "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4" | |
| # ) | |
| # generation_config = { | |
| # "max_new_tokens": 128, | |
| # "do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits | |
| # "temperature": 0.7, | |
| # "logits_temperature": 1.0, | |
| # "fps": 1.0, | |
| # "return_confidence": False, | |
| # "top_k_tokens": 10, | |
| # "debug": False, | |
| # } | |
| # output = model.chat_with_confidence( | |
| # prompt, video_path, token_choices=token_choices, **generation_config | |
| # ) | |
| # response = output["response"] | |
| # print(f"Response: {response}") | |
| # if generation_config["return_confidence"]: | |
| # confidence = output["confidence"] | |
| # print(f"Confidence: {confidence}") | |
| # else: | |
| # # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf, | |
| # # otherwise, the raw logits are used, which are not filtered. | |
| # logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW" | |
| # top_k_tokens = output["top_k_tokens"] | |
| # for i in range(len(top_k_tokens)): | |
| # print(f"Top {i+1} token: {top_k_tokens[i][0]}") | |
| # print(f"Top {i+1} token logit: {top_k_tokens[i][2]}") | |
| # print("--------------------------------") | |