Spaces:
Sleeping
Sleeping
Commit
Β·
cf5f08b
1
Parent(s):
40a4325
Add correct models/ repo
Browse files- models/.gitkeep +0 -0
- models/__init__.py +64 -23
- models/llava_video.py +173 -6
- models/{qwen2_5.py β qwen2_5vl.py} +4 -0
- models/qwen3vl.py +302 -66
models/.gitkeep
ADDED
|
File without changes
|
models/__init__.py
CHANGED
|
@@ -3,51 +3,64 @@ from packaging import version
|
|
| 3 |
import torch
|
| 4 |
from typing import Optional, Union, Dict
|
| 5 |
|
| 6 |
-
#
|
| 7 |
qwen_required_version = version.parse("4.57.0")
|
|
|
|
| 8 |
llava_required_version = version.parse("4.40.0")
|
| 9 |
|
| 10 |
# Conditional imports based on transformers version
|
| 11 |
try:
|
| 12 |
import transformers
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
transformers_version = version.parse(transformers.__version__)
|
| 15 |
|
| 16 |
QWEN_MODELS_AVAILABLE = False
|
|
|
|
| 17 |
LLAVA_MODELS_AVAILABLE = False
|
| 18 |
|
| 19 |
# Qwen condition
|
| 20 |
if transformers_version >= qwen_required_version:
|
| 21 |
-
from .
|
| 22 |
from .qwen3vl import Qwen3VLModel
|
|
|
|
| 23 |
QWEN_MODELS_AVAILABLE = True
|
| 24 |
else:
|
| 25 |
print(
|
| 26 |
-
f"Warning: Qwen models require transformers>=4.57.0, "
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
# LLaVA condition
|
| 32 |
if transformers_version <= llava_required_version:
|
| 33 |
from .llava_video import LLaVAVideoModel
|
|
|
|
| 34 |
LLAVA_MODELS_AVAILABLE = True
|
| 35 |
else:
|
| 36 |
print(
|
| 37 |
-
f"Warning: LLaVA models require transformers<=4.40.0, "
|
| 38 |
-
f"but found {transformers.__version__}. "
|
| 39 |
-
f"LLaVA models will not be available."
|
| 40 |
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
# Build __all__ list dynamically
|
| 48 |
__all__ = []
|
| 49 |
if QWEN_MODELS_AVAILABLE:
|
| 50 |
__all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
|
|
|
|
|
|
|
| 51 |
if LLAVA_MODELS_AVAILABLE:
|
| 52 |
__all__.append("LLaVAVideoModel")
|
| 53 |
|
|
@@ -59,12 +72,11 @@ def load_model(
|
|
| 59 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 60 |
attn_implementation: Optional[str] = "flash_attention_2",
|
| 61 |
) -> BaseVideoModel:
|
| 62 |
-
|
| 63 |
if "LLaVA-Video" in model_path:
|
| 64 |
if not LLAVA_MODELS_AVAILABLE:
|
| 65 |
raise ImportError(
|
| 66 |
-
"LLaVA models require transformers<=4.40.0.
|
| 67 |
-
"Please downgrade transformers."
|
| 68 |
)
|
| 69 |
return LLaVAVideoModel(
|
| 70 |
model_path,
|
|
@@ -72,14 +84,12 @@ def load_model(
|
|
| 72 |
device_map=device_map,
|
| 73 |
attn_implementation=attn_implementation,
|
| 74 |
)
|
| 75 |
-
|
| 76 |
elif "Qwen" in model_path:
|
| 77 |
if not QWEN_MODELS_AVAILABLE:
|
| 78 |
raise ImportError(
|
| 79 |
-
"Qwen models require transformers>=4.57.0.
|
| 80 |
-
"Please upgrade transformers."
|
| 81 |
)
|
| 82 |
-
|
| 83 |
if "Qwen3" in model_path:
|
| 84 |
return Qwen3VLModel(
|
| 85 |
model_path,
|
|
@@ -94,7 +104,38 @@ def load_model(
|
|
| 94 |
device_map=device_map,
|
| 95 |
attn_implementation=attn_implementation,
|
| 96 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
else:
|
| 99 |
-
raise ValueError(f"Unsupported model path: {model_path}")
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
from typing import Optional, Union, Dict
|
| 5 |
|
| 6 |
+
# IMP: Add required versions here
|
| 7 |
qwen_required_version = version.parse("4.57.0")
|
| 8 |
+
internvl_required_version = version.parse("4.45.0")
|
| 9 |
llava_required_version = version.parse("4.40.0")
|
| 10 |
|
| 11 |
# Conditional imports based on transformers version
|
| 12 |
try:
|
| 13 |
import transformers
|
| 14 |
+
from transformers.generation.logits_process import LogitsProcessor
|
| 15 |
+
|
| 16 |
+
# Check transformers version
|
| 17 |
transformers_version = version.parse(transformers.__version__)
|
| 18 |
|
| 19 |
QWEN_MODELS_AVAILABLE = False
|
| 20 |
+
INTERNVL_MODELS_AVAILABLE = False
|
| 21 |
LLAVA_MODELS_AVAILABLE = False
|
| 22 |
|
| 23 |
# Qwen condition
|
| 24 |
if transformers_version >= qwen_required_version:
|
| 25 |
+
from .qwen2_5vl import Qwen2_5VLModel
|
| 26 |
from .qwen3vl import Qwen3VLModel
|
| 27 |
+
|
| 28 |
QWEN_MODELS_AVAILABLE = True
|
| 29 |
else:
|
| 30 |
print(
|
| 31 |
+
f"Warning: Qwen models require transformers>=4.57.0, but found {transformers.__version__}. Qwen models will not be available. Please upgrade to transformers>=4.57.0 or switch conda environments to use Qwen models."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# InternVL condition
|
| 35 |
+
if transformers_version >= internvl_required_version:
|
| 36 |
+
from .internvl import InternVLModel
|
| 37 |
+
|
| 38 |
+
INTERNVL_MODELS_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
print(
|
| 41 |
+
f"Warning: InternVL models require transformers>=4.45.0, but found {transformers.__version__}. InternVL models will not be available. Please downgrade to transformers<=4.45.0 or switch conda environments to use InternVL models."
|
| 42 |
)
|
| 43 |
|
| 44 |
# LLaVA condition
|
| 45 |
if transformers_version <= llava_required_version:
|
| 46 |
from .llava_video import LLaVAVideoModel
|
| 47 |
+
|
| 48 |
LLAVA_MODELS_AVAILABLE = True
|
| 49 |
else:
|
| 50 |
print(
|
| 51 |
+
f"Warning: LLaVA models require transformers<=4.40.0, but found {transformers.__version__}. LLaVA models will not be available. Please downgrade to transformers<=4.40.0 or switch conda environments to use LLaVA models."
|
|
|
|
|
|
|
| 52 |
)
|
| 53 |
+
except ImportError:
|
| 54 |
+
print(
|
| 55 |
+
"Warning: Could not check transformers version. Please re-check transformers installation."
|
| 56 |
+
)
|
|
|
|
| 57 |
|
| 58 |
# Build __all__ list dynamically
|
| 59 |
__all__ = []
|
| 60 |
if QWEN_MODELS_AVAILABLE:
|
| 61 |
__all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
|
| 62 |
+
if INTERNVL_MODELS_AVAILABLE:
|
| 63 |
+
__all__.append("InternVLModel")
|
| 64 |
if LLAVA_MODELS_AVAILABLE:
|
| 65 |
__all__.append("LLaVAVideoModel")
|
| 66 |
|
|
|
|
| 72 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 73 |
attn_implementation: Optional[str] = "flash_attention_2",
|
| 74 |
) -> BaseVideoModel:
|
|
|
|
| 75 |
if "LLaVA-Video" in model_path:
|
| 76 |
if not LLAVA_MODELS_AVAILABLE:
|
| 77 |
raise ImportError(
|
| 78 |
+
f"LLaVA models require transformers<=4.40.0."
|
| 79 |
+
f"Please downgrade transformers: pip install transformers<=4.40.0"
|
| 80 |
)
|
| 81 |
return LLaVAVideoModel(
|
| 82 |
model_path,
|
|
|
|
| 84 |
device_map=device_map,
|
| 85 |
attn_implementation=attn_implementation,
|
| 86 |
)
|
|
|
|
| 87 |
elif "Qwen" in model_path:
|
| 88 |
if not QWEN_MODELS_AVAILABLE:
|
| 89 |
raise ImportError(
|
| 90 |
+
f"Qwen models require transformers>=4.57.0."
|
| 91 |
+
f"Please upgrade transformers: pip install transformers>=4.57.0"
|
| 92 |
)
|
|
|
|
| 93 |
if "Qwen3" in model_path:
|
| 94 |
return Qwen3VLModel(
|
| 95 |
model_path,
|
|
|
|
| 104 |
device_map=device_map,
|
| 105 |
attn_implementation=attn_implementation,
|
| 106 |
)
|
| 107 |
+
elif "Intern" in model_path:
|
| 108 |
+
if not INTERNVL_MODELS_AVAILABLE:
|
| 109 |
+
raise ImportError(
|
| 110 |
+
f"InternVL models require transformers>=4.45.0."
|
| 111 |
+
f"Please upgrade transformers: pip install transformers>=4.45.0"
|
| 112 |
+
)
|
| 113 |
+
return InternVLModel(
|
| 114 |
+
model_path,
|
| 115 |
+
dtype=dtype,
|
| 116 |
+
device_map=device_map,
|
| 117 |
+
attn_implementation=attn_implementation,
|
| 118 |
+
)
|
| 119 |
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
class LogitsCaptureProcessor(LogitsProcessor):
|
| 122 |
+
"""
|
| 123 |
+
Custom LogitsProcessor that captures the processed logits right before sampling.
|
| 124 |
+
This allows us to see what the actual distribution looks like after all other
|
| 125 |
+
processors have been applied.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.captured_logits = []
|
| 130 |
+
|
| 131 |
+
def __call__(
|
| 132 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
| 133 |
+
) -> torch.FloatTensor:
|
| 134 |
+
# Store a copy of the logits at this point in generation
|
| 135 |
+
self.captured_logits.append(scores.detach().clone().cpu())
|
| 136 |
+
# Return scores unchanged - we're just observing
|
| 137 |
+
return scores
|
| 138 |
+
|
| 139 |
+
def reset(self):
|
| 140 |
+
"""Clear captured logits for a new generation"""
|
| 141 |
+
self.captured_logits = []
|
models/llava_video.py
CHANGED
|
@@ -17,8 +17,7 @@ from PIL import Image
|
|
| 17 |
import requests
|
| 18 |
import copy
|
| 19 |
import torch
|
| 20 |
-
import
|
| 21 |
-
from typing import Optional, Union, Dict, List, Any
|
| 22 |
import warnings
|
| 23 |
from decord import VideoReader, cpu
|
| 24 |
import numpy as np
|
|
@@ -56,7 +55,6 @@ class LLaVAVideoModel(BaseVideoModel):
|
|
| 56 |
base_model,
|
| 57 |
torch_dtype=torch_dtype,
|
| 58 |
device_map=device_map,
|
| 59 |
-
attn_implementation=attn_implementation,
|
| 60 |
)
|
| 61 |
) # Add any other thing you want to pass in llava_model_args
|
| 62 |
self.model.eval()
|
|
@@ -105,10 +103,18 @@ class LLaVAVideoModel(BaseVideoModel):
|
|
| 105 |
video_path: str,
|
| 106 |
fps: float = 1.0,
|
| 107 |
max_new_tokens: int = 512,
|
|
|
|
|
|
|
|
|
|
| 108 |
temperature: float = 0.7,
|
|
|
|
|
|
|
| 109 |
**kwargs: Any,
|
| 110 |
) -> str:
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
| 112 |
video = self.image_processor.preprocess(video, return_tensors="pt")[
|
| 113 |
"pixel_values"
|
| 114 |
].to(device=self.model.device, dtype=self.dtype)
|
|
@@ -132,7 +138,7 @@ class LLaVAVideoModel(BaseVideoModel):
|
|
| 132 |
input_ids,
|
| 133 |
images=video,
|
| 134 |
modalities=["video"],
|
| 135 |
-
do_sample=
|
| 136 |
temperature=temperature,
|
| 137 |
max_new_tokens=max_new_tokens,
|
| 138 |
**kwargs,
|
|
@@ -149,9 +155,170 @@ class LLaVAVideoModel(BaseVideoModel):
|
|
| 149 |
fps: float = 1.0,
|
| 150 |
max_new_tokens: int = 512,
|
| 151 |
temperature: float = 0.7,
|
|
|
|
|
|
|
|
|
|
| 152 |
token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 153 |
logits_temperature: Optional[float] = 1.0,
|
| 154 |
return_confidence: Optional[bool] = False,
|
|
|
|
| 155 |
debug: Optional[bool] = False,
|
| 156 |
) -> Dict[str, Any]:
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import requests
|
| 18 |
import copy
|
| 19 |
import torch
|
| 20 |
+
from typing import Optional, Union, Dict, List, Tuple, Any
|
|
|
|
| 21 |
import warnings
|
| 22 |
from decord import VideoReader, cpu
|
| 23 |
import numpy as np
|
|
|
|
| 55 |
base_model,
|
| 56 |
torch_dtype=torch_dtype,
|
| 57 |
device_map=device_map,
|
|
|
|
| 58 |
)
|
| 59 |
) # Add any other thing you want to pass in llava_model_args
|
| 60 |
self.model.eval()
|
|
|
|
| 103 |
video_path: str,
|
| 104 |
fps: float = 1.0,
|
| 105 |
max_new_tokens: int = 512,
|
| 106 |
+
do_sample: Optional[
|
| 107 |
+
bool
|
| 108 |
+
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 109 |
temperature: float = 0.7,
|
| 110 |
+
video_mode: Optional[str] = "video",
|
| 111 |
+
video_frames: Optional[int] = 10,
|
| 112 |
**kwargs: Any,
|
| 113 |
) -> str:
|
| 114 |
+
if video_mode == "frames":
|
| 115 |
+
video, _, _ = self.load_video(video_path, max_frames_num=video_frames)
|
| 116 |
+
elif video_mode == "video":
|
| 117 |
+
video, _, _ = self.load_video(video_path, fps)
|
| 118 |
video = self.image_processor.preprocess(video, return_tensors="pt")[
|
| 119 |
"pixel_values"
|
| 120 |
].to(device=self.model.device, dtype=self.dtype)
|
|
|
|
| 138 |
input_ids,
|
| 139 |
images=video,
|
| 140 |
modalities=["video"],
|
| 141 |
+
do_sample=do_sample,
|
| 142 |
temperature=temperature,
|
| 143 |
max_new_tokens=max_new_tokens,
|
| 144 |
**kwargs,
|
|
|
|
| 155 |
fps: float = 1.0,
|
| 156 |
max_new_tokens: int = 512,
|
| 157 |
temperature: float = 0.7,
|
| 158 |
+
do_sample: Optional[
|
| 159 |
+
bool
|
| 160 |
+
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 161 |
token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 162 |
logits_temperature: Optional[float] = 1.0,
|
| 163 |
return_confidence: Optional[bool] = False,
|
| 164 |
+
top_k_tokens: Optional[int] = 10,
|
| 165 |
debug: Optional[bool] = False,
|
| 166 |
) -> Dict[str, Any]:
|
| 167 |
+
video, _, _ = self.load_video(video_path, fps)
|
| 168 |
+
video = self.image_processor.preprocess(video, return_tensors="pt")[
|
| 169 |
+
"pixel_values"
|
| 170 |
+
].to(device=self.model.device, dtype=self.dtype)
|
| 171 |
+
video = [video]
|
| 172 |
+
conv_template = (
|
| 173 |
+
"qwen_1_5" # Make sure you use correct chat template for different models
|
| 174 |
+
)
|
| 175 |
+
question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
|
| 176 |
+
conv = copy.deepcopy(conv_templates[conv_template])
|
| 177 |
+
conv.append_message(conv.roles[0], question)
|
| 178 |
+
conv.append_message(conv.roles[1], None)
|
| 179 |
+
prompt_question = conv.get_prompt()
|
| 180 |
+
input_ids = (
|
| 181 |
+
tokenizer_image_token(
|
| 182 |
+
prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 183 |
+
)
|
| 184 |
+
.unsqueeze(0)
|
| 185 |
+
.to(self.model.device)
|
| 186 |
+
)
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
outputs = self.model.generate(
|
| 189 |
+
input_ids,
|
| 190 |
+
images=video,
|
| 191 |
+
modalities=["video"],
|
| 192 |
+
do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 193 |
+
temperature=temperature,
|
| 194 |
+
max_new_tokens=max_new_tokens,
|
| 195 |
+
output_scores=True,
|
| 196 |
+
return_dict_in_generate=True,
|
| 197 |
+
)
|
| 198 |
+
generated_ids = outputs.sequences
|
| 199 |
+
scores = outputs.scores # Tuple of tensors, one per generated token
|
| 200 |
+
|
| 201 |
+
print(f"Number of generated tokens: {len(scores)}")
|
| 202 |
+
print(f"Vocabulary size: {scores[0].shape[1]}")
|
| 203 |
+
# Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
|
| 204 |
+
if debug:
|
| 205 |
+
print("****Running inference in debug mode****")
|
| 206 |
+
# Print first token scores shape and max/min scores in debug mode
|
| 207 |
+
print(f"Single token scores shape: {scores[0].shape}")
|
| 208 |
+
print(
|
| 209 |
+
f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Print details about top 10 tokens based on logits
|
| 213 |
+
logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
|
| 214 |
+
print(f"\n{'β'*80}")
|
| 215 |
+
print(
|
| 216 |
+
f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
|
| 217 |
+
)
|
| 218 |
+
print(f"{'β'*80}")
|
| 219 |
+
top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
|
| 220 |
+
for i in range(top_k_tokens):
|
| 221 |
+
score = top_k_tokens_scores.values[0, i].item()
|
| 222 |
+
score_index = top_k_tokens_scores.indices[0, i].item()
|
| 223 |
+
token = self.tokenizer.decode(score_index)
|
| 224 |
+
print(f"#{i+1}th Token: {token}")
|
| 225 |
+
print(f"#{i+1}th Token index: {score_index}")
|
| 226 |
+
print(f"#{i+1}th Token score: {score}")
|
| 227 |
+
print("--------------------------------")
|
| 228 |
+
|
| 229 |
+
# Decode the text
|
| 230 |
+
output_response = self.tokenizer.batch_decode(
|
| 231 |
+
generated_ids,
|
| 232 |
+
skip_special_tokens=True,
|
| 233 |
+
clean_up_tokenization_spaces=False,
|
| 234 |
+
)[0]
|
| 235 |
+
|
| 236 |
+
# Convert scores to probabilities
|
| 237 |
+
# scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
|
| 238 |
+
selected_token_probs = []
|
| 239 |
+
selected_token_logits = []
|
| 240 |
+
first_token_probs = torch.softmax(scores[0], dim=-1)
|
| 241 |
+
|
| 242 |
+
# Now, find indices of tokens in token_choices and get their probabilities
|
| 243 |
+
for token_choice in token_choices:
|
| 244 |
+
# Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
|
| 245 |
+
token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
|
| 246 |
+
0
|
| 247 |
+
]
|
| 248 |
+
selected_token_probs.append(first_token_probs[0, token_index].item())
|
| 249 |
+
selected_token_logits.append(scores[0][0, token_index].item())
|
| 250 |
+
|
| 251 |
+
# Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
|
| 252 |
+
if return_confidence:
|
| 253 |
+
first_token_id = generated_ids[0][
|
| 254 |
+
0
|
| 255 |
+
].item() # First token of the first sequence
|
| 256 |
+
confidence = (
|
| 257 |
+
first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
|
| 258 |
+
if sum(selected_token_probs) > 0
|
| 259 |
+
else 0.0
|
| 260 |
+
)
|
| 261 |
+
return {
|
| 262 |
+
"response": output_response,
|
| 263 |
+
"confidence": confidence,
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# Return token logits
|
| 267 |
+
else:
|
| 268 |
+
token_logits = dict(zip(token_choices, selected_token_logits))
|
| 269 |
+
top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
|
| 270 |
+
top_k_tokens_list: List[Tuple[str, int, float]] = []
|
| 271 |
+
for i in range(top_k_tokens):
|
| 272 |
+
logit_index = top_k_logits_indices.indices[0, i].item()
|
| 273 |
+
token = self.tokenizer.decode(logit_index)
|
| 274 |
+
logit = top_k_logits_indices.values[0, i].item()
|
| 275 |
+
top_k_tokens_list.append((token, logit_index, logit))
|
| 276 |
+
return {
|
| 277 |
+
"response": output_response,
|
| 278 |
+
"top_k_tokens": top_k_tokens_list,
|
| 279 |
+
"token_logits": token_logits,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
if __name__ == "__main__":
|
| 284 |
+
model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 285 |
+
device_map = "cuda:0"
|
| 286 |
+
model = LLaVAVideoModel(model_path, device_map=device_map)
|
| 287 |
+
prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
|
| 288 |
+
token_choices = ["Yes", "No"]
|
| 289 |
+
video_path = (
|
| 290 |
+
"/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
generation_config = {
|
| 294 |
+
"max_new_tokens": 128,
|
| 295 |
+
"do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
|
| 296 |
+
"temperature": 0.7,
|
| 297 |
+
"logits_temperature": 1.0,
|
| 298 |
+
"fps": 1.0,
|
| 299 |
+
"return_confidence": False,
|
| 300 |
+
"top_k_tokens": 10,
|
| 301 |
+
"debug": False,
|
| 302 |
+
}
|
| 303 |
+
output = model.chat_with_confidence(
|
| 304 |
+
prompt, video_path, token_choices=token_choices, **generation_config
|
| 305 |
+
)
|
| 306 |
+
response = output["response"]
|
| 307 |
+
print(f"Response: {response}")
|
| 308 |
+
|
| 309 |
+
if generation_config["return_confidence"]:
|
| 310 |
+
confidence = output["confidence"]
|
| 311 |
+
print(f"Confidence: {confidence}")
|
| 312 |
+
else:
|
| 313 |
+
# If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
|
| 314 |
+
# otherwise, the raw logits are used, which are not filtered.
|
| 315 |
+
logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
|
| 316 |
+
print(f"\n{'β'*80}")
|
| 317 |
+
print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
|
| 318 |
+
print(f"{'β'*80}")
|
| 319 |
+
top_k_tokens = output["top_k_tokens"]
|
| 320 |
+
for i in range(len(top_k_tokens)):
|
| 321 |
+
print(f"Top {i+1} token: {top_k_tokens[i][0]}")
|
| 322 |
+
print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
|
| 323 |
+
print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
|
| 324 |
+
print("--------------------------------")
|
models/{qwen2_5.py β qwen2_5vl.py}
RENAMED
|
@@ -39,6 +39,8 @@ class Qwen2_5VLModel(BaseVideoModel):
|
|
| 39 |
fps: float = 1.0,
|
| 40 |
temperature: float = 0.7,
|
| 41 |
max_new_tokens: int = 512,
|
|
|
|
|
|
|
| 42 |
) -> str:
|
| 43 |
# Messages containing a local video path and a text query
|
| 44 |
messages = [
|
|
@@ -75,8 +77,10 @@ class Qwen2_5VLModel(BaseVideoModel):
|
|
| 75 |
# Inference
|
| 76 |
generated_ids = self.model.generate(
|
| 77 |
**inputs,
|
|
|
|
| 78 |
temperature=temperature,
|
| 79 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 80 |
)
|
| 81 |
generated_ids_trimmed = [
|
| 82 |
out_ids[len(in_ids) :]
|
|
|
|
| 39 |
fps: float = 1.0,
|
| 40 |
temperature: float = 0.7,
|
| 41 |
max_new_tokens: int = 512,
|
| 42 |
+
do_sample: Optional[bool] = True,
|
| 43 |
+
**kwargs: Any,
|
| 44 |
) -> str:
|
| 45 |
# Messages containing a local video path and a text query
|
| 46 |
messages = [
|
|
|
|
| 77 |
# Inference
|
| 78 |
generated_ids = self.model.generate(
|
| 79 |
**inputs,
|
| 80 |
+
do_sample=do_sample,
|
| 81 |
temperature=temperature,
|
| 82 |
max_new_tokens=max_new_tokens,
|
| 83 |
+
**kwargs,
|
| 84 |
)
|
| 85 |
generated_ids_trimmed = [
|
| 86 |
out_ids[len(in_ids) :]
|
models/qwen3vl.py
CHANGED
|
@@ -5,8 +5,11 @@ from transformers import (
|
|
| 5 |
Qwen3VLForConditionalGeneration,
|
| 6 |
AutoProcessor,
|
| 7 |
)
|
| 8 |
-
from typing import Optional, Dict, Any, Union, List
|
| 9 |
from qwen_vl_utils import process_vision_info
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Handle both relative and absolute imports
|
| 12 |
try:
|
|
@@ -15,6 +18,36 @@ except ImportError:
|
|
| 15 |
from base import BaseVideoModel
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class Qwen3VLModel(BaseVideoModel):
|
| 19 |
def __init__(
|
| 20 |
self,
|
|
@@ -38,31 +71,55 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 38 |
video_path: str,
|
| 39 |
fps: float = 1.0,
|
| 40 |
temperature: float = 0.7,
|
|
|
|
|
|
|
|
|
|
| 41 |
max_new_tokens: int = 512,
|
|
|
|
|
|
|
|
|
|
| 42 |
) -> str:
|
| 43 |
# Messages containing a local video path and a text query
|
| 44 |
messages = [
|
| 45 |
{
|
| 46 |
"role": "user",
|
| 47 |
"content": [
|
| 48 |
-
{
|
| 49 |
-
"type": "video",
|
| 50 |
-
"video": video_path,
|
| 51 |
-
# "max_pixels": 360 * 420,
|
| 52 |
-
"fps": fps,
|
| 53 |
-
},
|
| 54 |
{"type": "text", "text": prompt},
|
| 55 |
],
|
| 56 |
}
|
| 57 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
inputs = inputs.to(self.model.device)
|
| 68 |
|
|
@@ -70,6 +127,8 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 70 |
**inputs,
|
| 71 |
max_new_tokens=max_new_tokens,
|
| 72 |
temperature=temperature,
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
generated_ids_trimmed = [
|
|
@@ -92,13 +151,18 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 92 |
fps: float = 1.0,
|
| 93 |
max_new_tokens: int = 512,
|
| 94 |
temperature: float = 0.7,
|
|
|
|
|
|
|
|
|
|
| 95 |
token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 96 |
logits_temperature: Optional[float] = 1.0,
|
| 97 |
return_confidence: Optional[bool] = False,
|
|
|
|
| 98 |
debug: Optional[bool] = False,
|
|
|
|
| 99 |
) -> Dict[str, Any]:
|
| 100 |
"""
|
| 101 |
-
Returns the response and confidence of the response, if return_confidence is True. Else, returns the
|
| 102 |
|
| 103 |
Args:
|
| 104 |
prompt (str): The text prompt to generate a response for.
|
|
@@ -108,19 +172,17 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 108 |
token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
|
| 109 |
generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
|
| 110 |
return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
|
|
|
|
| 111 |
debug (bool, optional): Whether to run in debug mode. Defaults to False.
|
| 112 |
|
| 113 |
Returns:
|
| 114 |
-
Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the
|
| 115 |
|
| 116 |
e.g., return_confidence: False
|
| 117 |
Output:
|
| 118 |
{
|
| 119 |
"response": "Yes",
|
| 120 |
-
"
|
| 121 |
-
"Yes": 12.0,
|
| 122 |
-
"No": 9.0
|
| 123 |
-
}
|
| 124 |
}
|
| 125 |
|
| 126 |
e.g., return_confidence: True
|
|
@@ -146,68 +208,233 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 146 |
}
|
| 147 |
]
|
| 148 |
|
| 149 |
-
|
| 150 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 151 |
-
)
|
| 152 |
-
image_inputs, videos, video_kwargs = process_vision_info(
|
| 153 |
messages,
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
)
|
| 158 |
-
# Extract out videos and video metadata
|
| 159 |
-
if videos is not None:
|
| 160 |
-
videos, video_metadatas = zip(*videos)
|
| 161 |
-
videos, video_metadatas = list(videos), list(video_metadatas)
|
| 162 |
-
else:
|
| 163 |
-
video_metadatas = None
|
| 164 |
-
|
| 165 |
-
inputs = self.processor(
|
| 166 |
-
text=text,
|
| 167 |
-
images=image_inputs,
|
| 168 |
-
videos=videos,
|
| 169 |
-
video_metadata=video_metadatas,
|
| 170 |
return_tensors="pt",
|
| 171 |
-
do_resize=False,
|
| 172 |
-
**video_kwargs,
|
| 173 |
)
|
|
|
|
| 174 |
inputs = inputs.to(self.model.device)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
# Inference with scores
|
| 177 |
with torch.no_grad():
|
| 178 |
outputs = self.model.generate(
|
| 179 |
**inputs,
|
| 180 |
temperature=temperature,
|
| 181 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 182 |
output_scores=True,
|
|
|
|
| 183 |
return_dict_in_generate=True,
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
generated_ids = outputs.sequences
|
| 187 |
-
scores = outputs.scores # Tuple of tensors
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
scores = tuple(
|
| 189 |
s / logits_temperature for s in scores
|
| 190 |
) # Scales the logits by a factor for normalization during reporting
|
| 191 |
|
| 192 |
print(f"Number of generated tokens: {len(scores)}")
|
| 193 |
print(f"Vocabulary size: {scores[0].shape[1]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
# Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
|
| 195 |
if debug:
|
|
|
|
| 196 |
print("****Running inference in debug mode****")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
# Print first token scores shape and max/min scores in debug mode
|
| 198 |
-
print(f"Single token scores shape: {scores[0].shape}")
|
| 199 |
print(
|
| 200 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
for i in range(3):
|
|
|
|
|
|
|
|
|
|
| 205 |
print(
|
| 206 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
)
|
| 208 |
print(
|
| 209 |
-
f"
|
| 210 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
# Trim the prompt tokens from generated sequences
|
| 213 |
generated_ids_trimmed = [
|
|
@@ -252,37 +479,41 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 252 |
"confidence": confidence,
|
| 253 |
}
|
| 254 |
|
| 255 |
-
#
|
| 256 |
else:
|
| 257 |
token_logits = dict(zip(token_choices, selected_token_logits))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
return {
|
| 259 |
"response": output_response,
|
| 260 |
-
"
|
|
|
|
| 261 |
}
|
| 262 |
|
| 263 |
|
| 264 |
if __name__ == "__main__":
|
| 265 |
model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 266 |
model = Qwen3VLModel(model_path)
|
| 267 |
-
prompt =
|
| 268 |
-
|
| 269 |
video_path = (
|
| 270 |
-
"/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/
|
| 271 |
)
|
| 272 |
-
response = model.chat(prompt, video_path)
|
| 273 |
-
print("Response: ", response)
|
| 274 |
-
|
| 275 |
-
token_choices = ["A", "B"]
|
| 276 |
-
ext = ".webm"
|
| 277 |
-
video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
|
| 278 |
|
| 279 |
generation_config = {
|
| 280 |
"max_new_tokens": 128,
|
|
|
|
| 281 |
"temperature": 0.7,
|
| 282 |
-
"logits_temperature":
|
| 283 |
-
"fps":
|
| 284 |
"return_confidence": False,
|
| 285 |
-
"
|
|
|
|
| 286 |
}
|
| 287 |
output = model.chat_with_confidence(
|
| 288 |
prompt, video_path, token_choices=token_choices, **generation_config
|
|
@@ -294,6 +525,11 @@ if __name__ == "__main__":
|
|
| 294 |
confidence = output["confidence"]
|
| 295 |
print(f"Confidence: {confidence}")
|
| 296 |
else:
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
Qwen3VLForConditionalGeneration,
|
| 6 |
AutoProcessor,
|
| 7 |
)
|
| 8 |
+
from typing import Optional, Dict, Any, Union, List, Tuple
|
| 9 |
from qwen_vl_utils import process_vision_info
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
|
| 14 |
# Handle both relative and absolute imports
|
| 15 |
try:
|
|
|
|
| 18 |
from base import BaseVideoModel
|
| 19 |
|
| 20 |
|
| 21 |
+
def downsample_video(video_path, max_dim=720, num_frames=10):
|
| 22 |
+
vidcap = cv2.VideoCapture(video_path)
|
| 23 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 24 |
+
frames = []
|
| 25 |
+
frame_indices = np.linspace(
|
| 26 |
+
0, total_frames - 1, min(total_frames, num_frames), dtype=int
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
for i in frame_indices:
|
| 30 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 31 |
+
success, image = vidcap.read()
|
| 32 |
+
if success:
|
| 33 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 34 |
+
|
| 35 |
+
h, w = image.shape[:2]
|
| 36 |
+
scale = max_dim / max(h, w)
|
| 37 |
+
if scale < 1:
|
| 38 |
+
image = cv2.resize(
|
| 39 |
+
image,
|
| 40 |
+
(int(w * scale), int(h * scale)),
|
| 41 |
+
interpolation=cv2.INTER_AREA,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
pil_image = Image.fromarray(image)
|
| 45 |
+
frames.append(pil_image)
|
| 46 |
+
|
| 47 |
+
vidcap.release()
|
| 48 |
+
return frames
|
| 49 |
+
|
| 50 |
+
|
| 51 |
class Qwen3VLModel(BaseVideoModel):
|
| 52 |
def __init__(
|
| 53 |
self,
|
|
|
|
| 71 |
video_path: str,
|
| 72 |
fps: float = 1.0,
|
| 73 |
temperature: float = 0.7,
|
| 74 |
+
do_sample: Optional[
|
| 75 |
+
bool
|
| 76 |
+
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 77 |
max_new_tokens: int = 512,
|
| 78 |
+
video_mode: Optional[str] = "video", # Choose from "video" or "frames"
|
| 79 |
+
video_frames: Optional[int] = 10,
|
| 80 |
+
**kwargs: Any,
|
| 81 |
) -> str:
|
| 82 |
# Messages containing a local video path and a text query
|
| 83 |
messages = [
|
| 84 |
{
|
| 85 |
"role": "user",
|
| 86 |
"content": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
{"type": "text", "text": prompt},
|
| 88 |
],
|
| 89 |
}
|
| 90 |
]
|
| 91 |
+
if video_mode == "video":
|
| 92 |
+
messages[0]["content"].append(
|
| 93 |
+
{
|
| 94 |
+
"type": "video",
|
| 95 |
+
"video": video_path,
|
| 96 |
+
# "max_pixels": 360 * 420,
|
| 97 |
+
"fps": fps,
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
inputs = self.processor.apply_chat_template(
|
| 101 |
+
messages,
|
| 102 |
+
tokenize=True,
|
| 103 |
+
add_generation_prompt=True,
|
| 104 |
+
return_dict=True,
|
| 105 |
+
return_tensors="pt",
|
| 106 |
+
)
|
| 107 |
|
| 108 |
+
elif video_mode == "frames":
|
| 109 |
+
frames = downsample_video(video_path, max_dim=720, num_frames=video_frames)
|
| 110 |
+
images_for_processor = []
|
| 111 |
+
for frame in frames:
|
| 112 |
+
messages[0]["content"].append({"type": "image"})
|
| 113 |
+
images_for_processor.append(frame)
|
| 114 |
+
prompt_full = self.processor.apply_chat_template(
|
| 115 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 116 |
+
)
|
| 117 |
+
inputs = self.processor(
|
| 118 |
+
text=[prompt_full],
|
| 119 |
+
images=images_for_processor,
|
| 120 |
+
return_tensors="pt",
|
| 121 |
+
padding=True,
|
| 122 |
+
)
|
| 123 |
|
| 124 |
inputs = inputs.to(self.model.device)
|
| 125 |
|
|
|
|
| 127 |
**inputs,
|
| 128 |
max_new_tokens=max_new_tokens,
|
| 129 |
temperature=temperature,
|
| 130 |
+
do_sample=do_sample,
|
| 131 |
+
**kwargs,
|
| 132 |
)
|
| 133 |
|
| 134 |
generated_ids_trimmed = [
|
|
|
|
| 151 |
fps: float = 1.0,
|
| 152 |
max_new_tokens: int = 512,
|
| 153 |
temperature: float = 0.7,
|
| 154 |
+
do_sample: Optional[
|
| 155 |
+
bool
|
| 156 |
+
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 157 |
token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 158 |
logits_temperature: Optional[float] = 1.0,
|
| 159 |
return_confidence: Optional[bool] = False,
|
| 160 |
+
top_k_tokens: Optional[int] = 10,
|
| 161 |
debug: Optional[bool] = False,
|
| 162 |
+
**kwargs: Any,
|
| 163 |
) -> Dict[str, Any]:
|
| 164 |
"""
|
| 165 |
+
Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
|
| 166 |
|
| 167 |
Args:
|
| 168 |
prompt (str): The text prompt to generate a response for.
|
|
|
|
| 172 |
token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
|
| 173 |
generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
|
| 174 |
return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
|
| 175 |
+
top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
|
| 176 |
debug (bool, optional): Whether to run in debug mode. Defaults to False.
|
| 177 |
|
| 178 |
Returns:
|
| 179 |
+
Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
|
| 180 |
|
| 181 |
e.g., return_confidence: False
|
| 182 |
Output:
|
| 183 |
{
|
| 184 |
"response": "Yes",
|
| 185 |
+
"top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
|
|
|
|
|
|
|
|
|
|
| 186 |
}
|
| 187 |
|
| 188 |
e.g., return_confidence: True
|
|
|
|
| 208 |
}
|
| 209 |
]
|
| 210 |
|
| 211 |
+
inputs = self.processor.apply_chat_template(
|
|
|
|
|
|
|
|
|
|
| 212 |
messages,
|
| 213 |
+
tokenize=True,
|
| 214 |
+
add_generation_prompt=True,
|
| 215 |
+
return_dict=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
return_tensors="pt",
|
|
|
|
|
|
|
| 217 |
)
|
| 218 |
+
|
| 219 |
inputs = inputs.to(self.model.device)
|
| 220 |
|
| 221 |
+
# In debug mode, inspect what logits processors will be used
|
| 222 |
+
if debug:
|
| 223 |
+
print("\n" + "=" * 80)
|
| 224 |
+
print("INSPECTING GENERATION CONFIG & WARPERS")
|
| 225 |
+
print("=" * 80)
|
| 226 |
+
# Get the generation config to see what processors will be added
|
| 227 |
+
gen_config = self.model.generation_config
|
| 228 |
+
print(f"Generation config attributes:")
|
| 229 |
+
print(f" Processor-related:")
|
| 230 |
+
print(
|
| 231 |
+
f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
|
| 232 |
+
)
|
| 233 |
+
print(
|
| 234 |
+
f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
|
| 235 |
+
)
|
| 236 |
+
print(
|
| 237 |
+
f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
|
| 238 |
+
)
|
| 239 |
+
print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
|
| 240 |
+
print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
|
| 241 |
+
print(
|
| 242 |
+
f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
|
| 243 |
+
)
|
| 244 |
+
print(
|
| 245 |
+
f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
|
| 246 |
+
)
|
| 247 |
+
print(f" Warper-related (THESE MASK TOKENS TO -INF):")
|
| 248 |
+
print(f" - temperature: {temperature} (passed as arg)")
|
| 249 |
+
print(
|
| 250 |
+
f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
|
| 251 |
+
)
|
| 252 |
+
print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
|
| 253 |
+
print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
|
| 254 |
+
print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
|
| 255 |
+
print(
|
| 256 |
+
f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
|
| 257 |
+
)
|
| 258 |
+
print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
|
| 259 |
+
print(
|
| 260 |
+
f"\n β οΈ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
|
| 261 |
+
)
|
| 262 |
+
print("=" * 80 + "\n")
|
| 263 |
+
|
| 264 |
# Inference with scores
|
| 265 |
with torch.no_grad():
|
| 266 |
outputs = self.model.generate(
|
| 267 |
**inputs,
|
| 268 |
temperature=temperature,
|
| 269 |
max_new_tokens=max_new_tokens,
|
| 270 |
+
do_sample=do_sample,
|
| 271 |
output_scores=True,
|
| 272 |
+
output_logits=True, # Get TRUE raw logits before any processing
|
| 273 |
return_dict_in_generate=True,
|
| 274 |
+
**kwargs,
|
| 275 |
)
|
| 276 |
|
| 277 |
generated_ids = outputs.sequences
|
| 278 |
+
scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling
|
| 279 |
+
logits = (
|
| 280 |
+
outputs.logits if hasattr(outputs, "logits") else None
|
| 281 |
+
) # TRUE raw logits from model
|
| 282 |
+
|
| 283 |
scores = tuple(
|
| 284 |
s / logits_temperature for s in scores
|
| 285 |
) # Scales the logits by a factor for normalization during reporting
|
| 286 |
|
| 287 |
print(f"Number of generated tokens: {len(scores)}")
|
| 288 |
print(f"Vocabulary size: {scores[0].shape[1]}")
|
| 289 |
+
|
| 290 |
+
# Check if logits differ from scores
|
| 291 |
+
if debug and logits is not None:
|
| 292 |
+
print(f"\n[IMPORTANT] output_logits available: True")
|
| 293 |
+
print(
|
| 294 |
+
f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):"
|
| 295 |
+
)
|
| 296 |
+
logits_raw = logits[0] / logits_temperature # First token's raw logits
|
| 297 |
+
scores_first = scores[0] # First token's processed scores
|
| 298 |
+
|
| 299 |
+
logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs()
|
| 300 |
+
max_diff = logits_diff.max().item()
|
| 301 |
+
if max_diff > 0.001:
|
| 302 |
+
print(
|
| 303 |
+
f"[IMPORTANT] β οΈ outputs.scores ARE DIFFERENT from outputs.logits!"
|
| 304 |
+
)
|
| 305 |
+
print(f"[IMPORTANT] Max difference: {max_diff:.6f}")
|
| 306 |
+
print(
|
| 307 |
+
f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!"
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
print(f"[IMPORTANT] β outputs.scores == outputs.logits (both are raw)")
|
| 311 |
+
elif debug:
|
| 312 |
+
print(
|
| 313 |
+
f"\n[IMPORTANT] output_logits not available in this transformers version"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
# Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
|
| 317 |
if debug:
|
| 318 |
+
print("\n" + "=" * 80)
|
| 319 |
print("****Running inference in debug mode****")
|
| 320 |
+
print("=" * 80)
|
| 321 |
+
|
| 322 |
+
# Use truly raw logits if available, otherwise use scores
|
| 323 |
+
raw_logits_to_show = (
|
| 324 |
+
logits[0] / logits_temperature if logits is not None else scores[0]
|
| 325 |
+
)
|
| 326 |
+
logits_label = (
|
| 327 |
+
"TRUE RAW LOGITS (from outputs.logits)"
|
| 328 |
+
if logits is not None
|
| 329 |
+
else "LOGITS (from outputs.scores)"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
# Print first token scores shape and max/min scores in debug mode
|
|
|
|
| 333 |
print(
|
| 334 |
+
f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}"
|
| 335 |
+
)
|
| 336 |
+
print(
|
| 337 |
+
f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Print details about top 3 tokens from RAW logits
|
| 341 |
+
print(f"\n{'β'*80}")
|
| 342 |
+
print(f"TOP 3 TOKENS FROM {logits_label}:")
|
| 343 |
+
print(f"{'β'*80}")
|
| 344 |
+
top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1)
|
| 345 |
+
for i in range(3):
|
| 346 |
+
token_id = top_3_tokens.indices[0, i].item()
|
| 347 |
+
token_text = self.processor.decode(token_id)
|
| 348 |
+
token_logit = top_3_tokens.values[0, i].item()
|
| 349 |
+
print(
|
| 350 |
+
f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Now compare with POST-PROCESSED logits (outputs.scores)
|
| 354 |
+
scores_first = scores[0] / logits_temperature
|
| 355 |
+
print(f"\n{'β'*80}")
|
| 356 |
+
print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):")
|
| 357 |
+
print(f"{'β'*80}")
|
| 358 |
+
print(
|
| 359 |
+
f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}"
|
| 360 |
)
|
| 361 |
+
|
| 362 |
+
top_3_processed = torch.topk(scores_first, k=3, dim=-1)
|
| 363 |
for i in range(3):
|
| 364 |
+
token_id = top_3_processed.indices[0, i].item()
|
| 365 |
+
token_text = self.processor.decode(token_id)
|
| 366 |
+
token_logit = top_3_processed.values[0, i].item()
|
| 367 |
print(
|
| 368 |
+
f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Check if the distributions differ (compare against truly raw logits if available)
|
| 372 |
+
print(f"\n{'β'*80}")
|
| 373 |
+
print("DIFFERENCE ANALYSIS (Raw β Post-Processed):")
|
| 374 |
+
print(f"{'β'*80}")
|
| 375 |
+
logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs()
|
| 376 |
+
max_diff = logit_diff.max().item()
|
| 377 |
+
num_changed = (logit_diff > 0.001).sum().item()
|
| 378 |
+
|
| 379 |
+
print(f" Max logit difference: {max_diff:.6f}")
|
| 380 |
+
print(
|
| 381 |
+
f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if max_diff > 0.001:
|
| 385 |
+
print(f"\n β οΈ LOGITS WERE MODIFIED BY PROCESSORS!")
|
| 386 |
+
# Show which tokens changed the most
|
| 387 |
+
top_changes = torch.topk(logit_diff[0], k=min(5, num_changed))
|
| 388 |
+
print(f"\n Top 5 most changed tokens:")
|
| 389 |
+
for i in range(min(5, len(top_changes.indices))):
|
| 390 |
+
token_id = top_changes.indices[i].item()
|
| 391 |
+
token_text = self.processor.decode(token_id)
|
| 392 |
+
raw_logit = raw_logits_to_show[0, token_id].item()
|
| 393 |
+
processed_logit = scores_first[0, token_id].item()
|
| 394 |
+
diff = top_changes.values[i].item()
|
| 395 |
+
print(f" Token='{token_text}' | ID={token_id}")
|
| 396 |
+
print(
|
| 397 |
+
f" Raw: {raw_logit:.4f} β Processed: {processed_logit:.4f} (Ξ={diff:.4f})"
|
| 398 |
+
)
|
| 399 |
+
else:
|
| 400 |
+
print(f" β No significant modifications detected")
|
| 401 |
+
|
| 402 |
+
# Show what token was actually selected
|
| 403 |
+
print(f"\n{'β'*80}")
|
| 404 |
+
print("ACTUALLY GENERATED TOKEN:")
|
| 405 |
+
print(f"{'β'*80}")
|
| 406 |
+
first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item()
|
| 407 |
+
first_generated_token = self.processor.decode(first_generated_id)
|
| 408 |
+
raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item()
|
| 409 |
+
|
| 410 |
+
print(f" Token: '{first_generated_token}' | ID={first_generated_id}")
|
| 411 |
+
print(f" Raw logit: {raw_logit_for_generated:.4f}")
|
| 412 |
+
|
| 413 |
+
processed_logit_for_generated = scores_first[0, first_generated_id].item()
|
| 414 |
+
print(f" Post-processed logit: {processed_logit_for_generated:.4f}")
|
| 415 |
+
|
| 416 |
+
# Check if this token is in top-k of raw logits
|
| 417 |
+
top_k_raw_indices = torch.topk(
|
| 418 |
+
raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1
|
| 419 |
+
).indices[0]
|
| 420 |
+
is_in_top10_raw = first_generated_id in top_k_raw_indices
|
| 421 |
+
print(f" In top-10 of RAW logits: {is_in_top10_raw}")
|
| 422 |
+
|
| 423 |
+
if not is_in_top10_raw:
|
| 424 |
+
print(
|
| 425 |
+
f"\n π¨ CRITICAL: Generated token was NOT in top-10 of raw logits!"
|
| 426 |
)
|
| 427 |
print(
|
| 428 |
+
f" This proves that logits processors modified the distribution."
|
| 429 |
)
|
| 430 |
+
# Find the rank of the generated token in raw logits
|
| 431 |
+
sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True)
|
| 432 |
+
raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[
|
| 433 |
+
0
|
| 434 |
+
].item() + 1
|
| 435 |
+
print(f" Raw logits rank: {raw_rank}")
|
| 436 |
+
|
| 437 |
+
print("=" * 80 + "\n")
|
| 438 |
|
| 439 |
# Trim the prompt tokens from generated sequences
|
| 440 |
generated_ids_trimmed = [
|
|
|
|
| 479 |
"confidence": confidence,
|
| 480 |
}
|
| 481 |
|
| 482 |
+
# Return token logits
|
| 483 |
else:
|
| 484 |
token_logits = dict(zip(token_choices, selected_token_logits))
|
| 485 |
+
top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
|
| 486 |
+
top_k_tokens_list: List[Tuple[str, int, float]] = []
|
| 487 |
+
for i in range(top_k_tokens):
|
| 488 |
+
logit_index = top_k_logits_indices.indices[0, i].item()
|
| 489 |
+
token = self.processor.decode(logit_index)
|
| 490 |
+
logit = top_k_logits_indices.values[0, i].item()
|
| 491 |
+
top_k_tokens_list.append((token, logit_index, logit))
|
| 492 |
return {
|
| 493 |
"response": output_response,
|
| 494 |
+
"top_k_tokens": top_k_tokens_list,
|
| 495 |
+
"token_logits": token_logits,
|
| 496 |
}
|
| 497 |
|
| 498 |
|
| 499 |
if __name__ == "__main__":
|
| 500 |
model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 501 |
model = Qwen3VLModel(model_path)
|
| 502 |
+
prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:'
|
| 503 |
+
token_choices = ["Yes", "No"]
|
| 504 |
video_path = (
|
| 505 |
+
"/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4"
|
| 506 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
generation_config = {
|
| 509 |
"max_new_tokens": 128,
|
| 510 |
+
"do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
|
| 511 |
"temperature": 0.7,
|
| 512 |
+
"logits_temperature": 1.0,
|
| 513 |
+
"fps": 1.0,
|
| 514 |
"return_confidence": False,
|
| 515 |
+
"top_k_tokens": 10,
|
| 516 |
+
"debug": False,
|
| 517 |
}
|
| 518 |
output = model.chat_with_confidence(
|
| 519 |
prompt, video_path, token_choices=token_choices, **generation_config
|
|
|
|
| 525 |
confidence = output["confidence"]
|
| 526 |
print(f"Confidence: {confidence}")
|
| 527 |
else:
|
| 528 |
+
# If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
|
| 529 |
+
# otherwise, the raw logits are used, which are not filtered.
|
| 530 |
+
logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
|
| 531 |
+
top_k_tokens = output["top_k_tokens"]
|
| 532 |
+
for i in range(len(top_k_tokens)):
|
| 533 |
+
print(f"Top {i+1} token: {top_k_tokens[i][0]}")
|
| 534 |
+
print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
|
| 535 |
+
print("--------------------------------")
|