jena-shreyas commited on
Commit
cf5f08b
Β·
1 Parent(s): 40a4325

Add correct models/ repo

Browse files
models/.gitkeep ADDED
File without changes
models/__init__.py CHANGED
@@ -3,51 +3,64 @@ from packaging import version
3
  import torch
4
  from typing import Optional, Union, Dict
5
 
6
- # Required versions
7
  qwen_required_version = version.parse("4.57.0")
 
8
  llava_required_version = version.parse("4.40.0")
9
 
10
  # Conditional imports based on transformers version
11
  try:
12
  import transformers
13
- # More robust import path for newer transformers
 
 
14
  transformers_version = version.parse(transformers.__version__)
15
 
16
  QWEN_MODELS_AVAILABLE = False
 
17
  LLAVA_MODELS_AVAILABLE = False
18
 
19
  # Qwen condition
20
  if transformers_version >= qwen_required_version:
21
- from .qwen2_5 import Qwen2_5VLModel
22
  from .qwen3vl import Qwen3VLModel
 
23
  QWEN_MODELS_AVAILABLE = True
24
  else:
25
  print(
26
- f"Warning: Qwen models require transformers>=4.57.0, "
27
- f"but found {transformers.__version__}. "
28
- f"Qwen models will not be available."
 
 
 
 
 
 
 
 
29
  )
30
 
31
  # LLaVA condition
32
  if transformers_version <= llava_required_version:
33
  from .llava_video import LLaVAVideoModel
 
34
  LLAVA_MODELS_AVAILABLE = True
35
  else:
36
  print(
37
- f"Warning: LLaVA models require transformers<=4.40.0, "
38
- f"but found {transformers.__version__}. "
39
- f"LLaVA models will not be available."
40
  )
41
-
42
- except ImportError as e:
43
- print("Warning: Could not import transformers correctly.")
44
- raise e
45
-
46
 
47
  # Build __all__ list dynamically
48
  __all__ = []
49
  if QWEN_MODELS_AVAILABLE:
50
  __all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
 
 
51
  if LLAVA_MODELS_AVAILABLE:
52
  __all__.append("LLaVAVideoModel")
53
 
@@ -59,12 +72,11 @@ def load_model(
59
  device_map: Optional[Union[str, Dict]] = "auto",
60
  attn_implementation: Optional[str] = "flash_attention_2",
61
  ) -> BaseVideoModel:
62
-
63
  if "LLaVA-Video" in model_path:
64
  if not LLAVA_MODELS_AVAILABLE:
65
  raise ImportError(
66
- "LLaVA models require transformers<=4.40.0. "
67
- "Please downgrade transformers."
68
  )
69
  return LLaVAVideoModel(
70
  model_path,
@@ -72,14 +84,12 @@ def load_model(
72
  device_map=device_map,
73
  attn_implementation=attn_implementation,
74
  )
75
-
76
  elif "Qwen" in model_path:
77
  if not QWEN_MODELS_AVAILABLE:
78
  raise ImportError(
79
- "Qwen models require transformers>=4.57.0. "
80
- "Please upgrade transformers."
81
  )
82
-
83
  if "Qwen3" in model_path:
84
  return Qwen3VLModel(
85
  model_path,
@@ -94,7 +104,38 @@ def load_model(
94
  device_map=device_map,
95
  attn_implementation=attn_implementation,
96
  )
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- else:
99
- raise ValueError(f"Unsupported model path: {model_path}")
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  from typing import Optional, Union, Dict
5
 
6
+ # IMP: Add required versions here
7
  qwen_required_version = version.parse("4.57.0")
8
+ internvl_required_version = version.parse("4.45.0")
9
  llava_required_version = version.parse("4.40.0")
10
 
11
  # Conditional imports based on transformers version
12
  try:
13
  import transformers
14
+ from transformers.generation.logits_process import LogitsProcessor
15
+
16
+ # Check transformers version
17
  transformers_version = version.parse(transformers.__version__)
18
 
19
  QWEN_MODELS_AVAILABLE = False
20
+ INTERNVL_MODELS_AVAILABLE = False
21
  LLAVA_MODELS_AVAILABLE = False
22
 
23
  # Qwen condition
24
  if transformers_version >= qwen_required_version:
25
+ from .qwen2_5vl import Qwen2_5VLModel
26
  from .qwen3vl import Qwen3VLModel
27
+
28
  QWEN_MODELS_AVAILABLE = True
29
  else:
30
  print(
31
+ f"Warning: Qwen models require transformers>=4.57.0, but found {transformers.__version__}. Qwen models will not be available. Please upgrade to transformers>=4.57.0 or switch conda environments to use Qwen models."
32
+ )
33
+
34
+ # InternVL condition
35
+ if transformers_version >= internvl_required_version:
36
+ from .internvl import InternVLModel
37
+
38
+ INTERNVL_MODELS_AVAILABLE = True
39
+ else:
40
+ print(
41
+ f"Warning: InternVL models require transformers>=4.45.0, but found {transformers.__version__}. InternVL models will not be available. Please downgrade to transformers<=4.45.0 or switch conda environments to use InternVL models."
42
  )
43
 
44
  # LLaVA condition
45
  if transformers_version <= llava_required_version:
46
  from .llava_video import LLaVAVideoModel
47
+
48
  LLAVA_MODELS_AVAILABLE = True
49
  else:
50
  print(
51
+ f"Warning: LLaVA models require transformers<=4.40.0, but found {transformers.__version__}. LLaVA models will not be available. Please downgrade to transformers<=4.40.0 or switch conda environments to use LLaVA models."
 
 
52
  )
53
+ except ImportError:
54
+ print(
55
+ "Warning: Could not check transformers version. Please re-check transformers installation."
56
+ )
 
57
 
58
  # Build __all__ list dynamically
59
  __all__ = []
60
  if QWEN_MODELS_AVAILABLE:
61
  __all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
62
+ if INTERNVL_MODELS_AVAILABLE:
63
+ __all__.append("InternVLModel")
64
  if LLAVA_MODELS_AVAILABLE:
65
  __all__.append("LLaVAVideoModel")
66
 
 
72
  device_map: Optional[Union[str, Dict]] = "auto",
73
  attn_implementation: Optional[str] = "flash_attention_2",
74
  ) -> BaseVideoModel:
 
75
  if "LLaVA-Video" in model_path:
76
  if not LLAVA_MODELS_AVAILABLE:
77
  raise ImportError(
78
+ f"LLaVA models require transformers<=4.40.0."
79
+ f"Please downgrade transformers: pip install transformers<=4.40.0"
80
  )
81
  return LLaVAVideoModel(
82
  model_path,
 
84
  device_map=device_map,
85
  attn_implementation=attn_implementation,
86
  )
 
87
  elif "Qwen" in model_path:
88
  if not QWEN_MODELS_AVAILABLE:
89
  raise ImportError(
90
+ f"Qwen models require transformers>=4.57.0."
91
+ f"Please upgrade transformers: pip install transformers>=4.57.0"
92
  )
 
93
  if "Qwen3" in model_path:
94
  return Qwen3VLModel(
95
  model_path,
 
104
  device_map=device_map,
105
  attn_implementation=attn_implementation,
106
  )
107
+ elif "Intern" in model_path:
108
+ if not INTERNVL_MODELS_AVAILABLE:
109
+ raise ImportError(
110
+ f"InternVL models require transformers>=4.45.0."
111
+ f"Please upgrade transformers: pip install transformers>=4.45.0"
112
+ )
113
+ return InternVLModel(
114
+ model_path,
115
+ dtype=dtype,
116
+ device_map=device_map,
117
+ attn_implementation=attn_implementation,
118
+ )
119
 
 
 
120
 
121
+ class LogitsCaptureProcessor(LogitsProcessor):
122
+ """
123
+ Custom LogitsProcessor that captures the processed logits right before sampling.
124
+ This allows us to see what the actual distribution looks like after all other
125
+ processors have been applied.
126
+ """
127
+
128
+ def __init__(self):
129
+ self.captured_logits = []
130
+
131
+ def __call__(
132
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
133
+ ) -> torch.FloatTensor:
134
+ # Store a copy of the logits at this point in generation
135
+ self.captured_logits.append(scores.detach().clone().cpu())
136
+ # Return scores unchanged - we're just observing
137
+ return scores
138
+
139
+ def reset(self):
140
+ """Clear captured logits for a new generation"""
141
+ self.captured_logits = []
models/llava_video.py CHANGED
@@ -17,8 +17,7 @@ from PIL import Image
17
  import requests
18
  import copy
19
  import torch
20
- import sys
21
- from typing import Optional, Union, Dict, List, Any
22
  import warnings
23
  from decord import VideoReader, cpu
24
  import numpy as np
@@ -56,7 +55,6 @@ class LLaVAVideoModel(BaseVideoModel):
56
  base_model,
57
  torch_dtype=torch_dtype,
58
  device_map=device_map,
59
- attn_implementation=attn_implementation,
60
  )
61
  ) # Add any other thing you want to pass in llava_model_args
62
  self.model.eval()
@@ -105,10 +103,18 @@ class LLaVAVideoModel(BaseVideoModel):
105
  video_path: str,
106
  fps: float = 1.0,
107
  max_new_tokens: int = 512,
 
 
 
108
  temperature: float = 0.7,
 
 
109
  **kwargs: Any,
110
  ) -> str:
111
- video, _, _ = self.load_video(video_path, fps)
 
 
 
112
  video = self.image_processor.preprocess(video, return_tensors="pt")[
113
  "pixel_values"
114
  ].to(device=self.model.device, dtype=self.dtype)
@@ -132,7 +138,7 @@ class LLaVAVideoModel(BaseVideoModel):
132
  input_ids,
133
  images=video,
134
  modalities=["video"],
135
- do_sample=False,
136
  temperature=temperature,
137
  max_new_tokens=max_new_tokens,
138
  **kwargs,
@@ -149,9 +155,170 @@ class LLaVAVideoModel(BaseVideoModel):
149
  fps: float = 1.0,
150
  max_new_tokens: int = 512,
151
  temperature: float = 0.7,
 
 
 
152
  token_choices: Optional[List[str]] = ["Yes", "No"],
153
  logits_temperature: Optional[float] = 1.0,
154
  return_confidence: Optional[bool] = False,
 
155
  debug: Optional[bool] = False,
156
  ) -> Dict[str, Any]:
157
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  import requests
18
  import copy
19
  import torch
20
+ from typing import Optional, Union, Dict, List, Tuple, Any
 
21
  import warnings
22
  from decord import VideoReader, cpu
23
  import numpy as np
 
55
  base_model,
56
  torch_dtype=torch_dtype,
57
  device_map=device_map,
 
58
  )
59
  ) # Add any other thing you want to pass in llava_model_args
60
  self.model.eval()
 
103
  video_path: str,
104
  fps: float = 1.0,
105
  max_new_tokens: int = 512,
106
+ do_sample: Optional[
107
+ bool
108
+ ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
109
  temperature: float = 0.7,
110
+ video_mode: Optional[str] = "video",
111
+ video_frames: Optional[int] = 10,
112
  **kwargs: Any,
113
  ) -> str:
114
+ if video_mode == "frames":
115
+ video, _, _ = self.load_video(video_path, max_frames_num=video_frames)
116
+ elif video_mode == "video":
117
+ video, _, _ = self.load_video(video_path, fps)
118
  video = self.image_processor.preprocess(video, return_tensors="pt")[
119
  "pixel_values"
120
  ].to(device=self.model.device, dtype=self.dtype)
 
138
  input_ids,
139
  images=video,
140
  modalities=["video"],
141
+ do_sample=do_sample,
142
  temperature=temperature,
143
  max_new_tokens=max_new_tokens,
144
  **kwargs,
 
155
  fps: float = 1.0,
156
  max_new_tokens: int = 512,
157
  temperature: float = 0.7,
158
+ do_sample: Optional[
159
+ bool
160
+ ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
161
  token_choices: Optional[List[str]] = ["Yes", "No"],
162
  logits_temperature: Optional[float] = 1.0,
163
  return_confidence: Optional[bool] = False,
164
+ top_k_tokens: Optional[int] = 10,
165
  debug: Optional[bool] = False,
166
  ) -> Dict[str, Any]:
167
+ video, _, _ = self.load_video(video_path, fps)
168
+ video = self.image_processor.preprocess(video, return_tensors="pt")[
169
+ "pixel_values"
170
+ ].to(device=self.model.device, dtype=self.dtype)
171
+ video = [video]
172
+ conv_template = (
173
+ "qwen_1_5" # Make sure you use correct chat template for different models
174
+ )
175
+ question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
176
+ conv = copy.deepcopy(conv_templates[conv_template])
177
+ conv.append_message(conv.roles[0], question)
178
+ conv.append_message(conv.roles[1], None)
179
+ prompt_question = conv.get_prompt()
180
+ input_ids = (
181
+ tokenizer_image_token(
182
+ prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
183
+ )
184
+ .unsqueeze(0)
185
+ .to(self.model.device)
186
+ )
187
+ with torch.no_grad():
188
+ outputs = self.model.generate(
189
+ input_ids,
190
+ images=video,
191
+ modalities=["video"],
192
+ do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
193
+ temperature=temperature,
194
+ max_new_tokens=max_new_tokens,
195
+ output_scores=True,
196
+ return_dict_in_generate=True,
197
+ )
198
+ generated_ids = outputs.sequences
199
+ scores = outputs.scores # Tuple of tensors, one per generated token
200
+
201
+ print(f"Number of generated tokens: {len(scores)}")
202
+ print(f"Vocabulary size: {scores[0].shape[1]}")
203
+ # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
204
+ if debug:
205
+ print("****Running inference in debug mode****")
206
+ # Print first token scores shape and max/min scores in debug mode
207
+ print(f"Single token scores shape: {scores[0].shape}")
208
+ print(
209
+ f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
210
+ )
211
+
212
+ # Print details about top 10 tokens based on logits
213
+ logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
214
+ print(f"\n{'─'*80}")
215
+ print(
216
+ f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
217
+ )
218
+ print(f"{'─'*80}")
219
+ top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
220
+ for i in range(top_k_tokens):
221
+ score = top_k_tokens_scores.values[0, i].item()
222
+ score_index = top_k_tokens_scores.indices[0, i].item()
223
+ token = self.tokenizer.decode(score_index)
224
+ print(f"#{i+1}th Token: {token}")
225
+ print(f"#{i+1}th Token index: {score_index}")
226
+ print(f"#{i+1}th Token score: {score}")
227
+ print("--------------------------------")
228
+
229
+ # Decode the text
230
+ output_response = self.tokenizer.batch_decode(
231
+ generated_ids,
232
+ skip_special_tokens=True,
233
+ clean_up_tokenization_spaces=False,
234
+ )[0]
235
+
236
+ # Convert scores to probabilities
237
+ # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
238
+ selected_token_probs = []
239
+ selected_token_logits = []
240
+ first_token_probs = torch.softmax(scores[0], dim=-1)
241
+
242
+ # Now, find indices of tokens in token_choices and get their probabilities
243
+ for token_choice in token_choices:
244
+ # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
245
+ token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
246
+ 0
247
+ ]
248
+ selected_token_probs.append(first_token_probs[0, token_index].item())
249
+ selected_token_logits.append(scores[0][0, token_index].item())
250
+
251
+ # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
252
+ if return_confidence:
253
+ first_token_id = generated_ids[0][
254
+ 0
255
+ ].item() # First token of the first sequence
256
+ confidence = (
257
+ first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
258
+ if sum(selected_token_probs) > 0
259
+ else 0.0
260
+ )
261
+ return {
262
+ "response": output_response,
263
+ "confidence": confidence,
264
+ }
265
+
266
+ # Return token logits
267
+ else:
268
+ token_logits = dict(zip(token_choices, selected_token_logits))
269
+ top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
270
+ top_k_tokens_list: List[Tuple[str, int, float]] = []
271
+ for i in range(top_k_tokens):
272
+ logit_index = top_k_logits_indices.indices[0, i].item()
273
+ token = self.tokenizer.decode(logit_index)
274
+ logit = top_k_logits_indices.values[0, i].item()
275
+ top_k_tokens_list.append((token, logit_index, logit))
276
+ return {
277
+ "response": output_response,
278
+ "top_k_tokens": top_k_tokens_list,
279
+ "token_logits": token_logits,
280
+ }
281
+
282
+
283
+ if __name__ == "__main__":
284
+ model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
285
+ device_map = "cuda:0"
286
+ model = LLaVAVideoModel(model_path, device_map=device_map)
287
+ prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
288
+ token_choices = ["Yes", "No"]
289
+ video_path = (
290
+ "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
291
+ )
292
+
293
+ generation_config = {
294
+ "max_new_tokens": 128,
295
+ "do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
296
+ "temperature": 0.7,
297
+ "logits_temperature": 1.0,
298
+ "fps": 1.0,
299
+ "return_confidence": False,
300
+ "top_k_tokens": 10,
301
+ "debug": False,
302
+ }
303
+ output = model.chat_with_confidence(
304
+ prompt, video_path, token_choices=token_choices, **generation_config
305
+ )
306
+ response = output["response"]
307
+ print(f"Response: {response}")
308
+
309
+ if generation_config["return_confidence"]:
310
+ confidence = output["confidence"]
311
+ print(f"Confidence: {confidence}")
312
+ else:
313
+ # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
314
+ # otherwise, the raw logits are used, which are not filtered.
315
+ logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
316
+ print(f"\n{'─'*80}")
317
+ print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
318
+ print(f"{'─'*80}")
319
+ top_k_tokens = output["top_k_tokens"]
320
+ for i in range(len(top_k_tokens)):
321
+ print(f"Top {i+1} token: {top_k_tokens[i][0]}")
322
+ print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
323
+ print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
324
+ print("--------------------------------")
models/{qwen2_5.py β†’ qwen2_5vl.py} RENAMED
@@ -39,6 +39,8 @@ class Qwen2_5VLModel(BaseVideoModel):
39
  fps: float = 1.0,
40
  temperature: float = 0.7,
41
  max_new_tokens: int = 512,
 
 
42
  ) -> str:
43
  # Messages containing a local video path and a text query
44
  messages = [
@@ -75,8 +77,10 @@ class Qwen2_5VLModel(BaseVideoModel):
75
  # Inference
76
  generated_ids = self.model.generate(
77
  **inputs,
 
78
  temperature=temperature,
79
  max_new_tokens=max_new_tokens,
 
80
  )
81
  generated_ids_trimmed = [
82
  out_ids[len(in_ids) :]
 
39
  fps: float = 1.0,
40
  temperature: float = 0.7,
41
  max_new_tokens: int = 512,
42
+ do_sample: Optional[bool] = True,
43
+ **kwargs: Any,
44
  ) -> str:
45
  # Messages containing a local video path and a text query
46
  messages = [
 
77
  # Inference
78
  generated_ids = self.model.generate(
79
  **inputs,
80
+ do_sample=do_sample,
81
  temperature=temperature,
82
  max_new_tokens=max_new_tokens,
83
+ **kwargs,
84
  )
85
  generated_ids_trimmed = [
86
  out_ids[len(in_ids) :]
models/qwen3vl.py CHANGED
@@ -5,8 +5,11 @@ from transformers import (
5
  Qwen3VLForConditionalGeneration,
6
  AutoProcessor,
7
  )
8
- from typing import Optional, Dict, Any, Union, List
9
  from qwen_vl_utils import process_vision_info
 
 
 
10
 
11
  # Handle both relative and absolute imports
12
  try:
@@ -15,6 +18,36 @@ except ImportError:
15
  from base import BaseVideoModel
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class Qwen3VLModel(BaseVideoModel):
19
  def __init__(
20
  self,
@@ -38,31 +71,55 @@ class Qwen3VLModel(BaseVideoModel):
38
  video_path: str,
39
  fps: float = 1.0,
40
  temperature: float = 0.7,
 
 
 
41
  max_new_tokens: int = 512,
 
 
 
42
  ) -> str:
43
  # Messages containing a local video path and a text query
44
  messages = [
45
  {
46
  "role": "user",
47
  "content": [
48
- {
49
- "type": "video",
50
- "video": video_path,
51
- # "max_pixels": 360 * 420,
52
- "fps": fps,
53
- },
54
  {"type": "text", "text": prompt},
55
  ],
56
  }
57
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- inputs = self.processor.apply_chat_template(
60
- messages,
61
- tokenize=True,
62
- add_generation_prompt=True,
63
- return_dict=True,
64
- return_tensors="pt",
65
- )
 
 
 
 
 
 
 
 
66
 
67
  inputs = inputs.to(self.model.device)
68
 
@@ -70,6 +127,8 @@ class Qwen3VLModel(BaseVideoModel):
70
  **inputs,
71
  max_new_tokens=max_new_tokens,
72
  temperature=temperature,
 
 
73
  )
74
 
75
  generated_ids_trimmed = [
@@ -92,13 +151,18 @@ class Qwen3VLModel(BaseVideoModel):
92
  fps: float = 1.0,
93
  max_new_tokens: int = 512,
94
  temperature: float = 0.7,
 
 
 
95
  token_choices: Optional[List[str]] = ["Yes", "No"],
96
  logits_temperature: Optional[float] = 1.0,
97
  return_confidence: Optional[bool] = False,
 
98
  debug: Optional[bool] = False,
 
99
  ) -> Dict[str, Any]:
100
  """
101
- Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
102
 
103
  Args:
104
  prompt (str): The text prompt to generate a response for.
@@ -108,19 +172,17 @@ class Qwen3VLModel(BaseVideoModel):
108
  token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
109
  generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
110
  return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
 
111
  debug (bool, optional): Whether to run in debug mode. Defaults to False.
112
 
113
  Returns:
114
- Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
115
 
116
  e.g., return_confidence: False
117
  Output:
118
  {
119
  "response": "Yes",
120
- "logits": {
121
- "Yes": 12.0,
122
- "No": 9.0
123
- }
124
  }
125
 
126
  e.g., return_confidence: True
@@ -146,68 +208,233 @@ class Qwen3VLModel(BaseVideoModel):
146
  }
147
  ]
148
 
149
- text = self.processor.apply_chat_template(
150
- messages, tokenize=False, add_generation_prompt=True
151
- )
152
- image_inputs, videos, video_kwargs = process_vision_info(
153
  messages,
154
- image_patch_size=16,
155
- return_video_kwargs=True,
156
- return_video_metadata=True,
157
- )
158
- # Extract out videos and video metadata
159
- if videos is not None:
160
- videos, video_metadatas = zip(*videos)
161
- videos, video_metadatas = list(videos), list(video_metadatas)
162
- else:
163
- video_metadatas = None
164
-
165
- inputs = self.processor(
166
- text=text,
167
- images=image_inputs,
168
- videos=videos,
169
- video_metadata=video_metadatas,
170
  return_tensors="pt",
171
- do_resize=False,
172
- **video_kwargs,
173
  )
 
174
  inputs = inputs.to(self.model.device)
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  # Inference with scores
177
  with torch.no_grad():
178
  outputs = self.model.generate(
179
  **inputs,
180
  temperature=temperature,
181
  max_new_tokens=max_new_tokens,
 
182
  output_scores=True,
 
183
  return_dict_in_generate=True,
 
184
  )
185
 
186
  generated_ids = outputs.sequences
187
- scores = outputs.scores # Tuple of tensors, one per generated token
 
 
 
 
188
  scores = tuple(
189
  s / logits_temperature for s in scores
190
  ) # Scales the logits by a factor for normalization during reporting
191
 
192
  print(f"Number of generated tokens: {len(scores)}")
193
  print(f"Vocabulary size: {scores[0].shape[1]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
195
  if debug:
 
196
  print("****Running inference in debug mode****")
 
 
 
 
 
 
 
 
 
 
 
 
197
  # Print first token scores shape and max/min scores in debug mode
198
- print(f"Single token scores shape: {scores[0].shape}")
199
  print(
200
- f"First token max/min scores: {scores[0].max().item()}, {scores[0].min().item()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  )
202
- # Print details about top 3 tokens
203
- top_3_tokens = torch.topk(scores[0], k=3, dim=-1)
204
  for i in range(3):
 
 
 
205
  print(
206
- f"Pos 0 | {i+1}th Token: {self.processor.decode(top_3_tokens.indices[0, i].item())}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  )
208
  print(
209
- f"Pos 0 | {i+1}th Token logit: {top_3_tokens.values[0, i].item()}"
210
  )
 
 
 
 
 
 
 
 
211
 
212
  # Trim the prompt tokens from generated sequences
213
  generated_ids_trimmed = [
@@ -252,37 +479,41 @@ class Qwen3VLModel(BaseVideoModel):
252
  "confidence": confidence,
253
  }
254
 
255
- # Retrn token logits
256
  else:
257
  token_logits = dict(zip(token_choices, selected_token_logits))
 
 
 
 
 
 
 
258
  return {
259
  "response": output_response,
260
- "logits": token_logits,
 
261
  }
262
 
263
 
264
  if __name__ == "__main__":
265
  model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
266
  model = Qwen3VLModel(model_path)
267
- prompt = "Describe this video."
268
- ext = ".mp4"
269
  video_path = (
270
- "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/1586" + ext
271
  )
272
- response = model.chat(prompt, video_path)
273
- print("Response: ", response)
274
-
275
- token_choices = ["A", "B"]
276
- ext = ".webm"
277
- video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
278
 
279
  generation_config = {
280
  "max_new_tokens": 128,
 
281
  "temperature": 0.7,
282
- "logits_temperature": 5.0,
283
- "fps": 3.0,
284
  "return_confidence": False,
285
- "debug": True,
 
286
  }
287
  output = model.chat_with_confidence(
288
  prompt, video_path, token_choices=token_choices, **generation_config
@@ -294,6 +525,11 @@ if __name__ == "__main__":
294
  confidence = output["confidence"]
295
  print(f"Confidence: {confidence}")
296
  else:
297
- selected_token_logits = output["logits"]
298
- print(f"Selected token logits: {selected_token_logits}")
299
- print(f"Logits temperature: {generation_config['logits_temperature']}")
 
 
 
 
 
 
5
  Qwen3VLForConditionalGeneration,
6
  AutoProcessor,
7
  )
8
+ from typing import Optional, Dict, Any, Union, List, Tuple
9
  from qwen_vl_utils import process_vision_info
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
 
14
  # Handle both relative and absolute imports
15
  try:
 
18
  from base import BaseVideoModel
19
 
20
 
21
+ def downsample_video(video_path, max_dim=720, num_frames=10):
22
+ vidcap = cv2.VideoCapture(video_path)
23
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
24
+ frames = []
25
+ frame_indices = np.linspace(
26
+ 0, total_frames - 1, min(total_frames, num_frames), dtype=int
27
+ )
28
+
29
+ for i in frame_indices:
30
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
31
+ success, image = vidcap.read()
32
+ if success:
33
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
34
+
35
+ h, w = image.shape[:2]
36
+ scale = max_dim / max(h, w)
37
+ if scale < 1:
38
+ image = cv2.resize(
39
+ image,
40
+ (int(w * scale), int(h * scale)),
41
+ interpolation=cv2.INTER_AREA,
42
+ )
43
+
44
+ pil_image = Image.fromarray(image)
45
+ frames.append(pil_image)
46
+
47
+ vidcap.release()
48
+ return frames
49
+
50
+
51
  class Qwen3VLModel(BaseVideoModel):
52
  def __init__(
53
  self,
 
71
  video_path: str,
72
  fps: float = 1.0,
73
  temperature: float = 0.7,
74
+ do_sample: Optional[
75
+ bool
76
+ ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
77
  max_new_tokens: int = 512,
78
+ video_mode: Optional[str] = "video", # Choose from "video" or "frames"
79
+ video_frames: Optional[int] = 10,
80
+ **kwargs: Any,
81
  ) -> str:
82
  # Messages containing a local video path and a text query
83
  messages = [
84
  {
85
  "role": "user",
86
  "content": [
 
 
 
 
 
 
87
  {"type": "text", "text": prompt},
88
  ],
89
  }
90
  ]
91
+ if video_mode == "video":
92
+ messages[0]["content"].append(
93
+ {
94
+ "type": "video",
95
+ "video": video_path,
96
+ # "max_pixels": 360 * 420,
97
+ "fps": fps,
98
+ }
99
+ )
100
+ inputs = self.processor.apply_chat_template(
101
+ messages,
102
+ tokenize=True,
103
+ add_generation_prompt=True,
104
+ return_dict=True,
105
+ return_tensors="pt",
106
+ )
107
 
108
+ elif video_mode == "frames":
109
+ frames = downsample_video(video_path, max_dim=720, num_frames=video_frames)
110
+ images_for_processor = []
111
+ for frame in frames:
112
+ messages[0]["content"].append({"type": "image"})
113
+ images_for_processor.append(frame)
114
+ prompt_full = self.processor.apply_chat_template(
115
+ messages, tokenize=False, add_generation_prompt=True
116
+ )
117
+ inputs = self.processor(
118
+ text=[prompt_full],
119
+ images=images_for_processor,
120
+ return_tensors="pt",
121
+ padding=True,
122
+ )
123
 
124
  inputs = inputs.to(self.model.device)
125
 
 
127
  **inputs,
128
  max_new_tokens=max_new_tokens,
129
  temperature=temperature,
130
+ do_sample=do_sample,
131
+ **kwargs,
132
  )
133
 
134
  generated_ids_trimmed = [
 
151
  fps: float = 1.0,
152
  max_new_tokens: int = 512,
153
  temperature: float = 0.7,
154
+ do_sample: Optional[
155
+ bool
156
+ ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
157
  token_choices: Optional[List[str]] = ["Yes", "No"],
158
  logits_temperature: Optional[float] = 1.0,
159
  return_confidence: Optional[bool] = False,
160
+ top_k_tokens: Optional[int] = 10,
161
  debug: Optional[bool] = False,
162
+ **kwargs: Any,
163
  ) -> Dict[str, Any]:
164
  """
165
+ Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
166
 
167
  Args:
168
  prompt (str): The text prompt to generate a response for.
 
172
  token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
173
  generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
174
  return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
175
+ top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
176
  debug (bool, optional): Whether to run in debug mode. Defaults to False.
177
 
178
  Returns:
179
+ Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
180
 
181
  e.g., return_confidence: False
182
  Output:
183
  {
184
  "response": "Yes",
185
+ "top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
 
 
 
186
  }
187
 
188
  e.g., return_confidence: True
 
208
  }
209
  ]
210
 
211
+ inputs = self.processor.apply_chat_template(
 
 
 
212
  messages,
213
+ tokenize=True,
214
+ add_generation_prompt=True,
215
+ return_dict=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  return_tensors="pt",
 
 
217
  )
218
+
219
  inputs = inputs.to(self.model.device)
220
 
221
+ # In debug mode, inspect what logits processors will be used
222
+ if debug:
223
+ print("\n" + "=" * 80)
224
+ print("INSPECTING GENERATION CONFIG & WARPERS")
225
+ print("=" * 80)
226
+ # Get the generation config to see what processors will be added
227
+ gen_config = self.model.generation_config
228
+ print(f"Generation config attributes:")
229
+ print(f" Processor-related:")
230
+ print(
231
+ f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
232
+ )
233
+ print(
234
+ f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
235
+ )
236
+ print(
237
+ f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
238
+ )
239
+ print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
240
+ print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
241
+ print(
242
+ f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
243
+ )
244
+ print(
245
+ f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
246
+ )
247
+ print(f" Warper-related (THESE MASK TOKENS TO -INF):")
248
+ print(f" - temperature: {temperature} (passed as arg)")
249
+ print(
250
+ f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
251
+ )
252
+ print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
253
+ print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
254
+ print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
255
+ print(
256
+ f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
257
+ )
258
+ print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
259
+ print(
260
+ f"\n ⚠️ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
261
+ )
262
+ print("=" * 80 + "\n")
263
+
264
  # Inference with scores
265
  with torch.no_grad():
266
  outputs = self.model.generate(
267
  **inputs,
268
  temperature=temperature,
269
  max_new_tokens=max_new_tokens,
270
+ do_sample=do_sample,
271
  output_scores=True,
272
+ output_logits=True, # Get TRUE raw logits before any processing
273
  return_dict_in_generate=True,
274
+ **kwargs,
275
  )
276
 
277
  generated_ids = outputs.sequences
278
+ scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling
279
+ logits = (
280
+ outputs.logits if hasattr(outputs, "logits") else None
281
+ ) # TRUE raw logits from model
282
+
283
  scores = tuple(
284
  s / logits_temperature for s in scores
285
  ) # Scales the logits by a factor for normalization during reporting
286
 
287
  print(f"Number of generated tokens: {len(scores)}")
288
  print(f"Vocabulary size: {scores[0].shape[1]}")
289
+
290
+ # Check if logits differ from scores
291
+ if debug and logits is not None:
292
+ print(f"\n[IMPORTANT] output_logits available: True")
293
+ print(
294
+ f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):"
295
+ )
296
+ logits_raw = logits[0] / logits_temperature # First token's raw logits
297
+ scores_first = scores[0] # First token's processed scores
298
+
299
+ logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs()
300
+ max_diff = logits_diff.max().item()
301
+ if max_diff > 0.001:
302
+ print(
303
+ f"[IMPORTANT] ⚠️ outputs.scores ARE DIFFERENT from outputs.logits!"
304
+ )
305
+ print(f"[IMPORTANT] Max difference: {max_diff:.6f}")
306
+ print(
307
+ f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!"
308
+ )
309
+ else:
310
+ print(f"[IMPORTANT] βœ“ outputs.scores == outputs.logits (both are raw)")
311
+ elif debug:
312
+ print(
313
+ f"\n[IMPORTANT] output_logits not available in this transformers version"
314
+ )
315
+
316
  # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
317
  if debug:
318
+ print("\n" + "=" * 80)
319
  print("****Running inference in debug mode****")
320
+ print("=" * 80)
321
+
322
+ # Use truly raw logits if available, otherwise use scores
323
+ raw_logits_to_show = (
324
+ logits[0] / logits_temperature if logits is not None else scores[0]
325
+ )
326
+ logits_label = (
327
+ "TRUE RAW LOGITS (from outputs.logits)"
328
+ if logits is not None
329
+ else "LOGITS (from outputs.scores)"
330
+ )
331
+
332
  # Print first token scores shape and max/min scores in debug mode
 
333
  print(
334
+ f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}"
335
+ )
336
+ print(
337
+ f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}"
338
+ )
339
+
340
+ # Print details about top 3 tokens from RAW logits
341
+ print(f"\n{'─'*80}")
342
+ print(f"TOP 3 TOKENS FROM {logits_label}:")
343
+ print(f"{'─'*80}")
344
+ top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1)
345
+ for i in range(3):
346
+ token_id = top_3_tokens.indices[0, i].item()
347
+ token_text = self.processor.decode(token_id)
348
+ token_logit = top_3_tokens.values[0, i].item()
349
+ print(
350
+ f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
351
+ )
352
+
353
+ # Now compare with POST-PROCESSED logits (outputs.scores)
354
+ scores_first = scores[0] / logits_temperature
355
+ print(f"\n{'─'*80}")
356
+ print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):")
357
+ print(f"{'─'*80}")
358
+ print(
359
+ f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}"
360
  )
361
+
362
+ top_3_processed = torch.topk(scores_first, k=3, dim=-1)
363
  for i in range(3):
364
+ token_id = top_3_processed.indices[0, i].item()
365
+ token_text = self.processor.decode(token_id)
366
+ token_logit = top_3_processed.values[0, i].item()
367
  print(
368
+ f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
369
+ )
370
+
371
+ # Check if the distributions differ (compare against truly raw logits if available)
372
+ print(f"\n{'─'*80}")
373
+ print("DIFFERENCE ANALYSIS (Raw β†’ Post-Processed):")
374
+ print(f"{'─'*80}")
375
+ logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs()
376
+ max_diff = logit_diff.max().item()
377
+ num_changed = (logit_diff > 0.001).sum().item()
378
+
379
+ print(f" Max logit difference: {max_diff:.6f}")
380
+ print(
381
+ f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}"
382
+ )
383
+
384
+ if max_diff > 0.001:
385
+ print(f"\n ⚠️ LOGITS WERE MODIFIED BY PROCESSORS!")
386
+ # Show which tokens changed the most
387
+ top_changes = torch.topk(logit_diff[0], k=min(5, num_changed))
388
+ print(f"\n Top 5 most changed tokens:")
389
+ for i in range(min(5, len(top_changes.indices))):
390
+ token_id = top_changes.indices[i].item()
391
+ token_text = self.processor.decode(token_id)
392
+ raw_logit = raw_logits_to_show[0, token_id].item()
393
+ processed_logit = scores_first[0, token_id].item()
394
+ diff = top_changes.values[i].item()
395
+ print(f" Token='{token_text}' | ID={token_id}")
396
+ print(
397
+ f" Raw: {raw_logit:.4f} β†’ Processed: {processed_logit:.4f} (Ξ”={diff:.4f})"
398
+ )
399
+ else:
400
+ print(f" βœ“ No significant modifications detected")
401
+
402
+ # Show what token was actually selected
403
+ print(f"\n{'─'*80}")
404
+ print("ACTUALLY GENERATED TOKEN:")
405
+ print(f"{'─'*80}")
406
+ first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item()
407
+ first_generated_token = self.processor.decode(first_generated_id)
408
+ raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item()
409
+
410
+ print(f" Token: '{first_generated_token}' | ID={first_generated_id}")
411
+ print(f" Raw logit: {raw_logit_for_generated:.4f}")
412
+
413
+ processed_logit_for_generated = scores_first[0, first_generated_id].item()
414
+ print(f" Post-processed logit: {processed_logit_for_generated:.4f}")
415
+
416
+ # Check if this token is in top-k of raw logits
417
+ top_k_raw_indices = torch.topk(
418
+ raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1
419
+ ).indices[0]
420
+ is_in_top10_raw = first_generated_id in top_k_raw_indices
421
+ print(f" In top-10 of RAW logits: {is_in_top10_raw}")
422
+
423
+ if not is_in_top10_raw:
424
+ print(
425
+ f"\n 🚨 CRITICAL: Generated token was NOT in top-10 of raw logits!"
426
  )
427
  print(
428
+ f" This proves that logits processors modified the distribution."
429
  )
430
+ # Find the rank of the generated token in raw logits
431
+ sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True)
432
+ raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[
433
+ 0
434
+ ].item() + 1
435
+ print(f" Raw logits rank: {raw_rank}")
436
+
437
+ print("=" * 80 + "\n")
438
 
439
  # Trim the prompt tokens from generated sequences
440
  generated_ids_trimmed = [
 
479
  "confidence": confidence,
480
  }
481
 
482
+ # Return token logits
483
  else:
484
  token_logits = dict(zip(token_choices, selected_token_logits))
485
+ top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
486
+ top_k_tokens_list: List[Tuple[str, int, float]] = []
487
+ for i in range(top_k_tokens):
488
+ logit_index = top_k_logits_indices.indices[0, i].item()
489
+ token = self.processor.decode(logit_index)
490
+ logit = top_k_logits_indices.values[0, i].item()
491
+ top_k_tokens_list.append((token, logit_index, logit))
492
  return {
493
  "response": output_response,
494
+ "top_k_tokens": top_k_tokens_list,
495
+ "token_logits": token_logits,
496
  }
497
 
498
 
499
  if __name__ == "__main__":
500
  model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
501
  model = Qwen3VLModel(model_path)
502
+ prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:'
503
+ token_choices = ["Yes", "No"]
504
  video_path = (
505
+ "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4"
506
  )
 
 
 
 
 
 
507
 
508
  generation_config = {
509
  "max_new_tokens": 128,
510
+ "do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
511
  "temperature": 0.7,
512
+ "logits_temperature": 1.0,
513
+ "fps": 1.0,
514
  "return_confidence": False,
515
+ "top_k_tokens": 10,
516
+ "debug": False,
517
  }
518
  output = model.chat_with_confidence(
519
  prompt, video_path, token_choices=token_choices, **generation_config
 
525
  confidence = output["confidence"]
526
  print(f"Confidence: {confidence}")
527
  else:
528
+ # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
529
+ # otherwise, the raw logits are used, which are not filtered.
530
+ logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
531
+ top_k_tokens = output["top_k_tokens"]
532
+ for i in range(len(top_k_tokens)):
533
+ print(f"Top {i+1} token: {top_k_tokens[i][0]}")
534
+ print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
535
+ print("--------------------------------")