jena-shreyas commited on
Commit
80ceab0
·
1 Parent(s): ae099b5

Initial commit without videos

Browse files
Files changed (10) hide show
  1. .gitattributes +4 -0
  2. .gitignore +3 -0
  3. app.py +414 -0
  4. models/__init__.py +121 -0
  5. models/base.py +27 -0
  6. models/internvl.py +44 -0
  7. models/llava_video.py +154 -0
  8. models/qwen2_5.py +288 -0
  9. models/qwen3vl.py +299 -0
  10. requirements.txt +149 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.webm filter=lfs diff=lfs merge=lfs -text
38
+ *.avi filter=lfs diff=lfs merge=lfs -text
39
+ *.mov filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .gradio/
2
+ models/__pycache__/
3
+ SETUP_VIDEO_LFS.md
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ import gradio as gr
5
+
6
+ # Allow importing your models package
7
+ sys.path.insert(0, str(Path(__file__).parent))
8
+
9
+ from models import load_model
10
+ from models.base import BaseVideoModel
11
+
12
+ # ----------------------
13
+ # CONFIG
14
+ # ----------------------
15
+ MODEL_PATH = "lmms-lab/LLaVA-Video-7B-Qwen2"
16
+ DEVICE_MAP = "cuda:0"
17
+
18
+ VIDEO_DIR = str(Path(__file__).parent / "videos")
19
+
20
+ FPS = 1.0
21
+ MAX_NEW_TOKENS = 512
22
+ TEMPERATURE = 0.01
23
+
24
+ # ----------------------
25
+ # Load model ONCE
26
+ # ----------------------
27
+ print("Loading LLaVa-Video-7B-Qwen2...")
28
+ model: BaseVideoModel = load_model(
29
+ MODEL_PATH,
30
+ device_map=DEVICE_MAP,
31
+ )
32
+ print("Model loaded.")
33
+
34
+ # ----------------------
35
+ # Collect video IDs
36
+ # ----------------------
37
+ VIDEO_IDS = sorted([
38
+ os.path.splitext(f)[0]
39
+ for f in os.listdir(VIDEO_DIR)
40
+ if f.endswith(".webm")
41
+ ])
42
+
43
+ # ----------------------
44
+ # Helpers
45
+ # ----------------------
46
+ def get_video_path(video_id: str):
47
+ if not video_id:
48
+ return None
49
+ path = os.path.join(VIDEO_DIR, video_id + ".webm")
50
+ return path if os.path.exists(path) else None
51
+
52
+ # ----------------------
53
+ # Inference function
54
+ # ----------------------
55
+ def video_qa(video_id: str, prompt: str) -> str:
56
+ if not video_id:
57
+ return "❌ Please select a video ID."
58
+
59
+ if not prompt.strip():
60
+ return "❌ Please enter a prompt."
61
+
62
+ video_path = get_video_path(video_id)
63
+ if video_path is None:
64
+ return f"❌ Video not found: {video_id}.webm"
65
+
66
+ try:
67
+ response = model.chat(
68
+ prompt=prompt,
69
+ video_path=video_path,
70
+ fps=FPS,
71
+ max_new_tokens=MAX_NEW_TOKENS,
72
+ temperature=TEMPERATURE,
73
+ )
74
+ return response
75
+
76
+ except Exception as e:
77
+ return f"❌ Error during inference: {str(e)}"
78
+
79
+ # ----------------------
80
+ # Gradio UI
81
+ # ----------------------
82
+ with gr.Blocks(title="Video QA – LLaVa-Video-7B-Qwen2") as demo:
83
+ gr.Markdown("## 🎥 Video Question Answering (LLaVa-Video-7B-Qwen2)")
84
+
85
+ with gr.Row():
86
+ # LEFT COLUMN
87
+ with gr.Column(scale=1):
88
+ video_id = gr.Dropdown(
89
+ choices=VIDEO_IDS,
90
+ label="Video ID",
91
+ filterable=True,
92
+ interactive=True
93
+ )
94
+
95
+ video_player = gr.Video(
96
+ label="Selected Video",
97
+ autoplay=True,
98
+ height=240
99
+ )
100
+
101
+ # RIGHT COLUMN
102
+ with gr.Column(scale=2):
103
+ prompt = gr.Textbox(
104
+ label="Prompt",
105
+ placeholder="Ask a question about the selected video",
106
+ lines=4
107
+ )
108
+ answer = gr.Textbox(
109
+ label="Model Answer",
110
+ lines=8
111
+ )
112
+ run = gr.Button("Run Inference 🚀")
113
+
114
+ # Update video player when dropdown changes
115
+ video_id.change(
116
+ fn=get_video_path,
117
+ inputs=video_id,
118
+ outputs=video_player
119
+ )
120
+
121
+ # Run inference
122
+ run.click(
123
+ fn=video_qa,
124
+ inputs=[video_id, prompt],
125
+ outputs=answer
126
+ )
127
+
128
+
129
+
130
+ demo.launch(
131
+ server_name="0.0.0.0",
132
+ server_port=7860,
133
+ share=True
134
+ )
135
+
136
+
137
+ # #---------------
138
+ # #---------------
139
+ # #---------------
140
+ # # Feb 5, 2026
141
+ # #---------------
142
+ # import os
143
+ # import sys
144
+ # import json
145
+ # from pathlib import Path
146
+ # import gradio as gr
147
+
148
+ # # Allow importing your models package
149
+ # sys.path.insert(0, str(Path(__file__).parent))
150
+
151
+ # from models import load_model
152
+ # from models.base import BaseVideoModel
153
+
154
+ # # ----------------------
155
+ # # CONFIG
156
+ # # ----------------------
157
+ # QWEN_MODEL_PATH = "Qwen/Qwen3-VL-4B-Instruct"
158
+ # LLAVA_MODEL_PATH = "lmms-lab/LLaVA-Video-7B-Qwen2"
159
+ # DEVICE_MAP_QWEN = "cuda:0"
160
+ # DEVICE_MAP_LLAVA = "cuda:0" # Both models on same GPU
161
+
162
+ # VIDEO_DIR = "/home/raman/Gradio_Qwen3vl4bInstruct/videos"
163
+ # LABELS_JSON = "/home/raman/Gradio_Qwen3vl4bInstruct/SSv2_prepost_sampled.json"
164
+
165
+ # DEFAULT_FPS = 1.0
166
+ # MAX_NEW_TOKENS = 512
167
+ # TEMPERATURE = 0.01
168
+
169
+ # # ----------------------
170
+ # # Load video labels
171
+ # # ----------------------
172
+ # print("Loading video labels...")
173
+ # video_labels = {}
174
+ # try:
175
+ # with open(LABELS_JSON, 'r') as f:
176
+ # labels_data = json.load(f)
177
+ # for item in labels_data:
178
+ # video_labels[item['id']] = {
179
+ # 'label': item['label'],
180
+ # 'template': item.get('template', ''),
181
+ # 'action_group': item.get('action_group', '')
182
+ # }
183
+ # print(f"Loaded {len(video_labels)} video labels.")
184
+ # except Exception as e:
185
+ # print(f"Warning: Could not load labels JSON: {e}")
186
+
187
+ # # ----------------------
188
+ # # Load models
189
+ # # ----------------------
190
+ # print("Loading Qwen3-VL-4B-Instruct...")
191
+ # qwen_model: BaseVideoModel = load_model(
192
+ # QWEN_MODEL_PATH,
193
+ # device_map=DEVICE_MAP_QWEN,
194
+ # )
195
+ # print("Qwen model loaded.")
196
+
197
+ # print("Loading LLaVA-Video-7B...")
198
+ # llava_model: BaseVideoModel = load_model(
199
+ # LLAVA_MODEL_PATH,
200
+ # device_map=DEVICE_MAP_LLAVA,
201
+ # )
202
+ # print("LLaVA model loaded.")
203
+
204
+ # # ----------------------
205
+ # # Collect video IDs
206
+ # # ----------------------
207
+ # VIDEO_IDS = sorted([
208
+ # os.path.splitext(f)[0]
209
+ # for f in os.listdir(VIDEO_DIR)
210
+ # if f.endswith(".mp4")
211
+ # ])
212
+
213
+ # print(f"Found {len(VIDEO_IDS)} videos.")
214
+
215
+ # # ----------------------
216
+ # # Helpers
217
+ # # ----------------------
218
+ # def get_video_path(video_id: str):
219
+ # if not video_id:
220
+ # return None
221
+ # path = os.path.join(VIDEO_DIR, video_id + ".mp4")
222
+ # return path if os.path.exists(path) else None
223
+
224
+ # def get_video_label(video_id: str):
225
+ # if not video_id:
226
+ # return ""
227
+ # info = video_labels.get(video_id, {})
228
+ # label = info.get('label', 'No label available')
229
+ # action_group = info.get('action_group', '')
230
+
231
+ # if action_group:
232
+ # return f"**Label:** {label}\n\n**Action Group:** {action_group}"
233
+ # return f"**Label:** {label}"
234
+
235
+ # def update_video_info(video_id: str):
236
+ # """Returns video path and label when video is selected"""
237
+ # video_path = get_video_path(video_id)
238
+ # label = get_video_label(video_id)
239
+ # return video_path, label
240
+
241
+ # # ----------------------
242
+ # # Inference functions
243
+ # # ----------------------
244
+ # def qwen_inference(video_id: str, prompt: str, fps: float) -> str:
245
+ # if not video_id:
246
+ # return "❌ Please select a video ID."
247
+
248
+ # if not prompt.strip():
249
+ # return "❌ Please enter a prompt."
250
+
251
+ # video_path = get_video_path(video_id)
252
+ # if video_path is None:
253
+ # return f"❌ Video not found: {video_id}.mp4"
254
+
255
+ # try:
256
+ # response = qwen_model.chat(
257
+ # prompt=prompt,
258
+ # video_path=video_path,
259
+ # fps=fps,
260
+ # max_new_tokens=MAX_NEW_TOKENS,
261
+ # temperature=TEMPERATURE,
262
+ # )
263
+ # return response
264
+
265
+ # except Exception as e:
266
+ # return f"❌ Error during Qwen inference: {str(e)}"
267
+
268
+ # def llava_inference(video_id: str, prompt: str, fps: float) -> str:
269
+ # if not video_id:
270
+ # return "❌ Please select a video ID."
271
+
272
+ # if not prompt.strip():
273
+ # return "❌ Please enter a prompt."
274
+
275
+ # video_path = get_video_path(video_id)
276
+ # if video_path is None:
277
+ # return f"❌ Video not found: {video_id}.mp4"
278
+
279
+ # try:
280
+ # response = llava_model.chat(
281
+ # prompt=prompt,
282
+ # video_path=video_path,
283
+ # fps=fps,
284
+ # max_new_tokens=MAX_NEW_TOKENS,
285
+ # temperature=TEMPERATURE,
286
+ # )
287
+ # return response
288
+
289
+ # except Exception as e:
290
+ # return f"❌ Error during LLaVA inference: {str(e)}"
291
+
292
+ # # ----------------------
293
+ # # Gradio UI
294
+ # # ----------------------
295
+ # with gr.Blocks(title="Video QA – Qwen3-VL & LLaVA-Video", theme=gr.themes.Soft()) as demo:
296
+ # gr.Markdown("# 🎥 Video Question Answering Demo")
297
+ # gr.Markdown("Compare **Qwen3-VL-4B-Instruct** and **LLaVA-Video-7B-Qwen2** on the same videos")
298
+
299
+ # # TOP SECTION: Video Selection and Display
300
+ # with gr.Row():
301
+ # with gr.Column(scale=1):
302
+ # video_id = gr.Dropdown(
303
+ # choices=VIDEO_IDS,
304
+ # label="📁 Select Video ID",
305
+ # filterable=True,
306
+ # interactive=True,
307
+ # value=VIDEO_IDS[0] if VIDEO_IDS else None
308
+ # )
309
+
310
+ # video_label = gr.Markdown(
311
+ # value=get_video_label(VIDEO_IDS[0]) if VIDEO_IDS else "",
312
+ # label="Video Information"
313
+ # )
314
+
315
+ # fps_slider = gr.Slider(
316
+ # minimum=0.5,
317
+ # maximum=5.0,
318
+ # step=0.5,
319
+ # value=DEFAULT_FPS,
320
+ # label="🎞️ Frames Per Second (FPS)",
321
+ # info="Higher FPS = more frames analyzed (slower but more detailed)"
322
+ # )
323
+
324
+ # with gr.Column(scale=2):
325
+ # video_player = gr.Video(
326
+ # label="Selected Video",
327
+ # autoplay=False,
328
+ # height=360,
329
+ # value=get_video_path(VIDEO_IDS[0]) if VIDEO_IDS else None
330
+ # )
331
+
332
+ # gr.Markdown("---")
333
+
334
+ # # BOTTOM SECTION: Two Models Side by Side
335
+ # with gr.Row():
336
+ # # QWEN COLUMN
337
+ # with gr.Column(scale=1):
338
+ # gr.Markdown("### 🤖 Qwen3-VL-4B-Instruct")
339
+
340
+ # qwen_prompt = gr.Textbox(
341
+ # label="Prompt",
342
+ # placeholder="Ask a question about the video...",
343
+ # lines=4,
344
+ # value="Describe what is happening in this video."
345
+ # )
346
+
347
+ # qwen_answer = gr.Textbox(
348
+ # label="Qwen Answer",
349
+ # lines=10,
350
+ # interactive=False
351
+ # )
352
+
353
+ # qwen_run = gr.Button("🚀 Run Qwen Inference", variant="primary")
354
+
355
+ # # LLAVA COLUMN
356
+ # with gr.Column(scale=1):
357
+ # gr.Markdown("### 🎬 LLaVA-Video-7B-Qwen2")
358
+
359
+ # llava_prompt = gr.Textbox(
360
+ # label="Prompt",
361
+ # placeholder="Ask a question about the video...",
362
+ # lines=4,
363
+ # value="Describe what is happening in this video."
364
+ # )
365
+
366
+ # llava_answer = gr.Textbox(
367
+ # label="LLaVA Answer",
368
+ # lines=10,
369
+ # interactive=False
370
+ # )
371
+
372
+ # llava_run = gr.Button("🚀 Run LLaVA Inference", variant="primary")
373
+
374
+ # # Model info footer
375
+ # gr.Markdown("""
376
+ # ---
377
+ # **Model Information:**
378
+ # - **Qwen3-VL-4B-Instruct**: 4B parameter vision-language model
379
+ # - **LLaVA-Video-7B-Qwen2**: 7B parameter video understanding model
380
+
381
+ # **Settings:** Max Tokens={}, Temperature={}
382
+ # """.format(MAX_NEW_TOKENS, TEMPERATURE))
383
+
384
+ # # ----------------------
385
+ # # Event Handlers
386
+ # # ----------------------
387
+
388
+ # # Update video player and label when dropdown changes
389
+ # video_id.change(
390
+ # fn=update_video_info,
391
+ # inputs=video_id,
392
+ # outputs=[video_player, video_label]
393
+ # )
394
+
395
+ # # Run Qwen inference
396
+ # qwen_run.click(
397
+ # fn=qwen_inference,
398
+ # inputs=[video_id, qwen_prompt, fps_slider],
399
+ # outputs=qwen_answer
400
+ # )
401
+
402
+ # # Run LLaVA inference
403
+ # llava_run.click(
404
+ # fn=llava_inference,
405
+ # inputs=[video_id, llava_prompt, fps_slider],
406
+ # outputs=llava_answer
407
+ # )
408
+
409
+ # # Launch
410
+ # demo.launch(
411
+ # server_name="0.0.0.0",
412
+ # server_port=7860,
413
+ # share=True
414
+ # )
models/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseVideoModel
2
+ from packaging import version
3
+ import torch
4
+ from typing import Optional, Union, Dict
5
+
6
+ # Required versions
7
+ qwen_required_version = version.parse("4.57.0")
8
+ llava_required_version = version.parse("4.40.0")
9
+
10
+ # Conditional imports based on transformers version
11
+ try:
12
+ import transformers
13
+ # More robust import path for newer transformers
14
+ from transformers.generation import LogitsProcessor
15
+
16
+ transformers_version = version.parse(transformers.__version__)
17
+
18
+ QWEN_MODELS_AVAILABLE = False
19
+ LLAVA_MODELS_AVAILABLE = False
20
+
21
+ # Qwen condition
22
+ if transformers_version >= qwen_required_version:
23
+ from .qwen2_5 import Qwen2_5VLModel
24
+ from .qwen3vl import Qwen3VLModel
25
+ QWEN_MODELS_AVAILABLE = True
26
+ else:
27
+ print(
28
+ f"Warning: Qwen models require transformers>=4.57.0, "
29
+ f"but found {transformers.__version__}. "
30
+ f"Qwen models will not be available."
31
+ )
32
+
33
+ # LLaVA condition
34
+ if transformers_version <= llava_required_version:
35
+ from .llava_video import LLaVAVideoModel
36
+ LLAVA_MODELS_AVAILABLE = True
37
+ else:
38
+ print(
39
+ f"Warning: LLaVA models require transformers<=4.40.0, "
40
+ f"but found {transformers.__version__}. "
41
+ f"LLaVA models will not be available."
42
+ )
43
+
44
+ except ImportError as e:
45
+ print("Warning: Could not import transformers correctly.")
46
+ raise e
47
+
48
+
49
+ # Build __all__ list dynamically
50
+ __all__ = []
51
+ if QWEN_MODELS_AVAILABLE:
52
+ __all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
53
+ if LLAVA_MODELS_AVAILABLE:
54
+ __all__.append("LLaVAVideoModel")
55
+
56
+
57
+ # Function to get the model by mapping model ID to the correct model class
58
+ def load_model(
59
+ model_path: str,
60
+ dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
61
+ device_map: Optional[Union[str, Dict]] = "auto",
62
+ attn_implementation: Optional[str] = "flash_attention_2",
63
+ ) -> BaseVideoModel:
64
+
65
+ if "LLaVA-Video" in model_path:
66
+ if not LLAVA_MODELS_AVAILABLE:
67
+ raise ImportError(
68
+ "LLaVA models require transformers<=4.40.0. "
69
+ "Please downgrade transformers."
70
+ )
71
+ return LLaVAVideoModel(
72
+ model_path,
73
+ dtype=dtype,
74
+ device_map=device_map,
75
+ attn_implementation=attn_implementation,
76
+ )
77
+
78
+ elif "Qwen" in model_path:
79
+ if not QWEN_MODELS_AVAILABLE:
80
+ raise ImportError(
81
+ "Qwen models require transformers>=4.57.0. "
82
+ "Please upgrade transformers."
83
+ )
84
+
85
+ if "Qwen3" in model_path:
86
+ return Qwen3VLModel(
87
+ model_path,
88
+ dtype=dtype,
89
+ device_map=device_map,
90
+ attn_implementation=attn_implementation,
91
+ )
92
+ else:
93
+ return Qwen2_5VLModel(
94
+ model_path,
95
+ dtype=dtype,
96
+ device_map=device_map,
97
+ attn_implementation=attn_implementation,
98
+ )
99
+
100
+ else:
101
+ raise ValueError(f"Unsupported model path: {model_path}")
102
+
103
+
104
+ class LogitsCaptureProcessor(LogitsProcessor):
105
+ """
106
+ Custom LogitsProcessor that captures the processed logits right before sampling.
107
+ """
108
+
109
+ def __init__(self):
110
+ self.captured_logits = []
111
+
112
+ def __call__(
113
+ self,
114
+ input_ids: torch.LongTensor,
115
+ scores: torch.FloatTensor,
116
+ ) -> torch.FloatTensor:
117
+ self.captured_logits.append(scores.detach().clone().cpu())
118
+ return scores
119
+
120
+ def reset(self):
121
+ self.captured_logits = []
models/base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional, Union, Any
3
+
4
+
5
+ class BaseVideoModel(ABC):
6
+ def __init__(self, model_name: str):
7
+ self.model_name = model_name
8
+ self.model = None
9
+ self.processor = None
10
+
11
+ @abstractmethod
12
+ def chat(
13
+ self,
14
+ prompt: str,
15
+ video_path: str,
16
+ generation_config: Optional[Dict[str, Any]] = None,
17
+ ) -> str:
18
+ pass
19
+
20
+ @abstractmethod
21
+ def chat_with_confidence(
22
+ self,
23
+ prompt: str,
24
+ video_path: str,
25
+ generation_config: Optional[Dict[str, Any]] = None,
26
+ ) -> Dict[str, Union[str, float]]:
27
+ pass
models/internvl.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from decord import VideoReader, cpu
6
+ from PIL import Image
7
+ from torchvision.transforms.functional import InterpolationMode
8
+ from transformers import AutoModel, AutoTokenizer
9
+ from typing import Optional, Dict, Any, Union, List
10
+ from .base import BaseVideoModel
11
+
12
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
13
+ IMAGENET_STD = (0.229, 0.224, 0.225)
14
+
15
+
16
+ class InternVLModel(BaseVideoModel):
17
+ def __init__(self, model_name: str = "OpenGVLab/InternVL3_5-8B"):
18
+ super().__init__(model_name)
19
+ self.model = AutoModel.from_pretrained(model_name)
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+
22
+ def chat(
23
+ self,
24
+ prompt: str,
25
+ video_path: str,
26
+ fps: float = 1.0,
27
+ max_new_tokens: int = 512,
28
+ temperature: float = 0.7,
29
+ ) -> str:
30
+ pass
31
+
32
+ def chat_with_confidence(
33
+ self,
34
+ prompt: str,
35
+ video_path: str,
36
+ fps: float = 1.0,
37
+ max_new_tokens: int = 512,
38
+ temperature: float = 0.7,
39
+ token_choices: Optional[List[str]] = ["Yes", "No"],
40
+ logits_temperature: Optional[float] = 1.0,
41
+ return_confidence: Optional[bool] = False,
42
+ debug: Optional[bool] = False,
43
+ ) -> Dict[str, Any]:
44
+ pass
models/llava_video.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with `conda activate llava`
2
+ from llava.model.builder import load_pretrained_model
3
+ from llava.mm_utils import (
4
+ get_model_name_from_path,
5
+ process_images,
6
+ tokenizer_image_token,
7
+ )
8
+ from llava.constants import (
9
+ IMAGE_TOKEN_INDEX,
10
+ DEFAULT_IMAGE_TOKEN,
11
+ DEFAULT_IM_START_TOKEN,
12
+ DEFAULT_IM_END_TOKEN,
13
+ IGNORE_INDEX,
14
+ )
15
+ from llava.conversation import conv_templates, SeparatorStyle
16
+ from PIL import Image
17
+ import requests
18
+ import copy
19
+ import torch
20
+ import sys
21
+ from typing import Optional, Union, Dict, List, Any
22
+ import warnings
23
+ from decord import VideoReader, cpu
24
+ import numpy as np
25
+
26
+ # Handle both relative and absolute imports
27
+ try:
28
+ from .base import BaseVideoModel
29
+ except ImportError:
30
+ from base import BaseVideoModel
31
+
32
+ warnings.filterwarnings("ignore")
33
+
34
+
35
+ class LLaVAVideoModel(BaseVideoModel):
36
+ def __init__(
37
+ self,
38
+ model_name: str = "lmms-lab/LLaVA-Video-7B-Qwen2",
39
+ dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
40
+ device_map: Optional[Union[str, Dict]] = "auto",
41
+ attn_implementation: Optional[str] = "flash_attention_2",
42
+ ):
43
+ super().__init__(model_name)
44
+ base_model = "llava_qwen"
45
+ self.dtype = dtype
46
+ # Convert torch dtype to string for safety, since LLaVA-Video only accepts torch_dtype as a string
47
+ if dtype == torch.bfloat16:
48
+ torch_dtype = "bfloat16"
49
+ elif dtype == torch.float16:
50
+ torch_dtype = "float16"
51
+
52
+ self.tokenizer, self.model, self.image_processor, max_length = (
53
+ load_pretrained_model(
54
+ model_name,
55
+ None,
56
+ base_model,
57
+ torch_dtype=torch_dtype,
58
+ device_map=device_map,
59
+ )
60
+ ) # Add any other thing you want to pass in llava_model_args
61
+ self.model.eval()
62
+
63
+ # Ensure all model components are on the same device
64
+ # The vision tower and mm_projector may not be on the correct device with device_map using `load_pretrained_model`, so need to explicitly move to the model's device
65
+ if hasattr(self.model, "get_vision_tower"):
66
+ vision_tower = self.model.get_vision_tower()
67
+ if vision_tower is not None:
68
+ vision_tower.to(self.model.device)
69
+
70
+ if hasattr(self.model, "get_model"):
71
+ model_inner = self.model.get_model()
72
+ if hasattr(model_inner, "mm_projector"):
73
+ model_inner.mm_projector.to(self.model.device)
74
+
75
+ def load_video(
76
+ self,
77
+ video_path: str,
78
+ fps: float = 1.0,
79
+ max_frames_num: int = -1,
80
+ force_sample: bool = False,
81
+ ):
82
+ if max_frames_num == 0:
83
+ return np.zeros((1, 336, 336, 3))
84
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
85
+ total_frame_num = len(vr)
86
+ video_time = total_frame_num / vr.get_avg_fps()
87
+ fps = round(vr.get_avg_fps() / fps)
88
+ frame_idx = [i for i in range(0, len(vr), fps)]
89
+ frame_time = [i / fps for i in frame_idx]
90
+ if (max_frames_num > 0 and len(frame_idx) > max_frames_num) or force_sample:
91
+ sample_fps = max_frames_num
92
+ uniform_sampled_frames = np.linspace(
93
+ 0, total_frame_num - 1, sample_fps, dtype=int
94
+ )
95
+ frame_idx = uniform_sampled_frames.tolist()
96
+ frame_time = [i / vr.get_avg_fps() for i in frame_idx]
97
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
98
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
99
+ return spare_frames, frame_time, video_time
100
+
101
+ def chat(
102
+ self,
103
+ prompt: str,
104
+ video_path: str,
105
+ fps: float = 1.0,
106
+ max_new_tokens: int = 512,
107
+ temperature: float = 0.7,
108
+ ) -> str:
109
+ video, _, _ = self.load_video(video_path, fps)
110
+ video = self.image_processor.preprocess(video, return_tensors="pt")[
111
+ "pixel_values"
112
+ ].to(device=self.model.device, dtype=self.dtype)
113
+ video = [video]
114
+ conv_template = (
115
+ "qwen_1_5" # Make sure you use correct chat template for different models
116
+ )
117
+ question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
118
+ conv = copy.deepcopy(conv_templates[conv_template])
119
+ conv.append_message(conv.roles[0], question)
120
+ conv.append_message(conv.roles[1], None)
121
+ prompt_question = conv.get_prompt()
122
+ input_ids = (
123
+ tokenizer_image_token(
124
+ prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
125
+ )
126
+ .unsqueeze(0)
127
+ .to(self.model.device)
128
+ )
129
+ cont = self.model.generate(
130
+ input_ids,
131
+ images=video,
132
+ modalities=["video"],
133
+ do_sample=False,
134
+ temperature=temperature,
135
+ max_new_tokens=max_new_tokens,
136
+ )
137
+ text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[
138
+ 0
139
+ ].strip()
140
+ return text_outputs
141
+
142
+ def chat_with_confidence(
143
+ self,
144
+ prompt: str,
145
+ video_path: str,
146
+ fps: float = 1.0,
147
+ max_new_tokens: int = 512,
148
+ temperature: float = 0.7,
149
+ token_choices: Optional[List[str]] = ["Yes", "No"],
150
+ logits_temperature: Optional[float] = 1.0,
151
+ return_confidence: Optional[bool] = False,
152
+ debug: Optional[bool] = False,
153
+ ) -> Dict[str, Any]:
154
+ pass
models/qwen2_5.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script requires transformers==4.57.0
2
+
3
+ import torch
4
+ from transformers import (
5
+ Qwen2_5_VLForConditionalGeneration,
6
+ AutoProcessor,
7
+ )
8
+ from typing import Optional, Dict, Any, Union, List
9
+ from qwen_vl_utils import process_vision_info
10
+
11
+ # Handle both relative and absolute imports
12
+ try:
13
+ from .base import BaseVideoModel
14
+ except ImportError:
15
+ from base import BaseVideoModel
16
+
17
+
18
+ class Qwen2_5VLModel(BaseVideoModel):
19
+ def __init__(
20
+ self,
21
+ model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct",
22
+ dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
23
+ device_map: Optional[Union[str, Dict]] = "auto",
24
+ attn_implementation: Optional[str] = "flash_attention_2",
25
+ ):
26
+ super().__init__(model_name)
27
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28
+ model_name,
29
+ dtype=dtype,
30
+ device_map=device_map,
31
+ attn_implementation=attn_implementation,
32
+ )
33
+ self.processor = AutoProcessor.from_pretrained(model_name)
34
+
35
+ def chat(
36
+ self,
37
+ prompt: str,
38
+ video_path: str,
39
+ fps: float = 1.0,
40
+ temperature: float = 0.7,
41
+ max_new_tokens: int = 512,
42
+ ) -> str:
43
+ # Messages containing a local video path and a text query
44
+ messages = [
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {
49
+ "type": "video",
50
+ "video": video_path,
51
+ # "max_pixels": 360 * 420,
52
+ "fps": fps,
53
+ },
54
+ {"type": "text", "text": prompt},
55
+ ],
56
+ }
57
+ ]
58
+
59
+ text = self.processor.apply_chat_template(
60
+ messages, tokenize=False, add_generation_prompt=True
61
+ )
62
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
63
+ messages, return_video_kwargs=True
64
+ )
65
+ inputs = self.processor(
66
+ text=[text],
67
+ images=image_inputs,
68
+ videos=video_inputs,
69
+ padding=True,
70
+ return_tensors="pt",
71
+ **video_kwargs,
72
+ )
73
+ inputs = inputs.to(self.model.device)
74
+
75
+ # Inference
76
+ generated_ids = self.model.generate(
77
+ **inputs,
78
+ temperature=temperature,
79
+ max_new_tokens=max_new_tokens,
80
+ )
81
+ generated_ids_trimmed = [
82
+ out_ids[len(in_ids) :]
83
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
84
+ ]
85
+ output_response = self.processor.batch_decode(
86
+ generated_ids_trimmed,
87
+ skip_special_tokens=True,
88
+ clean_up_tokenization_spaces=False,
89
+ )[0]
90
+ return output_response
91
+
92
+ def chat_with_confidence(
93
+ self,
94
+ prompt: str,
95
+ video_path: str,
96
+ fps: float = 1.0,
97
+ max_new_tokens: int = 512,
98
+ temperature: float = 0.7,
99
+ token_choices: Optional[List[str]] = ["Yes", "No"],
100
+ logits_temperature: Optional[float] = 1.0,
101
+ return_confidence: Optional[bool] = False,
102
+ debug: Optional[bool] = False,
103
+ ) -> Dict[str, Any]:
104
+ """
105
+ Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
106
+
107
+ Args:
108
+ prompt (str): The text prompt to generate a response for.
109
+ video_path (str): The path to the video file.
110
+ fps (float, optional): The frames per second of the video. Defaults to 1.0.
111
+ max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
112
+ temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
113
+ logits_temperature (float, optional): The logits temperature to use for generation. Defaults to 1.0.
114
+ token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
115
+ return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
116
+ debug (bool, optional): Whether to run in debug mode. Defaults to False.
117
+
118
+ Returns:
119
+ Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
120
+
121
+ e.g., return_confidence: False
122
+ Output:
123
+ {
124
+ "response": "Yes",
125
+ "logits": {
126
+ "Yes": 12.0,
127
+ "No": 9.0
128
+ }
129
+ }
130
+
131
+ e.g., return_confidence: True
132
+ Output:
133
+ {
134
+ "response": "Yes",
135
+ "confidence": 0.9999
136
+ }
137
+ """
138
+ # Messages containing a local video path and a text query
139
+ messages = [
140
+ {
141
+ "role": "user",
142
+ "content": [
143
+ {
144
+ "type": "video",
145
+ "video": video_path,
146
+ # "max_pixels": 360 * 420,
147
+ "fps": fps,
148
+ },
149
+ {"type": "text", "text": prompt},
150
+ ],
151
+ }
152
+ ]
153
+
154
+ text = self.processor.apply_chat_template(
155
+ messages, tokenize=False, add_generation_prompt=True
156
+ )
157
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
158
+ messages, return_video_kwargs=True
159
+ )
160
+ inputs = self.processor(
161
+ text=[text],
162
+ images=image_inputs,
163
+ videos=video_inputs,
164
+ padding=True,
165
+ return_tensors="pt",
166
+ **video_kwargs,
167
+ )
168
+ inputs = inputs.to(self.model.device)
169
+
170
+ # Inference with scores
171
+ with torch.no_grad():
172
+ outputs = self.model.generate(
173
+ **inputs,
174
+ temperature=temperature,
175
+ max_new_tokens=max_new_tokens,
176
+ output_scores=True,
177
+ return_dict_in_generate=True,
178
+ )
179
+
180
+ generated_ids = outputs.sequences
181
+ scores = outputs.scores # Tuple of tensors, one per generated token
182
+ scores = tuple(
183
+ s / logits_temperature for s in scores
184
+ ) # Scales the logits by a factor for normalization during reporting
185
+
186
+ print(f"Number of generated tokens: {len(scores)}")
187
+ print(f"Vocabulary size: {scores[0].shape[1]}")
188
+ # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
189
+ if debug:
190
+ print("****Running inference in debug mode****")
191
+ # Print first token scores shape and max/min scores in debug mode
192
+ print(f"Single token scores shape: {scores[0].shape}")
193
+ print(
194
+ f"First token max/min scores: {scores[0].max().item()}, {scores[0].min().item()}"
195
+ )
196
+ # Print details about top 3 tokens
197
+ top_3_tokens = torch.topk(scores[0], k=3, dim=-1)
198
+ for i in range(3):
199
+ print(
200
+ f"Pos 0 | {i+1}th Token: {self.processor.decode(top_3_tokens.indices[0, i].item())}"
201
+ )
202
+ print(
203
+ f"Pos 0 | {i+1}th Token logit: {top_3_tokens.values[0, i].item()}"
204
+ )
205
+
206
+ # Trim the prompt tokens from generated sequences
207
+ generated_ids_trimmed = [
208
+ out_ids[len(in_ids) :]
209
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
210
+ ]
211
+
212
+ # Decode the text
213
+ output_response = self.processor.batch_decode(
214
+ generated_ids_trimmed,
215
+ skip_special_tokens=True,
216
+ clean_up_tokenization_spaces=False,
217
+ )[0]
218
+
219
+ # Convert scores to probabilities
220
+ # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
221
+ selected_token_probs = []
222
+ selected_token_logits = []
223
+ first_token_probs = torch.softmax(scores[0], dim=-1)
224
+
225
+ # Now, find indices of tokens in token_choices and get their probabilities
226
+ for token_choice in token_choices:
227
+ # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
228
+ token_index = self.processor.tokenizer.encode(
229
+ token_choice, add_special_tokens=False
230
+ )[0]
231
+ selected_token_probs.append(first_token_probs[0, token_index].item())
232
+ selected_token_logits.append(scores[0][0, token_index].item())
233
+
234
+ # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
235
+ if return_confidence:
236
+ first_token_id = generated_ids_trimmed[0][
237
+ 0
238
+ ].item() # First token of the first sequence
239
+ confidence = (
240
+ first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
241
+ if sum(selected_token_probs) > 0
242
+ else 0.0
243
+ )
244
+ return {
245
+ "response": output_response,
246
+ "confidence": confidence,
247
+ }
248
+
249
+ # Retrn token logits
250
+ else:
251
+ token_logits = dict(zip(token_choices, selected_token_logits))
252
+ return {
253
+ "response": output_response,
254
+ "logits": token_logits,
255
+ }
256
+
257
+
258
+ if __name__ == "__main__":
259
+ model_path = "Qwen/Qwen2.5-VL-7B-Instruct" # "Qwen/Qwen2.5-VL-7B-Instruct"
260
+ model = Qwen2_5VLModel(model_path)
261
+ prompt = (
262
+ "Which of the following exist in the video? Answer in A or B.\nA: Hand\nB: Face"
263
+ )
264
+ token_choices = ["A", "B"]
265
+ ext = ".webm"
266
+ video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
267
+
268
+ generation_config = {
269
+ "max_new_tokens": 128,
270
+ "temperature": 0.7,
271
+ "logits_temperature": 5.0,
272
+ "fps": 3.0,
273
+ "return_confidence": False,
274
+ "debug": True,
275
+ }
276
+ output = model.chat_with_confidence(
277
+ prompt, video_path, token_choices=token_choices, **generation_config
278
+ )
279
+ response = output["response"]
280
+ print(f"Response: {response}")
281
+
282
+ if generation_config["return_confidence"]:
283
+ confidence = output["confidence"]
284
+ print(f"Confidence: {confidence}")
285
+ else:
286
+ selected_token_logits = output["logits"]
287
+ print(f"Selected token logits: {selected_token_logits}")
288
+ print(f"Logits temperature: {generation_config['logits_temperature']}")
models/qwen3vl.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script requires transformers==4.57.0
2
+
3
+ import torch
4
+ from transformers import (
5
+ Qwen3VLForConditionalGeneration,
6
+ AutoProcessor,
7
+ )
8
+ from typing import Optional, Dict, Any, Union, List
9
+ from qwen_vl_utils import process_vision_info
10
+
11
+ # Handle both relative and absolute imports
12
+ try:
13
+ from .base import BaseVideoModel
14
+ except ImportError:
15
+ from base import BaseVideoModel
16
+
17
+
18
+ class Qwen3VLModel(BaseVideoModel):
19
+ def __init__(
20
+ self,
21
+ model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
22
+ dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
23
+ device_map: Optional[Union[str, Dict]] = "auto",
24
+ attn_implementation: Optional[str] = "flash_attention_2",
25
+ ):
26
+ super().__init__(model_name)
27
+ self.model = Qwen3VLForConditionalGeneration.from_pretrained(
28
+ model_name,
29
+ dtype=dtype,
30
+ device_map=device_map,
31
+ attn_implementation=attn_implementation,
32
+ )
33
+ self.processor = AutoProcessor.from_pretrained(model_name)
34
+
35
+ def chat(
36
+ self,
37
+ prompt: str,
38
+ video_path: str,
39
+ fps: float = 1.0,
40
+ temperature: float = 0.7,
41
+ max_new_tokens: int = 512,
42
+ ) -> str:
43
+ # Messages containing a local video path and a text query
44
+ messages = [
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {
49
+ "type": "video",
50
+ "video": video_path,
51
+ # "max_pixels": 360 * 420,
52
+ "fps": fps,
53
+ },
54
+ {"type": "text", "text": prompt},
55
+ ],
56
+ }
57
+ ]
58
+
59
+ inputs = self.processor.apply_chat_template(
60
+ messages,
61
+ tokenize=True,
62
+ add_generation_prompt=True,
63
+ return_dict=True,
64
+ return_tensors="pt",
65
+ )
66
+
67
+ inputs = inputs.to(self.model.device)
68
+
69
+ generated_ids = self.model.generate(
70
+ **inputs,
71
+ max_new_tokens=max_new_tokens,
72
+ temperature=temperature,
73
+ )
74
+
75
+ generated_ids_trimmed = [
76
+ out_ids[len(in_ids) :]
77
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
78
+ ]
79
+
80
+ output_response = self.processor.batch_decode(
81
+ generated_ids_trimmed,
82
+ skip_special_tokens=True,
83
+ clean_up_tokenization_spaces=False,
84
+ )[0]
85
+
86
+ return output_response
87
+
88
+ def chat_with_confidence(
89
+ self,
90
+ prompt: str,
91
+ video_path: str,
92
+ fps: float = 1.0,
93
+ max_new_tokens: int = 512,
94
+ temperature: float = 0.7,
95
+ token_choices: Optional[List[str]] = ["Yes", "No"],
96
+ logits_temperature: Optional[float] = 1.0,
97
+ return_confidence: Optional[bool] = False,
98
+ debug: Optional[bool] = False,
99
+ ) -> Dict[str, Any]:
100
+ """
101
+ Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
102
+
103
+ Args:
104
+ prompt (str): The text prompt to generate a response for.
105
+ video_path (str): The path to the video file.
106
+ temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
107
+ max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512.
108
+ token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
109
+ generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
110
+ return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
111
+ debug (bool, optional): Whether to run in debug mode. Defaults to False.
112
+
113
+ Returns:
114
+ Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
115
+
116
+ e.g., return_confidence: False
117
+ Output:
118
+ {
119
+ "response": "Yes",
120
+ "logits": {
121
+ "Yes": 12.0,
122
+ "No": 9.0
123
+ }
124
+ }
125
+
126
+ e.g., return_confidence: True
127
+ Output:
128
+ {
129
+ "response": "Yes",
130
+ "confidence": 0.9999
131
+ }
132
+ """
133
+ # Messages containing a local video path and a text query
134
+ messages = [
135
+ {
136
+ "role": "user",
137
+ "content": [
138
+ {
139
+ "type": "video",
140
+ "video": video_path,
141
+ # "max_pixels": 360 * 420,
142
+ "fps": fps,
143
+ },
144
+ {"type": "text", "text": prompt},
145
+ ],
146
+ }
147
+ ]
148
+
149
+ text = self.processor.apply_chat_template(
150
+ messages, tokenize=False, add_generation_prompt=True
151
+ )
152
+ image_inputs, videos, video_kwargs = process_vision_info(
153
+ messages,
154
+ image_patch_size=16,
155
+ return_video_kwargs=True,
156
+ return_video_metadata=True,
157
+ )
158
+ # Extract out videos and video metadata
159
+ if videos is not None:
160
+ videos, video_metadatas = zip(*videos)
161
+ videos, video_metadatas = list(videos), list(video_metadatas)
162
+ else:
163
+ video_metadatas = None
164
+
165
+ inputs = self.processor(
166
+ text=text,
167
+ images=image_inputs,
168
+ videos=videos,
169
+ video_metadata=video_metadatas,
170
+ return_tensors="pt",
171
+ do_resize=False,
172
+ **video_kwargs,
173
+ )
174
+ inputs = inputs.to(self.model.device)
175
+
176
+ # Inference with scores
177
+ with torch.no_grad():
178
+ outputs = self.model.generate(
179
+ **inputs,
180
+ temperature=temperature,
181
+ max_new_tokens=max_new_tokens,
182
+ output_scores=True,
183
+ return_dict_in_generate=True,
184
+ )
185
+
186
+ generated_ids = outputs.sequences
187
+ scores = outputs.scores # Tuple of tensors, one per generated token
188
+ scores = tuple(
189
+ s / logits_temperature for s in scores
190
+ ) # Scales the logits by a factor for normalization during reporting
191
+
192
+ print(f"Number of generated tokens: {len(scores)}")
193
+ print(f"Vocabulary size: {scores[0].shape[1]}")
194
+ # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
195
+ if debug:
196
+ print("****Running inference in debug mode****")
197
+ # Print first token scores shape and max/min scores in debug mode
198
+ print(f"Single token scores shape: {scores[0].shape}")
199
+ print(
200
+ f"First token max/min scores: {scores[0].max().item()}, {scores[0].min().item()}"
201
+ )
202
+ # Print details about top 3 tokens
203
+ top_3_tokens = torch.topk(scores[0], k=3, dim=-1)
204
+ for i in range(3):
205
+ print(
206
+ f"Pos 0 | {i+1}th Token: {self.processor.decode(top_3_tokens.indices[0, i].item())}"
207
+ )
208
+ print(
209
+ f"Pos 0 | {i+1}th Token logit: {top_3_tokens.values[0, i].item()}"
210
+ )
211
+
212
+ # Trim the prompt tokens from generated sequences
213
+ generated_ids_trimmed = [
214
+ out_ids[len(in_ids) :]
215
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
216
+ ]
217
+
218
+ # Decode the text
219
+ output_response = self.processor.batch_decode(
220
+ generated_ids_trimmed,
221
+ skip_special_tokens=True,
222
+ clean_up_tokenization_spaces=False,
223
+ )[0]
224
+
225
+ # Convert scores to probabilities
226
+ # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
227
+ selected_token_probs = []
228
+ selected_token_logits = []
229
+ first_token_probs = torch.softmax(scores[0], dim=-1)
230
+
231
+ # Now, find indices of tokens in token_choices and get their probabilities
232
+ for token_choice in token_choices:
233
+ # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
234
+ token_index = self.processor.tokenizer.encode(
235
+ token_choice, add_special_tokens=False
236
+ )[0]
237
+ selected_token_probs.append(first_token_probs[0, token_index].item())
238
+ selected_token_logits.append(scores[0][0, token_index].item())
239
+
240
+ # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
241
+ if return_confidence:
242
+ first_token_id = generated_ids_trimmed[0][
243
+ 0
244
+ ].item() # First token of the first sequence
245
+ confidence = (
246
+ first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
247
+ if sum(selected_token_probs) > 0
248
+ else 0.0
249
+ )
250
+ return {
251
+ "response": output_response,
252
+ "confidence": confidence,
253
+ }
254
+
255
+ # Retrn token logits
256
+ else:
257
+ token_logits = dict(zip(token_choices, selected_token_logits))
258
+ return {
259
+ "response": output_response,
260
+ "logits": token_logits,
261
+ }
262
+
263
+
264
+ if __name__ == "__main__":
265
+ model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
266
+ model = Qwen3VLModel(model_path)
267
+ prompt = "Describe this video."
268
+ ext = ".mp4"
269
+ video_path = (
270
+ "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/1586" + ext
271
+ )
272
+ response = model.chat(prompt, video_path)
273
+ print("Response: ", response)
274
+
275
+ token_choices = ["A", "B"]
276
+ ext = ".webm"
277
+ video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
278
+
279
+ generation_config = {
280
+ "max_new_tokens": 128,
281
+ "temperature": 0.7,
282
+ "logits_temperature": 5.0,
283
+ "fps": 3.0,
284
+ "return_confidence": False,
285
+ "debug": True,
286
+ }
287
+ output = model.chat_with_confidence(
288
+ prompt, video_path, token_choices=token_choices, **generation_config
289
+ )
290
+ response = output["response"]
291
+ print(f"Response: {response}")
292
+
293
+ if generation_config["return_confidence"]:
294
+ confidence = output["confidence"]
295
+ print(f"Confidence: {confidence}")
296
+ else:
297
+ selected_token_logits = output["logits"]
298
+ print(f"Selected token logits: {selected_token_logits}")
299
+ print(f"Logits temperature: {generation_config['logits_temperature']}")
requirements.txt ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ aiohappyeyeballs==2.4.0
3
+ aiohttp==3.10.5
4
+ aiosignal==1.3.1
5
+ anyio==4.12.1
6
+ asttokens==3.0.1
7
+ async-timeout==4.0.3
8
+ attrs==24.2.0
9
+ av==12.3.0
10
+ beautifulsoup4==4.14.3
11
+ bitsandbytes==0.41.0
12
+ black==25.12.0
13
+ cachetools==6.2.4
14
+ certifi==2024.8.30
15
+ cfgv==3.5.0
16
+ charset-normalizer==3.3.2
17
+ click==8.1.7
18
+ contourpy==1.3.2
19
+ cuda-bindings==12.9.4
20
+ cuda-pathfinder==1.3.3
21
+ cycler==0.12.1
22
+ datasets==2.16.1
23
+ decorator==5.2.1
24
+ decord==0.6.0
25
+ deepspeed==0.14.2
26
+ dill==0.3.7
27
+ distlib==0.4.0
28
+ distro==1.9.0
29
+ docker-pycreds==0.4.0
30
+ docstring_parser==0.16
31
+ einops==0.6.1
32
+ einops-exts==0.0.4
33
+ exceptiongroup==1.3.1
34
+ executing==2.2.1
35
+ filelock==3.20.3
36
+ flash-attn==2.5.7
37
+ fonttools==4.61.1
38
+ frozenlist==1.4.1
39
+ fsspec==2023.10.0
40
+ ftfy==6.2.3
41
+ gdown==5.2.1
42
+ gitdb==4.0.11
43
+ GitPython==3.1.43
44
+ gradio==6.2.0
45
+ gradio_client==2.0.2
46
+ h11==0.16.0
47
+ hf-xet==1.2.0
48
+ hf_transfer==0.1.8
49
+ hjson==3.1.0
50
+ httpcore==1.0.9
51
+ httpx==0.28.1
52
+ huggingface_hub==1.4.1
53
+ identify==2.6.16
54
+ ipython==8.38.0
55
+ jedi==0.19.2
56
+ Jinja2==3.1.4
57
+ jiter==0.6.1
58
+ joblib==1.5.3
59
+ kiwisolver==1.4.9
60
+ latex2mathml==3.77.0
61
+ llava @ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git@e9835311c6f515a13702eb7a7750fcd936f65ed8
62
+ markdown-it-py==3.0.0
63
+ markdown2==2.5.0
64
+ MarkupSafe==2.1.5
65
+ matplotlib==3.10.8
66
+ matplotlib-inline==0.2.1
67
+ mpmath==1.3.0
68
+ multidict==6.0.5
69
+ multiprocess==0.70.15
70
+ mypy_extensions==1.1.0
71
+ networkx==3.4.2
72
+ ninja==1.11.1.1
73
+ nltk==3.9.2
74
+ nodeenv==1.10.0
75
+ numpy==1.26.4
76
+ open_clip_torch==2.26.1
77
+ openai==1.52.2
78
+ opencv-python==4.10.0.84
79
+ packaging==26.0
80
+ pandas==2.3.3
81
+ parso==0.8.5
82
+ pathspec==1.0.3
83
+ peft==0.4.0
84
+ pexpect==4.9.0
85
+ pillow==12.1.0
86
+ platformdirs==4.2.2
87
+ pre_commit==4.5.1
88
+ prompt_toolkit==3.0.52
89
+ protobuf==5.28.0
90
+ psutil==7.2.1
91
+ ptyprocess==0.7.0
92
+ pure_eval==0.2.3
93
+ py-cpuinfo==9.0.0
94
+ pyarrow==17.0.0
95
+ pyarrow-hotfix==0.6
96
+ pydantic_core==2.41.5
97
+ Pygments==2.18.0
98
+ pynvml==13.0.1
99
+ pyparsing==3.3.2
100
+ PySocks==1.7.1
101
+ python-dateutil==2.9.0.post0
102
+ pytokens==0.3.0
103
+ pytz==2024.1
104
+ PyYAML
105
+ regex==2026.1.15
106
+ requests==2.32.3
107
+ rich==13.8.0
108
+ safetensors==0.7.0
109
+ scikit-learn==1.7.2
110
+ scipy==1.15.3
111
+ seaborn==0.13.2
112
+ sentence-transformers==5.2.2
113
+ sentencepiece==0.1.99
114
+ sentry-sdk==2.13.0
115
+ setproctitle==1.3.3
116
+ shellingham==1.5.4
117
+ shortuuid==1.0.13
118
+ shtab==1.7.1
119
+ six==1.16.0
120
+ smmap==5.0.1
121
+ soupsieve==2.8.3
122
+ stack-data==0.6.3
123
+ svgwrite==1.4.3
124
+ sympy==1.14.0
125
+ termcolor==3.3.0
126
+ threadpoolctl==3.6.0
127
+ timm==1.0.9
128
+ tokenizers==0.22.2
129
+ tomli==2.4.0
130
+ torch==2.2.1
131
+ torchvision==0.17.1
132
+ tqdm==4.67.3
133
+ traitlets==5.14.3
134
+ transformers==5.1.0
135
+ triton==2.2.0
136
+ typer==0.20.0
137
+ typer-slim==0.21.1
138
+ typing_extensions==4.15.0
139
+ tyro==0.8.10
140
+ tzdata==2025.3
141
+ urllib3==1.26.20
142
+ uvicorn==0.30.6
143
+ virtualenv==20.36.1
144
+ wandb==0.17.8
145
+ wavedrom==2.0.3.post3
146
+ wcwidth==0.2.13
147
+ websockets==13.0.1
148
+ xxhash==3.5.0
149
+ yarl==1.9.7