jena-shreyas commited on
Commit
ef7643d
Β·
1 Parent(s): cf5f08b

Add transformers v5 integration to models

Browse files
Files changed (6) hide show
  1. app.py +2 -2
  2. models/__init__.py +25 -92
  3. models/base.py +8 -8
  4. models/llava_video.py +222 -243
  5. models/qwen2_5vl.py +232 -222
  6. models/qwen3vl.py +431 -477
app.py CHANGED
@@ -12,7 +12,7 @@ from models.base import BaseVideoModel
12
  # ----------------------
13
  # CONFIG
14
  # ----------------------
15
- MODEL_PATH = "lmms-lab/LLaVA-Video-7B-Qwen2"
16
  DEVICE_MAP = "cuda:0"
17
 
18
  VIDEO_DIR = str(Path(__file__).parent / "videos")
@@ -130,7 +130,7 @@ with gr.Blocks(title="Video QA – LLaVa-Video-7B-Qwen2", theme=gr.themes.Soft()
130
 
131
  fps_slider = gr.Slider(
132
  minimum=0.5,
133
- maximum=5.0,
134
  step=0.5,
135
  value=FPS,
136
  label="🎞️ Frames Per Second (FPS)",
 
12
  # ----------------------
13
  # CONFIG
14
  # ----------------------
15
+ MODEL_PATH = "Isotr0py/LLaVA-Video-7B-Qwen2-hf"
16
  DEVICE_MAP = "cuda:0"
17
 
18
  VIDEO_DIR = str(Path(__file__).parent / "videos")
 
130
 
131
  fps_slider = gr.Slider(
132
  minimum=0.5,
133
+ maximum=10.0,
134
  step=0.5,
135
  value=FPS,
136
  label="🎞️ Frames Per Second (FPS)",
models/__init__.py CHANGED
@@ -3,66 +3,27 @@ from packaging import version
3
  import torch
4
  from typing import Optional, Union, Dict
5
 
 
6
  # IMP: Add required versions here
7
- qwen_required_version = version.parse("4.57.0")
8
- internvl_required_version = version.parse("4.45.0")
9
- llava_required_version = version.parse("4.40.0")
10
 
11
  # Conditional imports based on transformers version
12
- try:
13
- import transformers
14
- from transformers.generation.logits_process import LogitsProcessor
15
-
16
- # Check transformers version
17
- transformers_version = version.parse(transformers.__version__)
18
-
19
- QWEN_MODELS_AVAILABLE = False
20
- INTERNVL_MODELS_AVAILABLE = False
21
- LLAVA_MODELS_AVAILABLE = False
22
-
23
- # Qwen condition
24
- if transformers_version >= qwen_required_version:
25
- from .qwen2_5vl import Qwen2_5VLModel
26
- from .qwen3vl import Qwen3VLModel
27
-
28
- QWEN_MODELS_AVAILABLE = True
29
- else:
30
- print(
31
- f"Warning: Qwen models require transformers>=4.57.0, but found {transformers.__version__}. Qwen models will not be available. Please upgrade to transformers>=4.57.0 or switch conda environments to use Qwen models."
32
- )
33
-
34
- # InternVL condition
35
- if transformers_version >= internvl_required_version:
36
- from .internvl import InternVLModel
37
-
38
- INTERNVL_MODELS_AVAILABLE = True
39
- else:
40
- print(
41
- f"Warning: InternVL models require transformers>=4.45.0, but found {transformers.__version__}. InternVL models will not be available. Please downgrade to transformers<=4.45.0 or switch conda environments to use InternVL models."
42
- )
43
 
44
- # LLaVA condition
45
- if transformers_version <= llava_required_version:
46
- from .llava_video import LLaVAVideoModel
47
 
48
- LLAVA_MODELS_AVAILABLE = True
49
- else:
50
- print(
51
- f"Warning: LLaVA models require transformers<=4.40.0, but found {transformers.__version__}. LLaVA models will not be available. Please downgrade to transformers<=4.40.0 or switch conda environments to use LLaVA models."
52
- )
53
- except ImportError:
54
- print(
55
- "Warning: Could not check transformers version. Please re-check transformers installation."
56
- )
57
 
58
- # Build __all__ list dynamically
59
- __all__ = []
60
- if QWEN_MODELS_AVAILABLE:
61
- __all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
62
- if INTERNVL_MODELS_AVAILABLE:
63
- __all__.append("InternVLModel")
64
- if LLAVA_MODELS_AVAILABLE:
65
- __all__.append("LLaVAVideoModel")
66
 
67
 
68
  # Function to get the model by mapping model ID to the correct model class
@@ -71,31 +32,27 @@ def load_model(
71
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
72
  device_map: Optional[Union[str, Dict]] = "auto",
73
  attn_implementation: Optional[str] = "flash_attention_2",
 
 
74
  ) -> BaseVideoModel:
75
  if "LLaVA-Video" in model_path:
76
- if not LLAVA_MODELS_AVAILABLE:
77
- raise ImportError(
78
- f"LLaVA models require transformers<=4.40.0."
79
- f"Please downgrade transformers: pip install transformers<=4.40.0"
80
- )
81
  return LLaVAVideoModel(
82
  model_path,
83
  dtype=dtype,
84
  device_map=device_map,
85
  attn_implementation=attn_implementation,
 
 
86
  )
87
  elif "Qwen" in model_path:
88
- if not QWEN_MODELS_AVAILABLE:
89
- raise ImportError(
90
- f"Qwen models require transformers>=4.57.0."
91
- f"Please upgrade transformers: pip install transformers>=4.57.0"
92
- )
93
  if "Qwen3" in model_path:
94
  return Qwen3VLModel(
95
  model_path,
96
  dtype=dtype,
97
  device_map=device_map,
98
  attn_implementation=attn_implementation,
 
 
99
  )
100
  else:
101
  return Qwen2_5VLModel(
@@ -103,39 +60,15 @@ def load_model(
103
  dtype=dtype,
104
  device_map=device_map,
105
  attn_implementation=attn_implementation,
 
 
106
  )
107
  elif "Intern" in model_path:
108
- if not INTERNVL_MODELS_AVAILABLE:
109
- raise ImportError(
110
- f"InternVL models require transformers>=4.45.0."
111
- f"Please upgrade transformers: pip install transformers>=4.45.0"
112
- )
113
  return InternVLModel(
114
  model_path,
115
  dtype=dtype,
116
  device_map=device_map,
117
  attn_implementation=attn_implementation,
 
 
118
  )
119
-
120
-
121
- class LogitsCaptureProcessor(LogitsProcessor):
122
- """
123
- Custom LogitsProcessor that captures the processed logits right before sampling.
124
- This allows us to see what the actual distribution looks like after all other
125
- processors have been applied.
126
- """
127
-
128
- def __init__(self):
129
- self.captured_logits = []
130
-
131
- def __call__(
132
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
133
- ) -> torch.FloatTensor:
134
- # Store a copy of the logits at this point in generation
135
- self.captured_logits.append(scores.detach().clone().cpu())
136
- # Return scores unchanged - we're just observing
137
- return scores
138
-
139
- def reset(self):
140
- """Clear captured logits for a new generation"""
141
- self.captured_logits = []
 
3
  import torch
4
  from typing import Optional, Union, Dict
5
 
6
+
7
  # IMP: Add required versions here
8
+ transformers_required_version = version.parse("5.0.0")
 
 
9
 
10
  # Conditional imports based on transformers version
11
+ import transformers
12
+ from transformers import BitsAndBytesConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Check transformers version
15
+ transformers_version = version.parse(transformers.__version__)
 
16
 
17
+ # transformers v5 condition
18
+ if transformers_version >= transformers_required_version:
19
+ from .qwen2_5vl import Qwen2_5VLModel
20
+ from .qwen3vl import Qwen3VLModel
21
+ from .internvl import InternVLModel
22
+ from .llava_video import LLaVAVideoModel
 
 
 
23
 
24
+ TRANSFORMERS_MODELS_AVAILABLE = True
25
+ else:
26
+ raise ValueError(f"Transformers v5 models require transformers>=5.0.0, but found {transformers.__version__}. Transformers v5 models will not be available. Please upgrade to transformers>=5.0.0 or switch conda environments to use Transformers v5 models.")
 
 
 
 
 
27
 
28
 
29
  # Function to get the model by mapping model ID to the correct model class
 
32
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
33
  device_map: Optional[Union[str, Dict]] = "auto",
34
  attn_implementation: Optional[str] = "flash_attention_2",
35
+ load_8bit: Optional[bool] = False,
36
+ load_4bit: Optional[bool] = False,
37
  ) -> BaseVideoModel:
38
  if "LLaVA-Video" in model_path:
 
 
 
 
 
39
  return LLaVAVideoModel(
40
  model_path,
41
  dtype=dtype,
42
  device_map=device_map,
43
  attn_implementation=attn_implementation,
44
+ load_8bit=load_8bit,
45
+ load_4bit=load_4bit,
46
  )
47
  elif "Qwen" in model_path:
 
 
 
 
 
48
  if "Qwen3" in model_path:
49
  return Qwen3VLModel(
50
  model_path,
51
  dtype=dtype,
52
  device_map=device_map,
53
  attn_implementation=attn_implementation,
54
+ load_8bit=load_8bit,
55
+ load_4bit=load_4bit,
56
  )
57
  else:
58
  return Qwen2_5VLModel(
 
60
  dtype=dtype,
61
  device_map=device_map,
62
  attn_implementation=attn_implementation,
63
+ load_8bit=load_8bit,
64
+ load_4bit=load_4bit,
65
  )
66
  elif "Intern" in model_path:
 
 
 
 
 
67
  return InternVLModel(
68
  model_path,
69
  dtype=dtype,
70
  device_map=device_map,
71
  attn_implementation=attn_implementation,
72
+ load_8bit=load_8bit,
73
+ load_4bit=load_4bit,
74
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/base.py CHANGED
@@ -17,11 +17,11 @@ class BaseVideoModel(ABC):
17
  ) -> str:
18
  pass
19
 
20
- @abstractmethod
21
- def chat_with_confidence(
22
- self,
23
- prompt: str,
24
- video_path: str,
25
- generation_config: Optional[Dict[str, Any]] = None,
26
- ) -> Dict[str, Union[str, float]]:
27
- pass
 
17
  ) -> str:
18
  pass
19
 
20
+ # @abstractmethod
21
+ # def chat_with_confidence(
22
+ # self,
23
+ # prompt: str,
24
+ # video_path: str,
25
+ # generation_config: Optional[Dict[str, Any]] = None,
26
+ # ) -> Dict[str, Union[str, float]]:
27
+ # pass
models/llava_video.py CHANGED
@@ -1,26 +1,11 @@
1
  # Run with `conda activate llava`
2
- from llava.model.builder import load_pretrained_model
3
- from llava.mm_utils import (
4
- get_model_name_from_path,
5
- process_images,
6
- tokenizer_image_token,
7
- )
8
- from llava.constants import (
9
- IMAGE_TOKEN_INDEX,
10
- DEFAULT_IMAGE_TOKEN,
11
- DEFAULT_IM_START_TOKEN,
12
- DEFAULT_IM_END_TOKEN,
13
- IGNORE_INDEX,
14
- )
15
- from llava.conversation import conv_templates, SeparatorStyle
16
- from PIL import Image
17
- import requests
18
  import copy
19
  import torch
20
- from typing import Optional, Union, Dict, List, Tuple, Any
21
- import warnings
22
- from decord import VideoReader, cpu
23
  import numpy as np
 
 
 
24
 
25
  # Handle both relative and absolute imports
26
  try:
@@ -30,46 +15,37 @@ except ImportError:
30
 
31
  warnings.filterwarnings("ignore")
32
 
33
-
34
  class LLaVAVideoModel(BaseVideoModel):
35
  def __init__(
36
  self,
37
- model_name: str = "lmms-lab/LLaVA-Video-7B-Qwen2",
38
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
39
  device_map: Optional[Union[str, Dict]] = "auto",
40
  attn_implementation: Optional[str] = "flash_attention_2",
 
 
41
  ):
42
  super().__init__(model_name)
43
- base_model = "llava_qwen"
44
  self.dtype = dtype
45
- # Convert torch dtype to string for safety, since LLaVA-Video only accepts torch_dtype as a string
46
- if dtype == torch.bfloat16:
47
- torch_dtype = "bfloat16"
48
- elif dtype == torch.float16:
49
- torch_dtype = "float16"
50
-
51
- self.tokenizer, self.model, self.image_processor, max_length = (
52
- load_pretrained_model(
53
- model_name,
54
- None,
55
- base_model,
56
- torch_dtype=torch_dtype,
57
- device_map=device_map,
58
- )
59
- ) # Add any other thing you want to pass in llava_model_args
60
- self.model.eval()
61
 
62
- # Ensure all model components are on the same device
63
- # The vision tower and mm_projector may not be on the correct device with device_map using `load_pretrained_model`, so need to explicitly move to the model's device
64
- if hasattr(self.model, "get_vision_tower"):
65
- vision_tower = self.model.get_vision_tower()
66
- if vision_tower is not None:
67
- vision_tower.to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- if hasattr(self.model, "get_model"):
70
- model_inner = self.model.get_model()
71
- if hasattr(model_inner, "mm_projector"):
72
- model_inner.mm_projector.to(self.model.device)
73
 
74
  def load_video(
75
  self,
@@ -101,224 +77,227 @@ class LLaVAVideoModel(BaseVideoModel):
101
  self,
102
  prompt: str,
103
  video_path: str,
104
- fps: float = 1.0,
105
  max_new_tokens: int = 512,
106
  do_sample: Optional[
107
  bool
108
  ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
109
  temperature: float = 0.7,
110
  video_mode: Optional[str] = "video",
111
- video_frames: Optional[int] = 10,
 
112
  **kwargs: Any,
113
  ) -> str:
 
114
  if video_mode == "frames":
115
- video, _, _ = self.load_video(video_path, max_frames_num=video_frames)
116
  elif video_mode == "video":
117
- video, _, _ = self.load_video(video_path, fps)
118
- video = self.image_processor.preprocess(video, return_tensors="pt")[
119
- "pixel_values"
120
- ].to(device=self.model.device, dtype=self.dtype)
121
- video = [video]
122
- conv_template = (
123
- "qwen_1_5" # Make sure you use correct chat template for different models
124
- )
125
- question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
126
- conv = copy.deepcopy(conv_templates[conv_template])
127
- conv.append_message(conv.roles[0], question)
128
- conv.append_message(conv.roles[1], None)
129
- prompt_question = conv.get_prompt()
130
- input_ids = (
131
- tokenizer_image_token(
132
- prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
133
- )
134
- .unsqueeze(0)
135
- .to(self.model.device)
136
- )
137
- cont = self.model.generate(
138
- input_ids,
139
- images=video,
140
- modalities=["video"],
141
- do_sample=do_sample,
142
- temperature=temperature,
143
- max_new_tokens=max_new_tokens,
144
- **kwargs,
145
- )
146
- text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[
147
- 0
148
- ].strip()
149
- return text_outputs
150
 
151
- def chat_with_confidence(
152
- self,
153
- prompt: str,
154
- video_path: str,
155
- fps: float = 1.0,
156
- max_new_tokens: int = 512,
157
- temperature: float = 0.7,
158
- do_sample: Optional[
159
- bool
160
- ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
161
- token_choices: Optional[List[str]] = ["Yes", "No"],
162
- logits_temperature: Optional[float] = 1.0,
163
- return_confidence: Optional[bool] = False,
164
- top_k_tokens: Optional[int] = 10,
165
- debug: Optional[bool] = False,
166
- ) -> Dict[str, Any]:
167
- video, _, _ = self.load_video(video_path, fps)
168
- video = self.image_processor.preprocess(video, return_tensors="pt")[
169
- "pixel_values"
170
- ].to(device=self.model.device, dtype=self.dtype)
171
- video = [video]
172
- conv_template = (
173
- "qwen_1_5" # Make sure you use correct chat template for different models
174
- )
175
- question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
176
- conv = copy.deepcopy(conv_templates[conv_template])
177
- conv.append_message(conv.roles[0], question)
178
- conv.append_message(conv.roles[1], None)
179
- prompt_question = conv.get_prompt()
180
- input_ids = (
181
- tokenizer_image_token(
182
- prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
183
- )
184
- .unsqueeze(0)
185
- .to(self.model.device)
186
- )
187
  with torch.no_grad():
188
- outputs = self.model.generate(
189
- input_ids,
190
- images=video,
191
- modalities=["video"],
192
- do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
193
  temperature=temperature,
194
  max_new_tokens=max_new_tokens,
195
- output_scores=True,
196
- return_dict_in_generate=True,
197
  )
198
- generated_ids = outputs.sequences
199
- scores = outputs.scores # Tuple of tensors, one per generated token
 
200
 
201
- print(f"Number of generated tokens: {len(scores)}")
202
- print(f"Vocabulary size: {scores[0].shape[1]}")
203
- # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
204
- if debug:
205
- print("****Running inference in debug mode****")
206
- # Print first token scores shape and max/min scores in debug mode
207
- print(f"Single token scores shape: {scores[0].shape}")
208
- print(
209
- f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
210
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- # Print details about top 10 tokens based on logits
213
- logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
214
- print(f"\n{'─'*80}")
215
- print(
216
- f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
217
- )
218
- print(f"{'─'*80}")
219
- top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
220
- for i in range(top_k_tokens):
221
- score = top_k_tokens_scores.values[0, i].item()
222
- score_index = top_k_tokens_scores.indices[0, i].item()
223
- token = self.tokenizer.decode(score_index)
224
- print(f"#{i+1}th Token: {token}")
225
- print(f"#{i+1}th Token index: {score_index}")
226
- print(f"#{i+1}th Token score: {score}")
227
- print("--------------------------------")
228
 
229
- # Decode the text
230
- output_response = self.tokenizer.batch_decode(
231
- generated_ids,
232
- skip_special_tokens=True,
233
- clean_up_tokenization_spaces=False,
234
- )[0]
 
 
 
 
 
 
 
 
 
 
235
 
236
- # Convert scores to probabilities
237
- # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
238
- selected_token_probs = []
239
- selected_token_logits = []
240
- first_token_probs = torch.softmax(scores[0], dim=-1)
 
241
 
242
- # Now, find indices of tokens in token_choices and get their probabilities
243
- for token_choice in token_choices:
244
- # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
245
- token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
246
- 0
247
- ]
248
- selected_token_probs.append(first_token_probs[0, token_index].item())
249
- selected_token_logits.append(scores[0][0, token_index].item())
250
 
251
- # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
252
- if return_confidence:
253
- first_token_id = generated_ids[0][
254
- 0
255
- ].item() # First token of the first sequence
256
- confidence = (
257
- first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
258
- if sum(selected_token_probs) > 0
259
- else 0.0
260
- )
261
- return {
262
- "response": output_response,
263
- "confidence": confidence,
264
- }
 
 
 
 
 
 
 
 
 
265
 
266
- # Return token logits
267
- else:
268
- token_logits = dict(zip(token_choices, selected_token_logits))
269
- top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
270
- top_k_tokens_list: List[Tuple[str, int, float]] = []
271
- for i in range(top_k_tokens):
272
- logit_index = top_k_logits_indices.indices[0, i].item()
273
- token = self.tokenizer.decode(logit_index)
274
- logit = top_k_logits_indices.values[0, i].item()
275
- top_k_tokens_list.append((token, logit_index, logit))
276
- return {
277
- "response": output_response,
278
- "top_k_tokens": top_k_tokens_list,
279
- "token_logits": token_logits,
280
- }
281
 
282
 
283
- if __name__ == "__main__":
284
- model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
285
- device_map = "cuda:0"
286
- model = LLaVAVideoModel(model_path, device_map=device_map)
287
- prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
288
- token_choices = ["Yes", "No"]
289
- video_path = (
290
- "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
291
- )
292
 
293
- generation_config = {
294
- "max_new_tokens": 128,
295
- "do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
296
- "temperature": 0.7,
297
- "logits_temperature": 1.0,
298
- "fps": 1.0,
299
- "return_confidence": False,
300
- "top_k_tokens": 10,
301
- "debug": False,
302
- }
303
- output = model.chat_with_confidence(
304
- prompt, video_path, token_choices=token_choices, **generation_config
305
- )
306
- response = output["response"]
307
- print(f"Response: {response}")
308
 
309
- if generation_config["return_confidence"]:
310
- confidence = output["confidence"]
311
- print(f"Confidence: {confidence}")
312
- else:
313
- # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
314
- # otherwise, the raw logits are used, which are not filtered.
315
- logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
316
- print(f"\n{'─'*80}")
317
- print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
318
- print(f"{'─'*80}")
319
- top_k_tokens = output["top_k_tokens"]
320
- for i in range(len(top_k_tokens)):
321
- print(f"Top {i+1} token: {top_k_tokens[i][0]}")
322
- print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
323
- print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
324
- print("--------------------------------")
 
1
  # Run with `conda activate llava`
2
+ import warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import copy
4
  import torch
 
 
 
5
  import numpy as np
6
+ from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
7
+ from typing import Optional, Dict, Any, Union, List
8
+ from decord import VideoReader, cpu
9
 
10
  # Handle both relative and absolute imports
11
  try:
 
15
 
16
  warnings.filterwarnings("ignore")
17
 
 
18
  class LLaVAVideoModel(BaseVideoModel):
19
  def __init__(
20
  self,
21
+ model_name: str = "Isotr0py/LLaVA-Video-7B-Qwen2-hf",
22
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
23
  device_map: Optional[Union[str, Dict]] = "auto",
24
  attn_implementation: Optional[str] = "flash_attention_2",
25
+ load_8bit: Optional[bool] = False,
26
+ load_4bit: Optional[bool] = False,
27
  ):
28
  super().__init__(model_name)
 
29
  self.dtype = dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # For quantized models (8-bit or 4-bit), device_map must be "auto" or a dict, not a device string
32
+ quantization_config = None
33
+ if load_8bit or load_4bit:
34
+ quantization_config = BitsAndBytesConfig(
35
+ load_in_8bit=load_8bit,
36
+ load_in_4bit=load_4bit,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_compute_dtype=torch.float16
39
+ )
40
+ self.model = AutoModelForImageTextToText.from_pretrained(
41
+ model_name,
42
+ quantization_config=quantization_config,
43
+ device_map=device_map,
44
+ attn_implementation=attn_implementation,
45
+ dtype=dtype,
46
+ )
47
+ self.processor = AutoProcessor.from_pretrained(model_name)
48
 
 
 
 
 
49
 
50
  def load_video(
51
  self,
 
77
  self,
78
  prompt: str,
79
  video_path: str,
 
80
  max_new_tokens: int = 512,
81
  do_sample: Optional[
82
  bool
83
  ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
84
  temperature: float = 0.7,
85
  video_mode: Optional[str] = "video",
86
+ fps: Optional[float] = 1.0,
87
+ num_frames: Optional[int] = 10,
88
  **kwargs: Any,
89
  ) -> str:
90
+ # Ensure only one of fps or num_frames is provided
91
  if video_mode == "frames":
92
+ fps = None
93
  elif video_mode == "video":
94
+ num_frames = None
95
+ conversation = [
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {
100
+ "type": "video",
101
+ "video": video_path,
102
+ },
103
+ {"type": "text", "text": prompt}
104
+ ],
105
+ },
106
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ inputs = self.processor.apply_chat_template(
109
+ conversation,
110
+ add_generation_prompt=True,
111
+ tokenize=True,
112
+ return_dict=True,
113
+ return_tensors="pt",
114
+ do_sample_frames=True,
115
+ fps=fps,
116
+ num_frames=num_frames
117
+ ).to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with torch.no_grad():
119
+ out = self.model.generate(
120
+ **inputs,
121
+ do_sample=do_sample,
 
 
122
  temperature=temperature,
123
  max_new_tokens=max_new_tokens,
124
+ **kwargs,
 
125
  )
126
+ raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
127
+ response = raw_response.split("assistant")[1].strip()
128
+ return response
129
 
130
+ # def chat_with_confidence(
131
+ # self,
132
+ # prompt: str,
133
+ # video_path: str,
134
+ # fps: float = 1.0,
135
+ # max_new_tokens: int = 512,
136
+ # temperature: float = 0.7,
137
+ # do_sample: Optional[
138
+ # bool
139
+ # ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
140
+ # token_choices: Optional[List[str]] = ["Yes", "No"],
141
+ # logits_temperature: Optional[float] = 1.0,
142
+ # return_confidence: Optional[bool] = False,
143
+ # top_k_tokens: Optional[int] = 10,
144
+ # debug: Optional[bool] = False,
145
+ # ) -> Dict[str, Any]:
146
+ # video, _, _ = self.load_video(video_path, fps)
147
+ # video = self.image_processor.preprocess(video, return_tensors="pt")[
148
+ # "pixel_values"
149
+ # ].to(device=self.model.device, dtype=self.dtype)
150
+ # video = [video]
151
+ # conv_template = (
152
+ # "qwen_1_5" # Make sure you use correct chat template for different models
153
+ # )
154
+ # question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
155
+ # conv = copy.deepcopy(conv_templates[conv_template])
156
+ # conv.append_message(conv.roles[0], question)
157
+ # conv.append_message(conv.roles[1], None)
158
+ # prompt_question = conv.get_prompt()
159
+ # input_ids = (
160
+ # tokenizer_image_token(
161
+ # prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
162
+ # )
163
+ # .unsqueeze(0)
164
+ # .to(self.model.device)
165
+ # )
166
+ # with torch.no_grad():
167
+ # outputs = self.model.generate(
168
+ # input_ids,
169
+ # images=video,
170
+ # modalities=["video"],
171
+ # do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
172
+ # temperature=temperature,
173
+ # max_new_tokens=max_new_tokens,
174
+ # output_scores=True,
175
+ # return_dict_in_generate=True,
176
+ # )
177
+ # generated_ids = outputs.sequences
178
+ # scores = outputs.scores # Tuple of tensors, one per generated token
179
 
180
+ # print(f"Number of generated tokens: {len(scores)}")
181
+ # print(f"Vocabulary size: {scores[0].shape[1]}")
182
+ # # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
183
+ # if debug:
184
+ # print("****Running inference in debug mode****")
185
+ # # Print first token scores shape and max/min scores in debug mode
186
+ # print(f"Single token scores shape: {scores[0].shape}")
187
+ # print(
188
+ # f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
189
+ # )
 
 
 
 
 
 
190
 
191
+ # # Print details about top 10 tokens based on logits
192
+ # logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
193
+ # print(f"\n{'─'*80}")
194
+ # print(
195
+ # f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
196
+ # )
197
+ # print(f"{'─'*80}")
198
+ # top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
199
+ # for i in range(top_k_tokens):
200
+ # score = top_k_tokens_scores.values[0, i].item()
201
+ # score_index = top_k_tokens_scores.indices[0, i].item()
202
+ # token = self.tokenizer.decode(score_index)
203
+ # print(f"#{i+1}th Token: {token}")
204
+ # print(f"#{i+1}th Token index: {score_index}")
205
+ # print(f"#{i+1}th Token score: {score}")
206
+ # print("--------------------------------")
207
 
208
+ # # Decode the text
209
+ # output_response = self.tokenizer.batch_decode(
210
+ # generated_ids,
211
+ # skip_special_tokens=True,
212
+ # clean_up_tokenization_spaces=False,
213
+ # )[0]
214
 
215
+ # # Convert scores to probabilities
216
+ # # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
217
+ # selected_token_probs = []
218
+ # selected_token_logits = []
219
+ # first_token_probs = torch.softmax(scores[0], dim=-1)
 
 
 
220
 
221
+ # # Now, find indices of tokens in token_choices and get their probabilities
222
+ # for token_choice in token_choices:
223
+ # # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
224
+ # token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
225
+ # 0
226
+ # ]
227
+ # selected_token_probs.append(first_token_probs[0, token_index].item())
228
+ # selected_token_logits.append(scores[0][0, token_index].item())
229
+
230
+ # # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
231
+ # if return_confidence:
232
+ # first_token_id = generated_ids[0][
233
+ # 0
234
+ # ].item() # First token of the first sequence
235
+ # confidence = (
236
+ # first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
237
+ # if sum(selected_token_probs) > 0
238
+ # else 0.0
239
+ # )
240
+ # return {
241
+ # "response": output_response,
242
+ # "confidence": confidence,
243
+ # }
244
 
245
+ # # Return token logits
246
+ # else:
247
+ # token_logits = dict(zip(token_choices, selected_token_logits))
248
+ # top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
249
+ # top_k_tokens_list: List[Tuple[str, int, float]] = []
250
+ # for i in range(top_k_tokens):
251
+ # logit_index = top_k_logits_indices.indices[0, i].item()
252
+ # token = self.tokenizer.decode(logit_index)
253
+ # logit = top_k_logits_indices.values[0, i].item()
254
+ # top_k_tokens_list.append((token, logit_index, logit))
255
+ # return {
256
+ # "response": output_response,
257
+ # "top_k_tokens": top_k_tokens_list,
258
+ # "token_logits": token_logits,
259
+ # }
260
 
261
 
262
+ # if __name__ == "__main__":
263
+ # model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
264
+ # device_map = "cuda:0"
265
+ # model = LLaVAVideoModel(model_path, device_map=device_map)
266
+ # prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
267
+ # token_choices = ["Yes", "No"]
268
+ # video_path = (
269
+ # "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
270
+ # )
271
 
272
+ # generation_config = {
273
+ # "max_new_tokens": 128,
274
+ # "do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
275
+ # "temperature": 0.7,
276
+ # "logits_temperature": 1.0,
277
+ # "fps": 1.0,
278
+ # "return_confidence": False,
279
+ # "top_k_tokens": 10,
280
+ # "debug": False,
281
+ # }
282
+ # output = model.chat_with_confidence(
283
+ # prompt, video_path, token_choices=token_choices, **generation_config
284
+ # )
285
+ # response = output["response"]
286
+ # print(f"Response: {response}")
287
 
288
+ # if generation_config["return_confidence"]:
289
+ # confidence = output["confidence"]
290
+ # print(f"Confidence: {confidence}")
291
+ # else:
292
+ # # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
293
+ # # otherwise, the raw logits are used, which are not filtered.
294
+ # logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
295
+ # print(f"\n{'─'*80}")
296
+ # print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
297
+ # print(f"{'─'*80}")
298
+ # top_k_tokens = output["top_k_tokens"]
299
+ # for i in range(len(top_k_tokens)):
300
+ # print(f"Top {i+1} token: {top_k_tokens[i][0]}")
301
+ # print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
302
+ # print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
303
+ # print("--------------------------------")
models/qwen2_5vl.py CHANGED
@@ -2,11 +2,12 @@
2
 
3
  import torch
4
  from transformers import (
5
- Qwen2_5_VLForConditionalGeneration,
6
  AutoProcessor,
 
7
  )
8
  from typing import Optional, Dict, Any, Union, List
9
- from qwen_vl_utils import process_vision_info
10
 
11
  # Handle both relative and absolute imports
12
  try:
@@ -22,10 +23,22 @@ class Qwen2_5VLModel(BaseVideoModel):
22
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
23
  device_map: Optional[Union[str, Dict]] = "auto",
24
  attn_implementation: Optional[str] = "flash_attention_2",
 
 
25
  ):
26
  super().__init__(model_name)
27
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
 
 
 
 
 
 
 
 
28
  model_name,
 
29
  dtype=dtype,
30
  device_map=device_map,
31
  attn_implementation=attn_implementation,
@@ -36,257 +49,254 @@ class Qwen2_5VLModel(BaseVideoModel):
36
  self,
37
  prompt: str,
38
  video_path: str,
39
- fps: float = 1.0,
40
  temperature: float = 0.7,
41
  max_new_tokens: int = 512,
42
  do_sample: Optional[bool] = True,
 
 
 
43
  **kwargs: Any,
44
  ) -> str:
 
 
 
 
 
45
  # Messages containing a local video path and a text query
46
- messages = [
47
  {
48
  "role": "user",
49
  "content": [
50
  {
51
- "type": "video",
52
  "video": video_path,
53
- # "max_pixels": 360 * 420,
54
- "fps": fps,
55
  },
56
- {"type": "text", "text": prompt},
57
  ],
58
- }
59
  ]
60
 
61
- text = self.processor.apply_chat_template(
62
- messages, tokenize=False, add_generation_prompt=True
63
- )
64
- image_inputs, video_inputs, video_kwargs = process_vision_info(
65
- messages, return_video_kwargs=True
66
- )
67
- inputs = self.processor(
68
- text=[text],
69
- images=image_inputs,
70
- videos=video_inputs,
71
- padding=True,
72
  return_tensors="pt",
73
- **video_kwargs,
74
- )
75
- inputs = inputs.to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Inference
78
- generated_ids = self.model.generate(
79
- **inputs,
80
- do_sample=do_sample,
81
- temperature=temperature,
82
- max_new_tokens=max_new_tokens,
83
- **kwargs,
84
- )
85
- generated_ids_trimmed = [
86
- out_ids[len(in_ids) :]
87
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
88
- ]
89
- output_response = self.processor.batch_decode(
90
- generated_ids_trimmed,
91
- skip_special_tokens=True,
92
- clean_up_tokenization_spaces=False,
93
- )[0]
94
- return output_response
95
 
96
- def chat_with_confidence(
97
- self,
98
- prompt: str,
99
- video_path: str,
100
- fps: float = 1.0,
101
- max_new_tokens: int = 512,
102
- temperature: float = 0.7,
103
- token_choices: Optional[List[str]] = ["Yes", "No"],
104
- logits_temperature: Optional[float] = 1.0,
105
- return_confidence: Optional[bool] = False,
106
- debug: Optional[bool] = False,
107
- ) -> Dict[str, Any]:
108
- """
109
- Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
110
 
111
- Args:
112
- prompt (str): The text prompt to generate a response for.
113
- video_path (str): The path to the video file.
114
- fps (float, optional): The frames per second of the video. Defaults to 1.0.
115
- max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
116
- temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
117
- logits_temperature (float, optional): The logits temperature to use for generation. Defaults to 1.0.
118
- token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
119
- return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
120
- debug (bool, optional): Whether to run in debug mode. Defaults to False.
121
 
122
- Returns:
123
- Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
 
 
 
 
 
 
 
124
 
125
- e.g., return_confidence: False
126
- Output:
127
- {
128
- "response": "Yes",
129
- "logits": {
130
- "Yes": 12.0,
131
- "No": 9.0
132
- }
133
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- e.g., return_confidence: True
136
- Output:
137
- {
138
- "response": "Yes",
139
- "confidence": 0.9999
140
- }
141
- """
142
- # Messages containing a local video path and a text query
143
- messages = [
144
- {
145
- "role": "user",
146
- "content": [
147
- {
148
- "type": "video",
149
- "video": video_path,
150
- # "max_pixels": 360 * 420,
151
- "fps": fps,
152
- },
153
- {"type": "text", "text": prompt},
154
- ],
155
- }
156
- ]
157
-
158
- text = self.processor.apply_chat_template(
159
- messages, tokenize=False, add_generation_prompt=True
160
- )
161
- image_inputs, video_inputs, video_kwargs = process_vision_info(
162
- messages, return_video_kwargs=True
163
- )
164
- inputs = self.processor(
165
- text=[text],
166
- images=image_inputs,
167
- videos=video_inputs,
168
- padding=True,
169
- return_tensors="pt",
170
- **video_kwargs,
171
- )
172
- inputs = inputs.to(self.model.device)
173
 
174
- # Inference with scores
175
- with torch.no_grad():
176
- outputs = self.model.generate(
177
- **inputs,
178
- temperature=temperature,
179
- max_new_tokens=max_new_tokens,
180
- output_scores=True,
181
- return_dict_in_generate=True,
182
- )
183
 
184
- generated_ids = outputs.sequences
185
- scores = outputs.scores # Tuple of tensors, one per generated token
186
- scores = tuple(
187
- s / logits_temperature for s in scores
188
- ) # Scales the logits by a factor for normalization during reporting
189
 
190
- print(f"Number of generated tokens: {len(scores)}")
191
- print(f"Vocabulary size: {scores[0].shape[1]}")
192
- # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
193
- if debug:
194
- print("****Running inference in debug mode****")
195
- # Print first token scores shape and max/min scores in debug mode
196
- print(f"Single token scores shape: {scores[0].shape}")
197
- print(
198
- f"First token max/min scores: {scores[0].max().item()}, {scores[0].min().item()}"
199
- )
200
- # Print details about top 3 tokens
201
- top_3_tokens = torch.topk(scores[0], k=3, dim=-1)
202
- for i in range(3):
203
- print(
204
- f"Pos 0 | {i+1}th Token: {self.processor.decode(top_3_tokens.indices[0, i].item())}"
205
- )
206
- print(
207
- f"Pos 0 | {i+1}th Token logit: {top_3_tokens.values[0, i].item()}"
208
- )
209
 
210
- # Trim the prompt tokens from generated sequences
211
- generated_ids_trimmed = [
212
- out_ids[len(in_ids) :]
213
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
214
- ]
215
 
216
- # Decode the text
217
- output_response = self.processor.batch_decode(
218
- generated_ids_trimmed,
219
- skip_special_tokens=True,
220
- clean_up_tokenization_spaces=False,
221
- )[0]
222
 
223
- # Convert scores to probabilities
224
- # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
225
- selected_token_probs = []
226
- selected_token_logits = []
227
- first_token_probs = torch.softmax(scores[0], dim=-1)
228
 
229
- # Now, find indices of tokens in token_choices and get their probabilities
230
- for token_choice in token_choices:
231
- # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
232
- token_index = self.processor.tokenizer.encode(
233
- token_choice, add_special_tokens=False
234
- )[0]
235
- selected_token_probs.append(first_token_probs[0, token_index].item())
236
- selected_token_logits.append(scores[0][0, token_index].item())
237
 
238
- # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
239
- if return_confidence:
240
- first_token_id = generated_ids_trimmed[0][
241
- 0
242
- ].item() # First token of the first sequence
243
- confidence = (
244
- first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
245
- if sum(selected_token_probs) > 0
246
- else 0.0
247
- )
248
- return {
249
- "response": output_response,
250
- "confidence": confidence,
251
- }
252
 
253
- # Retrn token logits
254
- else:
255
- token_logits = dict(zip(token_choices, selected_token_logits))
256
- return {
257
- "response": output_response,
258
- "logits": token_logits,
259
- }
260
 
261
 
262
- if __name__ == "__main__":
263
- model_path = "Qwen/Qwen2.5-VL-7B-Instruct" # "Qwen/Qwen2.5-VL-7B-Instruct"
264
- model = Qwen2_5VLModel(model_path)
265
- prompt = (
266
- "Which of the following exist in the video? Answer in A or B.\nA: Hand\nB: Face"
267
- )
268
- token_choices = ["A", "B"]
269
- ext = ".webm"
270
- video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
271
 
272
- generation_config = {
273
- "max_new_tokens": 128,
274
- "temperature": 0.7,
275
- "logits_temperature": 5.0,
276
- "fps": 3.0,
277
- "return_confidence": False,
278
- "debug": True,
279
- }
280
- output = model.chat_with_confidence(
281
- prompt, video_path, token_choices=token_choices, **generation_config
282
- )
283
- response = output["response"]
284
- print(f"Response: {response}")
285
 
286
- if generation_config["return_confidence"]:
287
- confidence = output["confidence"]
288
- print(f"Confidence: {confidence}")
289
- else:
290
- selected_token_logits = output["logits"]
291
- print(f"Selected token logits: {selected_token_logits}")
292
- print(f"Logits temperature: {generation_config['logits_temperature']}")
 
2
 
3
  import torch
4
  from transformers import (
5
+ AutoModelForImageTextToText,
6
  AutoProcessor,
7
+ BitsAndBytesConfig,
8
  )
9
  from typing import Optional, Dict, Any, Union, List
10
+ # from qwen_vl_utils import process_vision_info
11
 
12
  # Handle both relative and absolute imports
13
  try:
 
23
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
24
  device_map: Optional[Union[str, Dict]] = "auto",
25
  attn_implementation: Optional[str] = "flash_attention_2",
26
+ load_8bit: Optional[bool] = False,
27
+ load_4bit: Optional[bool] = False,
28
  ):
29
  super().__init__(model_name)
30
+ self.dtype = dtype
31
+ quantization_config = None
32
+ if load_8bit or load_4bit:
33
+ quantization_config = BitsAndBytesConfig(
34
+ load_in_8bit=load_8bit,
35
+ load_in_4bit=load_4bit,
36
+ bnb_4bit_quant_type="nf4",
37
+ bnb_4bit_compute_dtype=torch.float16
38
+ )
39
+ self.model = AutoModelForImageTextToText.from_pretrained(
40
  model_name,
41
+ quantization_config=quantization_config,
42
  dtype=dtype,
43
  device_map=device_map,
44
  attn_implementation=attn_implementation,
 
49
  self,
50
  prompt: str,
51
  video_path: str,
 
52
  temperature: float = 0.7,
53
  max_new_tokens: int = 512,
54
  do_sample: Optional[bool] = True,
55
+ fps: Optional[float] = 1.0,
56
+ num_frames: Optional[int] = 10,
57
+ video_mode: Optional[str] = "video",
58
  **kwargs: Any,
59
  ) -> str:
60
+ # Ensure only one of fps or num_frames is provided
61
+ if video_mode == "frames":
62
+ fps = None
63
+ elif video_mode == "video":
64
+ num_frames = None
65
  # Messages containing a local video path and a text query
66
+ conversation = [
67
  {
68
  "role": "user",
69
  "content": [
70
  {
71
+ "type": "video",
72
  "video": video_path,
 
 
73
  },
74
+ {"type": "text", "text": prompt}
75
  ],
76
+ },
77
  ]
78
 
79
+ inputs = self.processor.apply_chat_template(
80
+ conversation,
81
+ add_generation_prompt=True,
82
+ tokenize=True,
83
+ return_dict=True,
 
 
 
 
 
 
84
  return_tensors="pt",
85
+ do_sample_frames=True,
86
+ fps=fps,
87
+ num_frames=num_frames
88
+ ).to(self.model.device)
89
+ with torch.no_grad():
90
+ out = self.model.generate(
91
+ **inputs,
92
+ do_sample=do_sample,
93
+ temperature=temperature,
94
+ max_new_tokens=max_new_tokens,
95
+ **kwargs,
96
+ )
97
+ raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
98
+ response = raw_response.split("assistant")[1].strip()
99
+ return response
100
+
101
 
102
+ # def chat_with_confidence(
103
+ # self,
104
+ # prompt: str,
105
+ # video_path: str,
106
+ # fps: Optional[float] = 1.0,
107
+ # num_frames: Optional[int] = 10,
108
+ # max_new_tokens: int = 512,
109
+ # temperature: float = 0.7,
110
+ # do_sample: Optional[bool] = True,
111
+ # video_mode: Optional[str] = "video",
112
+ # token_choices: Optional[List[str]] = ["Yes", "No"],
113
+ # logits_temperature: Optional[float] = 1.0,
114
+ # return_confidence: Optional[bool] = False,
115
+ # debug: Optional[bool] = False,
116
+ # **kwargs: Any,
117
+ # ) -> Dict[str, Any]:
118
+ # """
119
+ # Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
120
 
121
+ # Args:
122
+ # prompt (str): The text prompt to generate a response for.
123
+ # video_path (str): The path to the video file.
124
+ # fps (float, optional): The frames per second of the video. Defaults to 1.0.
125
+ # max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
126
+ # temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
127
+ # logits_temperature (float, optional): The logits temperature to use for generation. Defaults to 1.0.
128
+ # token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
129
+ # return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
130
+ # debug (bool, optional): Whether to run in debug mode. Defaults to False.
 
 
 
 
131
 
132
+ # Returns:
133
+ # Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
 
 
 
 
 
 
 
 
134
 
135
+ # e.g., return_confidence: False
136
+ # Output:
137
+ # {
138
+ # "response": "Yes",
139
+ # "logits": {
140
+ # "Yes": 12.0,
141
+ # "No": 9.0
142
+ # }
143
+ # }
144
 
145
+ # e.g., return_confidence: True
146
+ # Output:
147
+ # {
148
+ # "response": "Yes",
149
+ # "confidence": 0.9999
150
+ # }
151
+ # """
152
+ # # Messages containing a local video path and a text query
153
+ # messages = [
154
+ # {
155
+ # "role": "user",
156
+ # "content": [
157
+ # {
158
+ # "type": "video",
159
+ # "video": video_path,
160
+ # # "max_pixels": 360 * 420,
161
+ # "fps": fps,
162
+ # },
163
+ # {"type": "text", "text": prompt},
164
+ # ],
165
+ # }
166
+ # ]
167
 
168
+ # text = self.processor.apply_chat_template(
169
+ # messages, tokenize=False, add_generation_prompt=True
170
+ # )
171
+ # image_inputs, video_inputs, video_kwargs = process_vision_info(
172
+ # messages, return_video_kwargs=True
173
+ # )
174
+ # inputs = self.processor(
175
+ # text=[text],
176
+ # images=image_inputs,
177
+ # videos=video_inputs,
178
+ # padding=True,
179
+ # return_tensors="pt",
180
+ # **video_kwargs,
181
+ # )
182
+ # inputs = inputs.to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # # Inference with scores
185
+ # with torch.no_grad():
186
+ # outputs = self.model.generate(
187
+ # **inputs,
188
+ # temperature=temperature,
189
+ # max_new_tokens=max_new_tokens,
190
+ # output_scores=True,
191
+ # return_dict_in_generate=True,
192
+ # )
193
 
194
+ # generated_ids = outputs.sequences
195
+ # scores = outputs.scores # Tuple of tensors, one per generated token
196
+ # scores = tuple(
197
+ # s / logits_temperature for s in scores
198
+ # ) # Scales the logits by a factor for normalization during reporting
199
 
200
+ # print(f"Number of generated tokens: {len(scores)}")
201
+ # print(f"Vocabulary size: {scores[0].shape[1]}")
202
+ # # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
203
+ # if debug:
204
+ # print("****Running inference in debug mode****")
205
+ # # Print first token scores shape and max/min scores in debug mode
206
+ # print(f"Single token scores shape: {scores[0].shape}")
207
+ # print(
208
+ # f"First token max/min scores: {scores[0].max().item()}, {scores[0].min().item()}"
209
+ # )
210
+ # # Print details about top 3 tokens
211
+ # top_3_tokens = torch.topk(scores[0], k=3, dim=-1)
212
+ # for i in range(3):
213
+ # print(
214
+ # f"Pos 0 | {i+1}th Token: {self.processor.decode(top_3_tokens.indices[0, i].item())}"
215
+ # )
216
+ # print(
217
+ # f"Pos 0 | {i+1}th Token logit: {top_3_tokens.values[0, i].item()}"
218
+ # )
219
 
220
+ # # Trim the prompt tokens from generated sequences
221
+ # generated_ids_trimmed = [
222
+ # out_ids[len(in_ids) :]
223
+ # for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
224
+ # ]
225
 
226
+ # # Decode the text
227
+ # output_response = self.processor.batch_decode(
228
+ # generated_ids_trimmed,
229
+ # skip_special_tokens=True,
230
+ # clean_up_tokenization_spaces=False,
231
+ # )[0]
232
 
233
+ # # Convert scores to probabilities
234
+ # # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
235
+ # selected_token_probs = []
236
+ # selected_token_logits = []
237
+ # first_token_probs = torch.softmax(scores[0], dim=-1)
238
 
239
+ # # Now, find indices of tokens in token_choices and get their probabilities
240
+ # for token_choice in token_choices:
241
+ # # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
242
+ # token_index = self.processor.tokenizer.encode(
243
+ # token_choice, add_special_tokens=False
244
+ # )[0]
245
+ # selected_token_probs.append(first_token_probs[0, token_index].item())
246
+ # selected_token_logits.append(scores[0][0, token_index].item())
247
 
248
+ # # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
249
+ # if return_confidence:
250
+ # first_token_id = generated_ids_trimmed[0][
251
+ # 0
252
+ # ].item() # First token of the first sequence
253
+ # confidence = (
254
+ # first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
255
+ # if sum(selected_token_probs) > 0
256
+ # else 0.0
257
+ # )
258
+ # return {
259
+ # "response": output_response,
260
+ # "confidence": confidence,
261
+ # }
262
 
263
+ # # Retrn token logits
264
+ # else:
265
+ # token_logits = dict(zip(token_choices, selected_token_logits))
266
+ # return {
267
+ # "response": output_response,
268
+ # "logits": token_logits,
269
+ # }
270
 
271
 
272
+ # if __name__ == "__main__":
273
+ # model_path = "Qwen/Qwen2.5-VL-7B-Instruct" # "Qwen/Qwen2.5-VL-7B-Instruct"
274
+ # model = Qwen2_5VLModel(model_path)
275
+ # prompt = (
276
+ # "Which of the following exist in the video? Answer in A or B.\nA: Hand\nB: Face"
277
+ # )
278
+ # token_choices = ["A", "B"]
279
+ # ext = ".webm"
280
+ # video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
281
 
282
+ # generation_config = {
283
+ # "max_new_tokens": 128,
284
+ # "temperature": 0.7,
285
+ # "logits_temperature": 5.0,
286
+ # "fps": 3.0,
287
+ # "return_confidence": False,
288
+ # "debug": True,
289
+ # }
290
+ # output = model.chat_with_confidence(
291
+ # prompt, video_path, token_choices=token_choices, **generation_config
292
+ # )
293
+ # response = output["response"]
294
+ # print(f"Response: {response}")
295
 
296
+ # if generation_config["return_confidence"]:
297
+ # confidence = output["confidence"]
298
+ # print(f"Confidence: {confidence}")
299
+ # else:
300
+ # selected_token_logits = output["logits"]
301
+ # print(f"Selected token logits: {selected_token_logits}")
302
+ # print(f"Logits temperature: {generation_config['logits_temperature']}")
models/qwen3vl.py CHANGED
@@ -2,14 +2,11 @@
2
 
3
  import torch
4
  from transformers import (
5
- Qwen3VLForConditionalGeneration,
6
  AutoProcessor,
 
7
  )
8
  from typing import Optional, Dict, Any, Union, List, Tuple
9
- from qwen_vl_utils import process_vision_info
10
- import cv2
11
- import numpy as np
12
- from PIL import Image
13
 
14
  # Handle both relative and absolute imports
15
  try:
@@ -18,50 +15,31 @@ except ImportError:
18
  from base import BaseVideoModel
19
 
20
 
21
- def downsample_video(video_path, max_dim=720, num_frames=10):
22
- vidcap = cv2.VideoCapture(video_path)
23
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
24
- frames = []
25
- frame_indices = np.linspace(
26
- 0, total_frames - 1, min(total_frames, num_frames), dtype=int
27
- )
28
-
29
- for i in frame_indices:
30
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
31
- success, image = vidcap.read()
32
- if success:
33
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
34
-
35
- h, w = image.shape[:2]
36
- scale = max_dim / max(h, w)
37
- if scale < 1:
38
- image = cv2.resize(
39
- image,
40
- (int(w * scale), int(h * scale)),
41
- interpolation=cv2.INTER_AREA,
42
- )
43
-
44
- pil_image = Image.fromarray(image)
45
- frames.append(pil_image)
46
-
47
- vidcap.release()
48
- return frames
49
-
50
-
51
  class Qwen3VLModel(BaseVideoModel):
52
  def __init__(
53
  self,
54
- model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
55
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
56
  device_map: Optional[Union[str, Dict]] = "auto",
57
  attn_implementation: Optional[str] = "flash_attention_2",
 
 
58
  ):
59
  super().__init__(model_name)
60
- self.model = Qwen3VLForConditionalGeneration.from_pretrained(
 
 
 
 
 
 
 
 
61
  model_name,
62
- dtype=dtype,
63
  device_map=device_map,
64
  attn_implementation=attn_implementation,
 
65
  )
66
  self.processor = AutoProcessor.from_pretrained(model_name)
67
 
@@ -69,467 +47,443 @@ class Qwen3VLModel(BaseVideoModel):
69
  self,
70
  prompt: str,
71
  video_path: str,
72
- fps: float = 1.0,
73
  temperature: float = 0.7,
74
  do_sample: Optional[
75
  bool
76
  ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
77
  max_new_tokens: int = 512,
78
  video_mode: Optional[str] = "video", # Choose from "video" or "frames"
79
- video_frames: Optional[int] = 10,
 
80
  **kwargs: Any,
81
  ) -> str:
82
- # Messages containing a local video path and a text query
83
- messages = [
84
- {
85
- "role": "user",
86
- "content": [
87
- {"type": "text", "text": prompt},
88
- ],
89
- }
90
- ]
91
- if video_mode == "video":
92
- messages[0]["content"].append(
93
- {
94
- "type": "video",
95
- "video": video_path,
96
- # "max_pixels": 360 * 420,
97
- "fps": fps,
98
- }
99
- )
100
- inputs = self.processor.apply_chat_template(
101
- messages,
102
- tokenize=True,
103
- add_generation_prompt=True,
104
- return_dict=True,
105
- return_tensors="pt",
106
- )
107
-
108
- elif video_mode == "frames":
109
- frames = downsample_video(video_path, max_dim=720, num_frames=video_frames)
110
- images_for_processor = []
111
- for frame in frames:
112
- messages[0]["content"].append({"type": "image"})
113
- images_for_processor.append(frame)
114
- prompt_full = self.processor.apply_chat_template(
115
- messages, tokenize=False, add_generation_prompt=True
116
- )
117
- inputs = self.processor(
118
- text=[prompt_full],
119
- images=images_for_processor,
120
- return_tensors="pt",
121
- padding=True,
122
- )
123
-
124
- inputs = inputs.to(self.model.device)
125
-
126
- generated_ids = self.model.generate(
127
- **inputs,
128
- max_new_tokens=max_new_tokens,
129
- temperature=temperature,
130
- do_sample=do_sample,
131
- **kwargs,
132
- )
133
-
134
- generated_ids_trimmed = [
135
- out_ids[len(in_ids) :]
136
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
137
- ]
138
-
139
- output_response = self.processor.batch_decode(
140
- generated_ids_trimmed,
141
- skip_special_tokens=True,
142
- clean_up_tokenization_spaces=False,
143
- )[0]
144
-
145
- return output_response
146
-
147
- def chat_with_confidence(
148
- self,
149
- prompt: str,
150
- video_path: str,
151
- fps: float = 1.0,
152
- max_new_tokens: int = 512,
153
- temperature: float = 0.7,
154
- do_sample: Optional[
155
- bool
156
- ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
157
- token_choices: Optional[List[str]] = ["Yes", "No"],
158
- logits_temperature: Optional[float] = 1.0,
159
- return_confidence: Optional[bool] = False,
160
- top_k_tokens: Optional[int] = 10,
161
- debug: Optional[bool] = False,
162
- **kwargs: Any,
163
- ) -> Dict[str, Any]:
164
- """
165
- Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
166
-
167
- Args:
168
- prompt (str): The text prompt to generate a response for.
169
- video_path (str): The path to the video file.
170
- temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
171
- max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512.
172
- token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
173
- generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
174
- return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
175
- top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
176
- debug (bool, optional): Whether to run in debug mode. Defaults to False.
177
-
178
- Returns:
179
- Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
180
-
181
- e.g., return_confidence: False
182
- Output:
183
- {
184
- "response": "Yes",
185
- "top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
186
- }
187
-
188
- e.g., return_confidence: True
189
- Output:
190
- {
191
- "response": "Yes",
192
- "confidence": 0.9999
193
- }
194
- """
195
- # Messages containing a local video path and a text query
196
- messages = [
197
  {
198
  "role": "user",
199
  "content": [
200
  {
201
- "type": "video",
202
  "video": video_path,
203
- # "max_pixels": 360 * 420,
204
- "fps": fps,
205
  },
206
- {"type": "text", "text": prompt},
207
  ],
208
- }
209
  ]
210
 
211
  inputs = self.processor.apply_chat_template(
212
- messages,
213
- tokenize=True,
214
- add_generation_prompt=True,
215
- return_dict=True,
216
  return_tensors="pt",
217
- )
218
-
219
- inputs = inputs.to(self.model.device)
220
-
221
- # In debug mode, inspect what logits processors will be used
222
- if debug:
223
- print("\n" + "=" * 80)
224
- print("INSPECTING GENERATION CONFIG & WARPERS")
225
- print("=" * 80)
226
- # Get the generation config to see what processors will be added
227
- gen_config = self.model.generation_config
228
- print(f"Generation config attributes:")
229
- print(f" Processor-related:")
230
- print(
231
- f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
232
- )
233
- print(
234
- f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
235
- )
236
- print(
237
- f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
238
- )
239
- print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
240
- print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
241
- print(
242
- f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
243
- )
244
- print(
245
- f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
246
- )
247
- print(f" Warper-related (THESE MASK TOKENS TO -INF):")
248
- print(f" - temperature: {temperature} (passed as arg)")
249
- print(
250
- f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
251
- )
252
- print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
253
- print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
254
- print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
255
- print(
256
- f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
257
- )
258
- print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
259
- print(
260
- f"\n ⚠️ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
261
- )
262
- print("=" * 80 + "\n")
263
-
264
- # Inference with scores
265
  with torch.no_grad():
266
- outputs = self.model.generate(
267
  **inputs,
 
268
  temperature=temperature,
269
  max_new_tokens=max_new_tokens,
270
- do_sample=do_sample,
271
- output_scores=True,
272
- output_logits=True, # Get TRUE raw logits before any processing
273
- return_dict_in_generate=True,
274
  **kwargs,
275
  )
276
-
277
- generated_ids = outputs.sequences
278
- scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling
279
- logits = (
280
- outputs.logits if hasattr(outputs, "logits") else None
281
- ) # TRUE raw logits from model
282
-
283
- scores = tuple(
284
- s / logits_temperature for s in scores
285
- ) # Scales the logits by a factor for normalization during reporting
286
-
287
- print(f"Number of generated tokens: {len(scores)}")
288
- print(f"Vocabulary size: {scores[0].shape[1]}")
289
-
290
- # Check if logits differ from scores
291
- if debug and logits is not None:
292
- print(f"\n[IMPORTANT] output_logits available: True")
293
- print(
294
- f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):"
295
- )
296
- logits_raw = logits[0] / logits_temperature # First token's raw logits
297
- scores_first = scores[0] # First token's processed scores
298
-
299
- logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs()
300
- max_diff = logits_diff.max().item()
301
- if max_diff > 0.001:
302
- print(
303
- f"[IMPORTANT] ⚠️ outputs.scores ARE DIFFERENT from outputs.logits!"
304
- )
305
- print(f"[IMPORTANT] Max difference: {max_diff:.6f}")
306
- print(
307
- f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!"
308
- )
309
- else:
310
- print(f"[IMPORTANT] βœ“ outputs.scores == outputs.logits (both are raw)")
311
- elif debug:
312
- print(
313
- f"\n[IMPORTANT] output_logits not available in this transformers version"
314
- )
315
-
316
- # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
317
- if debug:
318
- print("\n" + "=" * 80)
319
- print("****Running inference in debug mode****")
320
- print("=" * 80)
321
-
322
- # Use truly raw logits if available, otherwise use scores
323
- raw_logits_to_show = (
324
- logits[0] / logits_temperature if logits is not None else scores[0]
325
- )
326
- logits_label = (
327
- "TRUE RAW LOGITS (from outputs.logits)"
328
- if logits is not None
329
- else "LOGITS (from outputs.scores)"
330
- )
331
-
332
- # Print first token scores shape and max/min scores in debug mode
333
- print(
334
- f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}"
335
- )
336
- print(
337
- f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}"
338
- )
339
-
340
- # Print details about top 3 tokens from RAW logits
341
- print(f"\n{'─'*80}")
342
- print(f"TOP 3 TOKENS FROM {logits_label}:")
343
- print(f"{'─'*80}")
344
- top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1)
345
- for i in range(3):
346
- token_id = top_3_tokens.indices[0, i].item()
347
- token_text = self.processor.decode(token_id)
348
- token_logit = top_3_tokens.values[0, i].item()
349
- print(
350
- f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
351
- )
352
-
353
- # Now compare with POST-PROCESSED logits (outputs.scores)
354
- scores_first = scores[0] / logits_temperature
355
- print(f"\n{'─'*80}")
356
- print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):")
357
- print(f"{'─'*80}")
358
- print(
359
- f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}"
360
- )
361
-
362
- top_3_processed = torch.topk(scores_first, k=3, dim=-1)
363
- for i in range(3):
364
- token_id = top_3_processed.indices[0, i].item()
365
- token_text = self.processor.decode(token_id)
366
- token_logit = top_3_processed.values[0, i].item()
367
- print(
368
- f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
369
- )
370
-
371
- # Check if the distributions differ (compare against truly raw logits if available)
372
- print(f"\n{'─'*80}")
373
- print("DIFFERENCE ANALYSIS (Raw β†’ Post-Processed):")
374
- print(f"{'─'*80}")
375
- logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs()
376
- max_diff = logit_diff.max().item()
377
- num_changed = (logit_diff > 0.001).sum().item()
378
-
379
- print(f" Max logit difference: {max_diff:.6f}")
380
- print(
381
- f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}"
382
- )
383
-
384
- if max_diff > 0.001:
385
- print(f"\n ⚠️ LOGITS WERE MODIFIED BY PROCESSORS!")
386
- # Show which tokens changed the most
387
- top_changes = torch.topk(logit_diff[0], k=min(5, num_changed))
388
- print(f"\n Top 5 most changed tokens:")
389
- for i in range(min(5, len(top_changes.indices))):
390
- token_id = top_changes.indices[i].item()
391
- token_text = self.processor.decode(token_id)
392
- raw_logit = raw_logits_to_show[0, token_id].item()
393
- processed_logit = scores_first[0, token_id].item()
394
- diff = top_changes.values[i].item()
395
- print(f" Token='{token_text}' | ID={token_id}")
396
- print(
397
- f" Raw: {raw_logit:.4f} β†’ Processed: {processed_logit:.4f} (Ξ”={diff:.4f})"
398
- )
399
- else:
400
- print(f" βœ“ No significant modifications detected")
401
-
402
- # Show what token was actually selected
403
- print(f"\n{'─'*80}")
404
- print("ACTUALLY GENERATED TOKEN:")
405
- print(f"{'─'*80}")
406
- first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item()
407
- first_generated_token = self.processor.decode(first_generated_id)
408
- raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item()
409
-
410
- print(f" Token: '{first_generated_token}' | ID={first_generated_id}")
411
- print(f" Raw logit: {raw_logit_for_generated:.4f}")
412
-
413
- processed_logit_for_generated = scores_first[0, first_generated_id].item()
414
- print(f" Post-processed logit: {processed_logit_for_generated:.4f}")
415
-
416
- # Check if this token is in top-k of raw logits
417
- top_k_raw_indices = torch.topk(
418
- raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1
419
- ).indices[0]
420
- is_in_top10_raw = first_generated_id in top_k_raw_indices
421
- print(f" In top-10 of RAW logits: {is_in_top10_raw}")
422
-
423
- if not is_in_top10_raw:
424
- print(
425
- f"\n 🚨 CRITICAL: Generated token was NOT in top-10 of raw logits!"
426
- )
427
- print(
428
- f" This proves that logits processors modified the distribution."
429
- )
430
- # Find the rank of the generated token in raw logits
431
- sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True)
432
- raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[
433
- 0
434
- ].item() + 1
435
- print(f" Raw logits rank: {raw_rank}")
436
-
437
- print("=" * 80 + "\n")
438
-
439
- # Trim the prompt tokens from generated sequences
440
- generated_ids_trimmed = [
441
- out_ids[len(in_ids) :]
442
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
443
- ]
444
-
445
- # Decode the text
446
- output_response = self.processor.batch_decode(
447
- generated_ids_trimmed,
448
- skip_special_tokens=True,
449
- clean_up_tokenization_spaces=False,
450
- )[0]
451
-
452
- # Convert scores to probabilities
453
- # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
454
- selected_token_probs = []
455
- selected_token_logits = []
456
- first_token_probs = torch.softmax(scores[0], dim=-1)
457
-
458
- # Now, find indices of tokens in token_choices and get their probabilities
459
- for token_choice in token_choices:
460
- # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
461
- token_index = self.processor.tokenizer.encode(
462
- token_choice, add_special_tokens=False
463
- )[0]
464
- selected_token_probs.append(first_token_probs[0, token_index].item())
465
- selected_token_logits.append(scores[0][0, token_index].item())
466
-
467
- # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
468
- if return_confidence:
469
- first_token_id = generated_ids_trimmed[0][
470
- 0
471
- ].item() # First token of the first sequence
472
- confidence = (
473
- first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
474
- if sum(selected_token_probs) > 0
475
- else 0.0
476
- )
477
- return {
478
- "response": output_response,
479
- "confidence": confidence,
480
- }
481
-
482
- # Return token logits
483
- else:
484
- token_logits = dict(zip(token_choices, selected_token_logits))
485
- top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
486
- top_k_tokens_list: List[Tuple[str, int, float]] = []
487
- for i in range(top_k_tokens):
488
- logit_index = top_k_logits_indices.indices[0, i].item()
489
- token = self.processor.decode(logit_index)
490
- logit = top_k_logits_indices.values[0, i].item()
491
- top_k_tokens_list.append((token, logit_index, logit))
492
- return {
493
- "response": output_response,
494
- "top_k_tokens": top_k_tokens_list,
495
- "token_logits": token_logits,
496
- }
497
-
498
-
499
- if __name__ == "__main__":
500
- model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
501
- model = Qwen3VLModel(model_path)
502
- prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:'
503
- token_choices = ["Yes", "No"]
504
- video_path = (
505
- "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4"
506
- )
507
-
508
- generation_config = {
509
- "max_new_tokens": 128,
510
- "do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
511
- "temperature": 0.7,
512
- "logits_temperature": 1.0,
513
- "fps": 1.0,
514
- "return_confidence": False,
515
- "top_k_tokens": 10,
516
- "debug": False,
517
- }
518
- output = model.chat_with_confidence(
519
- prompt, video_path, token_choices=token_choices, **generation_config
520
- )
521
- response = output["response"]
522
- print(f"Response: {response}")
523
-
524
- if generation_config["return_confidence"]:
525
- confidence = output["confidence"]
526
- print(f"Confidence: {confidence}")
527
- else:
528
- # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
529
- # otherwise, the raw logits are used, which are not filtered.
530
- logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
531
- top_k_tokens = output["top_k_tokens"]
532
- for i in range(len(top_k_tokens)):
533
- print(f"Top {i+1} token: {top_k_tokens[i][0]}")
534
- print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
535
- print("--------------------------------")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from transformers import (
5
+ AutoModelForImageTextToText,
6
  AutoProcessor,
7
+ BitsAndBytesConfig,
8
  )
9
  from typing import Optional, Dict, Any, Union, List, Tuple
 
 
 
 
10
 
11
  # Handle both relative and absolute imports
12
  try:
 
15
  from base import BaseVideoModel
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class Qwen3VLModel(BaseVideoModel):
19
  def __init__(
20
  self,
21
+ model_name: str = "Qwen/Qwen3-VL-4B-Instruct",
22
  dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
23
  device_map: Optional[Union[str, Dict]] = "auto",
24
  attn_implementation: Optional[str] = "flash_attention_2",
25
+ load_8bit: Optional[bool] = False,
26
+ load_4bit: Optional[bool] = False,
27
  ):
28
  super().__init__(model_name)
29
+ quantization_config = None
30
+ if load_8bit or load_4bit:
31
+ quantization_config = BitsAndBytesConfig(
32
+ load_in_8bit=load_8bit,
33
+ load_in_4bit=load_4bit,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.float16
36
+ )
37
+ self.model = AutoModelForImageTextToText.from_pretrained(
38
  model_name,
39
+ quantization_config=quantization_config,
40
  device_map=device_map,
41
  attn_implementation=attn_implementation,
42
+ dtype=dtype,
43
  )
44
  self.processor = AutoProcessor.from_pretrained(model_name)
45
 
 
47
  self,
48
  prompt: str,
49
  video_path: str,
 
50
  temperature: float = 0.7,
51
  do_sample: Optional[
52
  bool
53
  ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
54
  max_new_tokens: int = 512,
55
  video_mode: Optional[str] = "video", # Choose from "video" or "frames"
56
+ fps: Optional[float] = 1.0,
57
+ num_frames: Optional[int] = 10,
58
  **kwargs: Any,
59
  ) -> str:
60
+ # Ensure only one of fps or num_frames is provided
61
+ if video_mode == "frames":
62
+ fps = None
63
+ elif video_mode == "video":
64
+ num_frames = None
65
+ conversation = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  {
67
  "role": "user",
68
  "content": [
69
  {
70
+ "type": "video",
71
  "video": video_path,
 
 
72
  },
73
+ {"type": "text", "text": prompt}
74
  ],
75
+ },
76
  ]
77
 
78
  inputs = self.processor.apply_chat_template(
79
+ conversation,
80
+ add_generation_prompt=True,
81
+ tokenize=True,
82
+ return_dict=True,
83
  return_tensors="pt",
84
+ do_sample_frames=True,
85
+ fps=fps,
86
+ num_frames=num_frames
87
+ ).to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with torch.no_grad():
89
+ out = self.model.generate(
90
  **inputs,
91
+ do_sample=do_sample,
92
  temperature=temperature,
93
  max_new_tokens=max_new_tokens,
 
 
 
 
94
  **kwargs,
95
  )
96
+ raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
97
+ response = raw_response.split("assistant")[1].strip()
98
+ return response
99
+
100
+ # def chat_with_confidence(
101
+ # self,
102
+ # prompt: str,
103
+ # video_path: str,
104
+ # max_new_tokens: int = 512,
105
+ # temperature: float = 0.7,
106
+ # do_sample: Optional[
107
+ # bool
108
+ # ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
109
+ # fps: Optional[float] = 1.0,
110
+ # num_frames: Optional[int] = 10,
111
+ # token_choices: Optional[List[str]] = ["Yes", "No"],
112
+ # logits_temperature: Optional[float] = 1.0,
113
+ # return_confidence: Optional[bool] = False,
114
+ # top_k_tokens: Optional[int] = 10,
115
+ # debug: Optional[bool] = False,
116
+ # **kwargs: Any,
117
+ # ) -> Dict[str, Any]:
118
+ # """
119
+ # Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
120
+
121
+ # Args:
122
+ # prompt (str): The text prompt to generate a response for.
123
+ # video_path (str): The path to the video file.
124
+ # temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
125
+ # max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512.
126
+ # token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
127
+ # generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
128
+ # return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
129
+ # top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
130
+ # debug (bool, optional): Whether to run in debug mode. Defaults to False.
131
+
132
+ # Returns:
133
+ # Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
134
+
135
+ # e.g., return_confidence: False
136
+ # Output:
137
+ # {
138
+ # "response": "Yes",
139
+ # "top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
140
+ # }
141
+
142
+ # e.g., return_confidence: True
143
+ # Output:
144
+ # {
145
+ # "response": "Yes",
146
+ # "confidence": 0.9999
147
+ # }
148
+ # """
149
+ # # Messages containing a local video path and a text query
150
+ # messages = [
151
+ # {
152
+ # "role": "user",
153
+ # "content": [
154
+ # {
155
+ # "type": "video",
156
+ # "video": video_path,
157
+ # # "max_pixels": 360 * 420,
158
+ # "fps": fps,
159
+ # },
160
+ # {"type": "text", "text": prompt},
161
+ # ],
162
+ # }
163
+ # ]
164
+
165
+ # inputs = self.processor.apply_chat_template(
166
+ # messages,
167
+ # tokenize=True,
168
+ # add_generation_prompt=True,
169
+ # return_dict=True,
170
+ # return_tensors="pt",
171
+ # )
172
+
173
+ # inputs = inputs.to(self.model.device)
174
+
175
+ # # In debug mode, inspect what logits processors will be used
176
+ # if debug:
177
+ # print("\n" + "=" * 80)
178
+ # print("INSPECTING GENERATION CONFIG & WARPERS")
179
+ # print("=" * 80)
180
+ # # Get the generation config to see what processors will be added
181
+ # gen_config = self.model.generation_config
182
+ # print(f"Generation config attributes:")
183
+ # print(f" Processor-related:")
184
+ # print(
185
+ # f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
186
+ # )
187
+ # print(
188
+ # f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
189
+ # )
190
+ # print(
191
+ # f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
192
+ # )
193
+ # print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
194
+ # print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
195
+ # print(
196
+ # f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
197
+ # )
198
+ # print(
199
+ # f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
200
+ # )
201
+ # print(f" Warper-related (THESE MASK TOKENS TO -INF):")
202
+ # print(f" - temperature: {temperature} (passed as arg)")
203
+ # print(
204
+ # f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
205
+ # )
206
+ # print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
207
+ # print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
208
+ # print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
209
+ # print(
210
+ # f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
211
+ # )
212
+ # print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
213
+ # print(
214
+ # f"\n ⚠️ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
215
+ # )
216
+ # print("=" * 80 + "\n")
217
+
218
+ # # Inference with scores
219
+ # with torch.no_grad():
220
+ # outputs = self.model.generate(
221
+ # **inputs,
222
+ # temperature=temperature,
223
+ # max_new_tokens=max_new_tokens,
224
+ # do_sample=do_sample,
225
+ # output_scores=True,
226
+ # output_logits=True, # Get TRUE raw logits before any processing
227
+ # return_dict_in_generate=True,
228
+ # **kwargs,
229
+ # )
230
+
231
+ # generated_ids = outputs.sequences
232
+ # scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling
233
+ # logits = (
234
+ # outputs.logits if hasattr(outputs, "logits") else None
235
+ # ) # TRUE raw logits from model
236
+
237
+ # scores = tuple(
238
+ # s / logits_temperature for s in scores
239
+ # ) # Scales the logits by a factor for normalization during reporting
240
+
241
+ # print(f"Number of generated tokens: {len(scores)}")
242
+ # print(f"Vocabulary size: {scores[0].shape[1]}")
243
+
244
+ # # Check if logits differ from scores
245
+ # if debug and logits is not None:
246
+ # print(f"\n[IMPORTANT] output_logits available: True")
247
+ # print(
248
+ # f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):"
249
+ # )
250
+ # logits_raw = logits[0] / logits_temperature # First token's raw logits
251
+ # scores_first = scores[0] # First token's processed scores
252
+
253
+ # logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs()
254
+ # max_diff = logits_diff.max().item()
255
+ # if max_diff > 0.001:
256
+ # print(
257
+ # f"[IMPORTANT] ⚠️ outputs.scores ARE DIFFERENT from outputs.logits!"
258
+ # )
259
+ # print(f"[IMPORTANT] Max difference: {max_diff:.6f}")
260
+ # print(
261
+ # f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!"
262
+ # )
263
+ # else:
264
+ # print(f"[IMPORTANT] βœ“ outputs.scores == outputs.logits (both are raw)")
265
+ # elif debug:
266
+ # print(
267
+ # f"\n[IMPORTANT] output_logits not available in this transformers version"
268
+ # )
269
+
270
+ # # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
271
+ # if debug:
272
+ # print("\n" + "=" * 80)
273
+ # print("****Running inference in debug mode****")
274
+ # print("=" * 80)
275
+
276
+ # # Use truly raw logits if available, otherwise use scores
277
+ # raw_logits_to_show = (
278
+ # logits[0] / logits_temperature if logits is not None else scores[0]
279
+ # )
280
+ # logits_label = (
281
+ # "TRUE RAW LOGITS (from outputs.logits)"
282
+ # if logits is not None
283
+ # else "LOGITS (from outputs.scores)"
284
+ # )
285
+
286
+ # # Print first token scores shape and max/min scores in debug mode
287
+ # print(
288
+ # f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}"
289
+ # )
290
+ # print(
291
+ # f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}"
292
+ # )
293
+
294
+ # # Print details about top 3 tokens from RAW logits
295
+ # print(f"\n{'─'*80}")
296
+ # print(f"TOP 3 TOKENS FROM {logits_label}:")
297
+ # print(f"{'─'*80}")
298
+ # top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1)
299
+ # for i in range(3):
300
+ # token_id = top_3_tokens.indices[0, i].item()
301
+ # token_text = self.processor.decode(token_id)
302
+ # token_logit = top_3_tokens.values[0, i].item()
303
+ # print(
304
+ # f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
305
+ # )
306
+
307
+ # # Now compare with POST-PROCESSED logits (outputs.scores)
308
+ # scores_first = scores[0] / logits_temperature
309
+ # print(f"\n{'─'*80}")
310
+ # print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):")
311
+ # print(f"{'─'*80}")
312
+ # print(
313
+ # f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}"
314
+ # )
315
+
316
+ # top_3_processed = torch.topk(scores_first, k=3, dim=-1)
317
+ # for i in range(3):
318
+ # token_id = top_3_processed.indices[0, i].item()
319
+ # token_text = self.processor.decode(token_id)
320
+ # token_logit = top_3_processed.values[0, i].item()
321
+ # print(
322
+ # f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
323
+ # )
324
+
325
+ # # Check if the distributions differ (compare against truly raw logits if available)
326
+ # print(f"\n{'─'*80}")
327
+ # print("DIFFERENCE ANALYSIS (Raw β†’ Post-Processed):")
328
+ # print(f"{'─'*80}")
329
+ # logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs()
330
+ # max_diff = logit_diff.max().item()
331
+ # num_changed = (logit_diff > 0.001).sum().item()
332
+
333
+ # print(f" Max logit difference: {max_diff:.6f}")
334
+ # print(
335
+ # f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}"
336
+ # )
337
+
338
+ # if max_diff > 0.001:
339
+ # print(f"\n ⚠️ LOGITS WERE MODIFIED BY PROCESSORS!")
340
+ # # Show which tokens changed the most
341
+ # top_changes = torch.topk(logit_diff[0], k=min(5, num_changed))
342
+ # print(f"\n Top 5 most changed tokens:")
343
+ # for i in range(min(5, len(top_changes.indices))):
344
+ # token_id = top_changes.indices[i].item()
345
+ # token_text = self.processor.decode(token_id)
346
+ # raw_logit = raw_logits_to_show[0, token_id].item()
347
+ # processed_logit = scores_first[0, token_id].item()
348
+ # diff = top_changes.values[i].item()
349
+ # print(f" Token='{token_text}' | ID={token_id}")
350
+ # print(
351
+ # f" Raw: {raw_logit:.4f} β†’ Processed: {processed_logit:.4f} (Ξ”={diff:.4f})"
352
+ # )
353
+ # else:
354
+ # print(f" βœ“ No significant modifications detected")
355
+
356
+ # # Show what token was actually selected
357
+ # print(f"\n{'─'*80}")
358
+ # print("ACTUALLY GENERATED TOKEN:")
359
+ # print(f"{'─'*80}")
360
+ # first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item()
361
+ # first_generated_token = self.processor.decode(first_generated_id)
362
+ # raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item()
363
+
364
+ # print(f" Token: '{first_generated_token}' | ID={first_generated_id}")
365
+ # print(f" Raw logit: {raw_logit_for_generated:.4f}")
366
+
367
+ # processed_logit_for_generated = scores_first[0, first_generated_id].item()
368
+ # print(f" Post-processed logit: {processed_logit_for_generated:.4f}")
369
+
370
+ # # Check if this token is in top-k of raw logits
371
+ # top_k_raw_indices = torch.topk(
372
+ # raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1
373
+ # ).indices[0]
374
+ # is_in_top10_raw = first_generated_id in top_k_raw_indices
375
+ # print(f" In top-10 of RAW logits: {is_in_top10_raw}")
376
+
377
+ # if not is_in_top10_raw:
378
+ # print(
379
+ # f"\n 🚨 CRITICAL: Generated token was NOT in top-10 of raw logits!"
380
+ # )
381
+ # print(
382
+ # f" This proves that logits processors modified the distribution."
383
+ # )
384
+ # # Find the rank of the generated token in raw logits
385
+ # sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True)
386
+ # raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[
387
+ # 0
388
+ # ].item() + 1
389
+ # print(f" Raw logits rank: {raw_rank}")
390
+
391
+ # print("=" * 80 + "\n")
392
+
393
+ # # Trim the prompt tokens from generated sequences
394
+ # generated_ids_trimmed = [
395
+ # out_ids[len(in_ids) :]
396
+ # for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
397
+ # ]
398
+
399
+ # # Decode the text
400
+ # output_response = self.processor.batch_decode(
401
+ # generated_ids_trimmed,
402
+ # skip_special_tokens=True,
403
+ # clean_up_tokenization_spaces=False,
404
+ # )[0]
405
+
406
+ # # Convert scores to probabilities
407
+ # # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
408
+ # selected_token_probs = []
409
+ # selected_token_logits = []
410
+ # first_token_probs = torch.softmax(scores[0], dim=-1)
411
+
412
+ # # Now, find indices of tokens in token_choices and get their probabilities
413
+ # for token_choice in token_choices:
414
+ # # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
415
+ # token_index = self.processor.tokenizer.encode(
416
+ # token_choice, add_special_tokens=False
417
+ # )[0]
418
+ # selected_token_probs.append(first_token_probs[0, token_index].item())
419
+ # selected_token_logits.append(scores[0][0, token_index].item())
420
+
421
+ # # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
422
+ # if return_confidence:
423
+ # first_token_id = generated_ids_trimmed[0][
424
+ # 0
425
+ # ].item() # First token of the first sequence
426
+ # confidence = (
427
+ # first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
428
+ # if sum(selected_token_probs) > 0
429
+ # else 0.0
430
+ # )
431
+ # return {
432
+ # "response": output_response,
433
+ # "confidence": confidence,
434
+ # }
435
+
436
+ # # Return token logits
437
+ # else:
438
+ # token_logits = dict(zip(token_choices, selected_token_logits))
439
+ # top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
440
+ # top_k_tokens_list: List[Tuple[str, int, float]] = []
441
+ # for i in range(top_k_tokens):
442
+ # logit_index = top_k_logits_indices.indices[0, i].item()
443
+ # token = self.processor.decode(logit_index)
444
+ # logit = top_k_logits_indices.values[0, i].item()
445
+ # top_k_tokens_list.append((token, logit_index, logit))
446
+ # return {
447
+ # "response": output_response,
448
+ # "top_k_tokens": top_k_tokens_list,
449
+ # "token_logits": token_logits,
450
+ # }
451
+
452
+
453
+ # if __name__ == "__main__":
454
+ # model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
455
+ # model = Qwen3VLModel(model_path)
456
+ # prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:'
457
+ # token_choices = ["Yes", "No"]
458
+ # video_path = (
459
+ # "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4"
460
+ # )
461
+
462
+ # generation_config = {
463
+ # "max_new_tokens": 128,
464
+ # "do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
465
+ # "temperature": 0.7,
466
+ # "logits_temperature": 1.0,
467
+ # "fps": 1.0,
468
+ # "return_confidence": False,
469
+ # "top_k_tokens": 10,
470
+ # "debug": False,
471
+ # }
472
+ # output = model.chat_with_confidence(
473
+ # prompt, video_path, token_choices=token_choices, **generation_config
474
+ # )
475
+ # response = output["response"]
476
+ # print(f"Response: {response}")
477
+
478
+ # if generation_config["return_confidence"]:
479
+ # confidence = output["confidence"]
480
+ # print(f"Confidence: {confidence}")
481
+ # else:
482
+ # # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
483
+ # # otherwise, the raw logits are used, which are not filtered.
484
+ # logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
485
+ # top_k_tokens = output["top_k_tokens"]
486
+ # for i in range(len(top_k_tokens)):
487
+ # print(f"Top {i+1} token: {top_k_tokens[i][0]}")
488
+ # print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
489
+ # print("--------------------------------")