Commit Β·
ef7643d
1
Parent(s): cf5f08b
Add transformers v5 integration to models
Browse files- app.py +2 -2
- models/__init__.py +25 -92
- models/base.py +8 -8
- models/llava_video.py +222 -243
- models/qwen2_5vl.py +232 -222
- models/qwen3vl.py +431 -477
app.py
CHANGED
|
@@ -12,7 +12,7 @@ from models.base import BaseVideoModel
|
|
| 12 |
# ----------------------
|
| 13 |
# CONFIG
|
| 14 |
# ----------------------
|
| 15 |
-
MODEL_PATH = "
|
| 16 |
DEVICE_MAP = "cuda:0"
|
| 17 |
|
| 18 |
VIDEO_DIR = str(Path(__file__).parent / "videos")
|
|
@@ -130,7 +130,7 @@ with gr.Blocks(title="Video QA β LLaVa-Video-7B-Qwen2", theme=gr.themes.Soft()
|
|
| 130 |
|
| 131 |
fps_slider = gr.Slider(
|
| 132 |
minimum=0.5,
|
| 133 |
-
maximum=
|
| 134 |
step=0.5,
|
| 135 |
value=FPS,
|
| 136 |
label="ποΈ Frames Per Second (FPS)",
|
|
|
|
| 12 |
# ----------------------
|
| 13 |
# CONFIG
|
| 14 |
# ----------------------
|
| 15 |
+
MODEL_PATH = "Isotr0py/LLaVA-Video-7B-Qwen2-hf"
|
| 16 |
DEVICE_MAP = "cuda:0"
|
| 17 |
|
| 18 |
VIDEO_DIR = str(Path(__file__).parent / "videos")
|
|
|
|
| 130 |
|
| 131 |
fps_slider = gr.Slider(
|
| 132 |
minimum=0.5,
|
| 133 |
+
maximum=10.0,
|
| 134 |
step=0.5,
|
| 135 |
value=FPS,
|
| 136 |
label="ποΈ Frames Per Second (FPS)",
|
models/__init__.py
CHANGED
|
@@ -3,66 +3,27 @@ from packaging import version
|
|
| 3 |
import torch
|
| 4 |
from typing import Optional, Union, Dict
|
| 5 |
|
|
|
|
| 6 |
# IMP: Add required versions here
|
| 7 |
-
|
| 8 |
-
internvl_required_version = version.parse("4.45.0")
|
| 9 |
-
llava_required_version = version.parse("4.40.0")
|
| 10 |
|
| 11 |
# Conditional imports based on transformers version
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from transformers.generation.logits_process import LogitsProcessor
|
| 15 |
-
|
| 16 |
-
# Check transformers version
|
| 17 |
-
transformers_version = version.parse(transformers.__version__)
|
| 18 |
-
|
| 19 |
-
QWEN_MODELS_AVAILABLE = False
|
| 20 |
-
INTERNVL_MODELS_AVAILABLE = False
|
| 21 |
-
LLAVA_MODELS_AVAILABLE = False
|
| 22 |
-
|
| 23 |
-
# Qwen condition
|
| 24 |
-
if transformers_version >= qwen_required_version:
|
| 25 |
-
from .qwen2_5vl import Qwen2_5VLModel
|
| 26 |
-
from .qwen3vl import Qwen3VLModel
|
| 27 |
-
|
| 28 |
-
QWEN_MODELS_AVAILABLE = True
|
| 29 |
-
else:
|
| 30 |
-
print(
|
| 31 |
-
f"Warning: Qwen models require transformers>=4.57.0, but found {transformers.__version__}. Qwen models will not be available. Please upgrade to transformers>=4.57.0 or switch conda environments to use Qwen models."
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
# InternVL condition
|
| 35 |
-
if transformers_version >= internvl_required_version:
|
| 36 |
-
from .internvl import InternVLModel
|
| 37 |
-
|
| 38 |
-
INTERNVL_MODELS_AVAILABLE = True
|
| 39 |
-
else:
|
| 40 |
-
print(
|
| 41 |
-
f"Warning: InternVL models require transformers>=4.45.0, but found {transformers.__version__}. InternVL models will not be available. Please downgrade to transformers<=4.45.0 or switch conda environments to use InternVL models."
|
| 42 |
-
)
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
from .llava_video import LLaVAVideoModel
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
print(
|
| 55 |
-
"Warning: Could not check transformers version. Please re-check transformers installation."
|
| 56 |
-
)
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
__all__.extend(["Qwen2_5VLModel", "Qwen3VLModel"])
|
| 62 |
-
if INTERNVL_MODELS_AVAILABLE:
|
| 63 |
-
__all__.append("InternVLModel")
|
| 64 |
-
if LLAVA_MODELS_AVAILABLE:
|
| 65 |
-
__all__.append("LLaVAVideoModel")
|
| 66 |
|
| 67 |
|
| 68 |
# Function to get the model by mapping model ID to the correct model class
|
|
@@ -71,31 +32,27 @@ def load_model(
|
|
| 71 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 72 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 73 |
attn_implementation: Optional[str] = "flash_attention_2",
|
|
|
|
|
|
|
| 74 |
) -> BaseVideoModel:
|
| 75 |
if "LLaVA-Video" in model_path:
|
| 76 |
-
if not LLAVA_MODELS_AVAILABLE:
|
| 77 |
-
raise ImportError(
|
| 78 |
-
f"LLaVA models require transformers<=4.40.0."
|
| 79 |
-
f"Please downgrade transformers: pip install transformers<=4.40.0"
|
| 80 |
-
)
|
| 81 |
return LLaVAVideoModel(
|
| 82 |
model_path,
|
| 83 |
dtype=dtype,
|
| 84 |
device_map=device_map,
|
| 85 |
attn_implementation=attn_implementation,
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
elif "Qwen" in model_path:
|
| 88 |
-
if not QWEN_MODELS_AVAILABLE:
|
| 89 |
-
raise ImportError(
|
| 90 |
-
f"Qwen models require transformers>=4.57.0."
|
| 91 |
-
f"Please upgrade transformers: pip install transformers>=4.57.0"
|
| 92 |
-
)
|
| 93 |
if "Qwen3" in model_path:
|
| 94 |
return Qwen3VLModel(
|
| 95 |
model_path,
|
| 96 |
dtype=dtype,
|
| 97 |
device_map=device_map,
|
| 98 |
attn_implementation=attn_implementation,
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
else:
|
| 101 |
return Qwen2_5VLModel(
|
|
@@ -103,39 +60,15 @@ def load_model(
|
|
| 103 |
dtype=dtype,
|
| 104 |
device_map=device_map,
|
| 105 |
attn_implementation=attn_implementation,
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
elif "Intern" in model_path:
|
| 108 |
-
if not INTERNVL_MODELS_AVAILABLE:
|
| 109 |
-
raise ImportError(
|
| 110 |
-
f"InternVL models require transformers>=4.45.0."
|
| 111 |
-
f"Please upgrade transformers: pip install transformers>=4.45.0"
|
| 112 |
-
)
|
| 113 |
return InternVLModel(
|
| 114 |
model_path,
|
| 115 |
dtype=dtype,
|
| 116 |
device_map=device_map,
|
| 117 |
attn_implementation=attn_implementation,
|
|
|
|
|
|
|
| 118 |
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class LogitsCaptureProcessor(LogitsProcessor):
|
| 122 |
-
"""
|
| 123 |
-
Custom LogitsProcessor that captures the processed logits right before sampling.
|
| 124 |
-
This allows us to see what the actual distribution looks like after all other
|
| 125 |
-
processors have been applied.
|
| 126 |
-
"""
|
| 127 |
-
|
| 128 |
-
def __init__(self):
|
| 129 |
-
self.captured_logits = []
|
| 130 |
-
|
| 131 |
-
def __call__(
|
| 132 |
-
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
| 133 |
-
) -> torch.FloatTensor:
|
| 134 |
-
# Store a copy of the logits at this point in generation
|
| 135 |
-
self.captured_logits.append(scores.detach().clone().cpu())
|
| 136 |
-
# Return scores unchanged - we're just observing
|
| 137 |
-
return scores
|
| 138 |
-
|
| 139 |
-
def reset(self):
|
| 140 |
-
"""Clear captured logits for a new generation"""
|
| 141 |
-
self.captured_logits = []
|
|
|
|
| 3 |
import torch
|
| 4 |
from typing import Optional, Union, Dict
|
| 5 |
|
| 6 |
+
|
| 7 |
# IMP: Add required versions here
|
| 8 |
+
transformers_required_version = version.parse("5.0.0")
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Conditional imports based on transformers version
|
| 11 |
+
import transformers
|
| 12 |
+
from transformers import BitsAndBytesConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
# Check transformers version
|
| 15 |
+
transformers_version = version.parse(transformers.__version__)
|
|
|
|
| 16 |
|
| 17 |
+
# transformers v5 condition
|
| 18 |
+
if transformers_version >= transformers_required_version:
|
| 19 |
+
from .qwen2_5vl import Qwen2_5VLModel
|
| 20 |
+
from .qwen3vl import Qwen3VLModel
|
| 21 |
+
from .internvl import InternVLModel
|
| 22 |
+
from .llava_video import LLaVAVideoModel
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
TRANSFORMERS_MODELS_AVAILABLE = True
|
| 25 |
+
else:
|
| 26 |
+
raise ValueError(f"Transformers v5 models require transformers>=5.0.0, but found {transformers.__version__}. Transformers v5 models will not be available. Please upgrade to transformers>=5.0.0 or switch conda environments to use Transformers v5 models.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# Function to get the model by mapping model ID to the correct model class
|
|
|
|
| 32 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 33 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 34 |
attn_implementation: Optional[str] = "flash_attention_2",
|
| 35 |
+
load_8bit: Optional[bool] = False,
|
| 36 |
+
load_4bit: Optional[bool] = False,
|
| 37 |
) -> BaseVideoModel:
|
| 38 |
if "LLaVA-Video" in model_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
return LLaVAVideoModel(
|
| 40 |
model_path,
|
| 41 |
dtype=dtype,
|
| 42 |
device_map=device_map,
|
| 43 |
attn_implementation=attn_implementation,
|
| 44 |
+
load_8bit=load_8bit,
|
| 45 |
+
load_4bit=load_4bit,
|
| 46 |
)
|
| 47 |
elif "Qwen" in model_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
if "Qwen3" in model_path:
|
| 49 |
return Qwen3VLModel(
|
| 50 |
model_path,
|
| 51 |
dtype=dtype,
|
| 52 |
device_map=device_map,
|
| 53 |
attn_implementation=attn_implementation,
|
| 54 |
+
load_8bit=load_8bit,
|
| 55 |
+
load_4bit=load_4bit,
|
| 56 |
)
|
| 57 |
else:
|
| 58 |
return Qwen2_5VLModel(
|
|
|
|
| 60 |
dtype=dtype,
|
| 61 |
device_map=device_map,
|
| 62 |
attn_implementation=attn_implementation,
|
| 63 |
+
load_8bit=load_8bit,
|
| 64 |
+
load_4bit=load_4bit,
|
| 65 |
)
|
| 66 |
elif "Intern" in model_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return InternVLModel(
|
| 68 |
model_path,
|
| 69 |
dtype=dtype,
|
| 70 |
device_map=device_map,
|
| 71 |
attn_implementation=attn_implementation,
|
| 72 |
+
load_8bit=load_8bit,
|
| 73 |
+
load_4bit=load_4bit,
|
| 74 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/base.py
CHANGED
|
@@ -17,11 +17,11 @@ class BaseVideoModel(ABC):
|
|
| 17 |
) -> str:
|
| 18 |
pass
|
| 19 |
|
| 20 |
-
@abstractmethod
|
| 21 |
-
def chat_with_confidence(
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
) -> Dict[str, Union[str, float]]:
|
| 27 |
-
|
|
|
|
| 17 |
) -> str:
|
| 18 |
pass
|
| 19 |
|
| 20 |
+
# @abstractmethod
|
| 21 |
+
# def chat_with_confidence(
|
| 22 |
+
# self,
|
| 23 |
+
# prompt: str,
|
| 24 |
+
# video_path: str,
|
| 25 |
+
# generation_config: Optional[Dict[str, Any]] = None,
|
| 26 |
+
# ) -> Dict[str, Union[str, float]]:
|
| 27 |
+
# pass
|
models/llava_video.py
CHANGED
|
@@ -1,26 +1,11 @@
|
|
| 1 |
# Run with `conda activate llava`
|
| 2 |
-
|
| 3 |
-
from llava.mm_utils import (
|
| 4 |
-
get_model_name_from_path,
|
| 5 |
-
process_images,
|
| 6 |
-
tokenizer_image_token,
|
| 7 |
-
)
|
| 8 |
-
from llava.constants import (
|
| 9 |
-
IMAGE_TOKEN_INDEX,
|
| 10 |
-
DEFAULT_IMAGE_TOKEN,
|
| 11 |
-
DEFAULT_IM_START_TOKEN,
|
| 12 |
-
DEFAULT_IM_END_TOKEN,
|
| 13 |
-
IGNORE_INDEX,
|
| 14 |
-
)
|
| 15 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
| 16 |
-
from PIL import Image
|
| 17 |
-
import requests
|
| 18 |
import copy
|
| 19 |
import torch
|
| 20 |
-
from typing import Optional, Union, Dict, List, Tuple, Any
|
| 21 |
-
import warnings
|
| 22 |
-
from decord import VideoReader, cpu
|
| 23 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Handle both relative and absolute imports
|
| 26 |
try:
|
|
@@ -30,46 +15,37 @@ except ImportError:
|
|
| 30 |
|
| 31 |
warnings.filterwarnings("ignore")
|
| 32 |
|
| 33 |
-
|
| 34 |
class LLaVAVideoModel(BaseVideoModel):
|
| 35 |
def __init__(
|
| 36 |
self,
|
| 37 |
-
model_name: str = "
|
| 38 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 39 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 40 |
attn_implementation: Optional[str] = "flash_attention_2",
|
|
|
|
|
|
|
| 41 |
):
|
| 42 |
super().__init__(model_name)
|
| 43 |
-
base_model = "llava_qwen"
|
| 44 |
self.dtype = dtype
|
| 45 |
-
# Convert torch dtype to string for safety, since LLaVA-Video only accepts torch_dtype as a string
|
| 46 |
-
if dtype == torch.bfloat16:
|
| 47 |
-
torch_dtype = "bfloat16"
|
| 48 |
-
elif dtype == torch.float16:
|
| 49 |
-
torch_dtype = "float16"
|
| 50 |
-
|
| 51 |
-
self.tokenizer, self.model, self.image_processor, max_length = (
|
| 52 |
-
load_pretrained_model(
|
| 53 |
-
model_name,
|
| 54 |
-
None,
|
| 55 |
-
base_model,
|
| 56 |
-
torch_dtype=torch_dtype,
|
| 57 |
-
device_map=device_map,
|
| 58 |
-
)
|
| 59 |
-
) # Add any other thing you want to pass in llava_model_args
|
| 60 |
-
self.model.eval()
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
if
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
if hasattr(self.model, "get_model"):
|
| 70 |
-
model_inner = self.model.get_model()
|
| 71 |
-
if hasattr(model_inner, "mm_projector"):
|
| 72 |
-
model_inner.mm_projector.to(self.model.device)
|
| 73 |
|
| 74 |
def load_video(
|
| 75 |
self,
|
|
@@ -101,224 +77,227 @@ class LLaVAVideoModel(BaseVideoModel):
|
|
| 101 |
self,
|
| 102 |
prompt: str,
|
| 103 |
video_path: str,
|
| 104 |
-
fps: float = 1.0,
|
| 105 |
max_new_tokens: int = 512,
|
| 106 |
do_sample: Optional[
|
| 107 |
bool
|
| 108 |
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 109 |
temperature: float = 0.7,
|
| 110 |
video_mode: Optional[str] = "video",
|
| 111 |
-
|
|
|
|
| 112 |
**kwargs: Any,
|
| 113 |
) -> str:
|
|
|
|
| 114 |
if video_mode == "frames":
|
| 115 |
-
|
| 116 |
elif video_mode == "video":
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
input_ids = (
|
| 131 |
-
tokenizer_image_token(
|
| 132 |
-
prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 133 |
-
)
|
| 134 |
-
.unsqueeze(0)
|
| 135 |
-
.to(self.model.device)
|
| 136 |
-
)
|
| 137 |
-
cont = self.model.generate(
|
| 138 |
-
input_ids,
|
| 139 |
-
images=video,
|
| 140 |
-
modalities=["video"],
|
| 141 |
-
do_sample=do_sample,
|
| 142 |
-
temperature=temperature,
|
| 143 |
-
max_new_tokens=max_new_tokens,
|
| 144 |
-
**kwargs,
|
| 145 |
-
)
|
| 146 |
-
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[
|
| 147 |
-
0
|
| 148 |
-
].strip()
|
| 149 |
-
return text_outputs
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 162 |
-
logits_temperature: Optional[float] = 1.0,
|
| 163 |
-
return_confidence: Optional[bool] = False,
|
| 164 |
-
top_k_tokens: Optional[int] = 10,
|
| 165 |
-
debug: Optional[bool] = False,
|
| 166 |
-
) -> Dict[str, Any]:
|
| 167 |
-
video, _, _ = self.load_video(video_path, fps)
|
| 168 |
-
video = self.image_processor.preprocess(video, return_tensors="pt")[
|
| 169 |
-
"pixel_values"
|
| 170 |
-
].to(device=self.model.device, dtype=self.dtype)
|
| 171 |
-
video = [video]
|
| 172 |
-
conv_template = (
|
| 173 |
-
"qwen_1_5" # Make sure you use correct chat template for different models
|
| 174 |
-
)
|
| 175 |
-
question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
|
| 176 |
-
conv = copy.deepcopy(conv_templates[conv_template])
|
| 177 |
-
conv.append_message(conv.roles[0], question)
|
| 178 |
-
conv.append_message(conv.roles[1], None)
|
| 179 |
-
prompt_question = conv.get_prompt()
|
| 180 |
-
input_ids = (
|
| 181 |
-
tokenizer_image_token(
|
| 182 |
-
prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 183 |
-
)
|
| 184 |
-
.unsqueeze(0)
|
| 185 |
-
.to(self.model.device)
|
| 186 |
-
)
|
| 187 |
with torch.no_grad():
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
modalities=["video"],
|
| 192 |
-
do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 193 |
temperature=temperature,
|
| 194 |
max_new_tokens=max_new_tokens,
|
| 195 |
-
|
| 196 |
-
return_dict_in_generate=True,
|
| 197 |
)
|
| 198 |
-
|
| 199 |
-
|
|
|
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
score_index = top_k_tokens_scores.indices[0, i].item()
|
| 223 |
-
token = self.tokenizer.decode(score_index)
|
| 224 |
-
print(f"#{i+1}th Token: {token}")
|
| 225 |
-
print(f"#{i+1}th Token index: {score_index}")
|
| 226 |
-
print(f"#{i+1}th Token score: {score}")
|
| 227 |
-
print("--------------------------------")
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
]
|
| 248 |
-
selected_token_probs.append(first_token_probs[0, token_index].item())
|
| 249 |
-
selected_token_logits.append(scores[0][0, token_index].item())
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
|
| 282 |
|
| 283 |
-
if __name__ == "__main__":
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
| 1 |
# Run with `conda activate llava`
|
| 2 |
+
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import copy
|
| 4 |
import torch
|
|
|
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
|
| 7 |
+
from typing import Optional, Dict, Any, Union, List
|
| 8 |
+
from decord import VideoReader, cpu
|
| 9 |
|
| 10 |
# Handle both relative and absolute imports
|
| 11 |
try:
|
|
|
|
| 15 |
|
| 16 |
warnings.filterwarnings("ignore")
|
| 17 |
|
|
|
|
| 18 |
class LLaVAVideoModel(BaseVideoModel):
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
+
model_name: str = "Isotr0py/LLaVA-Video-7B-Qwen2-hf",
|
| 22 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 23 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 24 |
attn_implementation: Optional[str] = "flash_attention_2",
|
| 25 |
+
load_8bit: Optional[bool] = False,
|
| 26 |
+
load_4bit: Optional[bool] = False,
|
| 27 |
):
|
| 28 |
super().__init__(model_name)
|
|
|
|
| 29 |
self.dtype = dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# For quantized models (8-bit or 4-bit), device_map must be "auto" or a dict, not a device string
|
| 32 |
+
quantization_config = None
|
| 33 |
+
if load_8bit or load_4bit:
|
| 34 |
+
quantization_config = BitsAndBytesConfig(
|
| 35 |
+
load_in_8bit=load_8bit,
|
| 36 |
+
load_in_4bit=load_4bit,
|
| 37 |
+
bnb_4bit_quant_type="nf4",
|
| 38 |
+
bnb_4bit_compute_dtype=torch.float16
|
| 39 |
+
)
|
| 40 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 41 |
+
model_name,
|
| 42 |
+
quantization_config=quantization_config,
|
| 43 |
+
device_map=device_map,
|
| 44 |
+
attn_implementation=attn_implementation,
|
| 45 |
+
dtype=dtype,
|
| 46 |
+
)
|
| 47 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def load_video(
|
| 51 |
self,
|
|
|
|
| 77 |
self,
|
| 78 |
prompt: str,
|
| 79 |
video_path: str,
|
|
|
|
| 80 |
max_new_tokens: int = 512,
|
| 81 |
do_sample: Optional[
|
| 82 |
bool
|
| 83 |
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 84 |
temperature: float = 0.7,
|
| 85 |
video_mode: Optional[str] = "video",
|
| 86 |
+
fps: Optional[float] = 1.0,
|
| 87 |
+
num_frames: Optional[int] = 10,
|
| 88 |
**kwargs: Any,
|
| 89 |
) -> str:
|
| 90 |
+
# Ensure only one of fps or num_frames is provided
|
| 91 |
if video_mode == "frames":
|
| 92 |
+
fps = None
|
| 93 |
elif video_mode == "video":
|
| 94 |
+
num_frames = None
|
| 95 |
+
conversation = [
|
| 96 |
+
{
|
| 97 |
+
"role": "user",
|
| 98 |
+
"content": [
|
| 99 |
+
{
|
| 100 |
+
"type": "video",
|
| 101 |
+
"video": video_path,
|
| 102 |
+
},
|
| 103 |
+
{"type": "text", "text": prompt}
|
| 104 |
+
],
|
| 105 |
+
},
|
| 106 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
inputs = self.processor.apply_chat_template(
|
| 109 |
+
conversation,
|
| 110 |
+
add_generation_prompt=True,
|
| 111 |
+
tokenize=True,
|
| 112 |
+
return_dict=True,
|
| 113 |
+
return_tensors="pt",
|
| 114 |
+
do_sample_frames=True,
|
| 115 |
+
fps=fps,
|
| 116 |
+
num_frames=num_frames
|
| 117 |
+
).to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
with torch.no_grad():
|
| 119 |
+
out = self.model.generate(
|
| 120 |
+
**inputs,
|
| 121 |
+
do_sample=do_sample,
|
|
|
|
|
|
|
| 122 |
temperature=temperature,
|
| 123 |
max_new_tokens=max_new_tokens,
|
| 124 |
+
**kwargs,
|
|
|
|
| 125 |
)
|
| 126 |
+
raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
| 127 |
+
response = raw_response.split("assistant")[1].strip()
|
| 128 |
+
return response
|
| 129 |
|
| 130 |
+
# def chat_with_confidence(
|
| 131 |
+
# self,
|
| 132 |
+
# prompt: str,
|
| 133 |
+
# video_path: str,
|
| 134 |
+
# fps: float = 1.0,
|
| 135 |
+
# max_new_tokens: int = 512,
|
| 136 |
+
# temperature: float = 0.7,
|
| 137 |
+
# do_sample: Optional[
|
| 138 |
+
# bool
|
| 139 |
+
# ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 140 |
+
# token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 141 |
+
# logits_temperature: Optional[float] = 1.0,
|
| 142 |
+
# return_confidence: Optional[bool] = False,
|
| 143 |
+
# top_k_tokens: Optional[int] = 10,
|
| 144 |
+
# debug: Optional[bool] = False,
|
| 145 |
+
# ) -> Dict[str, Any]:
|
| 146 |
+
# video, _, _ = self.load_video(video_path, fps)
|
| 147 |
+
# video = self.image_processor.preprocess(video, return_tensors="pt")[
|
| 148 |
+
# "pixel_values"
|
| 149 |
+
# ].to(device=self.model.device, dtype=self.dtype)
|
| 150 |
+
# video = [video]
|
| 151 |
+
# conv_template = (
|
| 152 |
+
# "qwen_1_5" # Make sure you use correct chat template for different models
|
| 153 |
+
# )
|
| 154 |
+
# question = DEFAULT_IMAGE_TOKEN + f"\n{prompt}"
|
| 155 |
+
# conv = copy.deepcopy(conv_templates[conv_template])
|
| 156 |
+
# conv.append_message(conv.roles[0], question)
|
| 157 |
+
# conv.append_message(conv.roles[1], None)
|
| 158 |
+
# prompt_question = conv.get_prompt()
|
| 159 |
+
# input_ids = (
|
| 160 |
+
# tokenizer_image_token(
|
| 161 |
+
# prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 162 |
+
# )
|
| 163 |
+
# .unsqueeze(0)
|
| 164 |
+
# .to(self.model.device)
|
| 165 |
+
# )
|
| 166 |
+
# with torch.no_grad():
|
| 167 |
+
# outputs = self.model.generate(
|
| 168 |
+
# input_ids,
|
| 169 |
+
# images=video,
|
| 170 |
+
# modalities=["video"],
|
| 171 |
+
# do_sample=do_sample, # Was set to False, i.e., greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 172 |
+
# temperature=temperature,
|
| 173 |
+
# max_new_tokens=max_new_tokens,
|
| 174 |
+
# output_scores=True,
|
| 175 |
+
# return_dict_in_generate=True,
|
| 176 |
+
# )
|
| 177 |
+
# generated_ids = outputs.sequences
|
| 178 |
+
# scores = outputs.scores # Tuple of tensors, one per generated token
|
| 179 |
|
| 180 |
+
# print(f"Number of generated tokens: {len(scores)}")
|
| 181 |
+
# print(f"Vocabulary size: {scores[0].shape[1]}")
|
| 182 |
+
# # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
|
| 183 |
+
# if debug:
|
| 184 |
+
# print("****Running inference in debug mode****")
|
| 185 |
+
# # Print first token scores shape and max/min scores in debug mode
|
| 186 |
+
# print(f"Single token scores shape: {scores[0].shape}")
|
| 187 |
+
# print(
|
| 188 |
+
# f"Max score: {scores[0].max().item():.4f} | Min score: {scores[0].min().item():.4f}"
|
| 189 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
+
# # Print details about top 10 tokens based on logits
|
| 192 |
+
# logits_type = "POST-PROCESSED" if do_sample is True else "RAW"
|
| 193 |
+
# print(f"\n{'β'*80}")
|
| 194 |
+
# print(
|
| 195 |
+
# f"TOP {top_k_tokens} TOKENS FROM {logits_type} LOGITS (outputs.scores):"
|
| 196 |
+
# )
|
| 197 |
+
# print(f"{'β'*80}")
|
| 198 |
+
# top_k_tokens_scores = torch.topk(scores[0], k=top_k_tokens, dim=-1)
|
| 199 |
+
# for i in range(top_k_tokens):
|
| 200 |
+
# score = top_k_tokens_scores.values[0, i].item()
|
| 201 |
+
# score_index = top_k_tokens_scores.indices[0, i].item()
|
| 202 |
+
# token = self.tokenizer.decode(score_index)
|
| 203 |
+
# print(f"#{i+1}th Token: {token}")
|
| 204 |
+
# print(f"#{i+1}th Token index: {score_index}")
|
| 205 |
+
# print(f"#{i+1}th Token score: {score}")
|
| 206 |
+
# print("--------------------------------")
|
| 207 |
|
| 208 |
+
# # Decode the text
|
| 209 |
+
# output_response = self.tokenizer.batch_decode(
|
| 210 |
+
# generated_ids,
|
| 211 |
+
# skip_special_tokens=True,
|
| 212 |
+
# clean_up_tokenization_spaces=False,
|
| 213 |
+
# )[0]
|
| 214 |
|
| 215 |
+
# # Convert scores to probabilities
|
| 216 |
+
# # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
|
| 217 |
+
# selected_token_probs = []
|
| 218 |
+
# selected_token_logits = []
|
| 219 |
+
# first_token_probs = torch.softmax(scores[0], dim=-1)
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
# # Now, find indices of tokens in token_choices and get their probabilities
|
| 222 |
+
# for token_choice in token_choices:
|
| 223 |
+
# # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
|
| 224 |
+
# token_index = self.tokenizer.encode(token_choice, add_special_tokens=False)[
|
| 225 |
+
# 0
|
| 226 |
+
# ]
|
| 227 |
+
# selected_token_probs.append(first_token_probs[0, token_index].item())
|
| 228 |
+
# selected_token_logits.append(scores[0][0, token_index].item())
|
| 229 |
+
|
| 230 |
+
# # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
|
| 231 |
+
# if return_confidence:
|
| 232 |
+
# first_token_id = generated_ids[0][
|
| 233 |
+
# 0
|
| 234 |
+
# ].item() # First token of the first sequence
|
| 235 |
+
# confidence = (
|
| 236 |
+
# first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
|
| 237 |
+
# if sum(selected_token_probs) > 0
|
| 238 |
+
# else 0.0
|
| 239 |
+
# )
|
| 240 |
+
# return {
|
| 241 |
+
# "response": output_response,
|
| 242 |
+
# "confidence": confidence,
|
| 243 |
+
# }
|
| 244 |
|
| 245 |
+
# # Return token logits
|
| 246 |
+
# else:
|
| 247 |
+
# token_logits = dict(zip(token_choices, selected_token_logits))
|
| 248 |
+
# top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
|
| 249 |
+
# top_k_tokens_list: List[Tuple[str, int, float]] = []
|
| 250 |
+
# for i in range(top_k_tokens):
|
| 251 |
+
# logit_index = top_k_logits_indices.indices[0, i].item()
|
| 252 |
+
# token = self.tokenizer.decode(logit_index)
|
| 253 |
+
# logit = top_k_logits_indices.values[0, i].item()
|
| 254 |
+
# top_k_tokens_list.append((token, logit_index, logit))
|
| 255 |
+
# return {
|
| 256 |
+
# "response": output_response,
|
| 257 |
+
# "top_k_tokens": top_k_tokens_list,
|
| 258 |
+
# "token_logits": token_logits,
|
| 259 |
+
# }
|
| 260 |
|
| 261 |
|
| 262 |
+
# if __name__ == "__main__":
|
| 263 |
+
# model_path = "lmms-lab/LLaVA-Video-7B-Qwen2" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 264 |
+
# device_map = "cuda:0"
|
| 265 |
+
# model = LLaVAVideoModel(model_path, device_map=device_map)
|
| 266 |
+
# prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying but failing to attach clip to ring because it doesn\'t stick\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Hand is holding the clip\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Clip is physically separate from the ring\n- Clip is not attached to the ring\n\nAnswer:'
|
| 267 |
+
# token_choices = ["Yes", "No"]
|
| 268 |
+
# video_path = (
|
| 269 |
+
# "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/101917.mp4"
|
| 270 |
+
# )
|
| 271 |
|
| 272 |
+
# generation_config = {
|
| 273 |
+
# "max_new_tokens": 128,
|
| 274 |
+
# "do_sample": False, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
|
| 275 |
+
# "temperature": 0.7,
|
| 276 |
+
# "logits_temperature": 1.0,
|
| 277 |
+
# "fps": 1.0,
|
| 278 |
+
# "return_confidence": False,
|
| 279 |
+
# "top_k_tokens": 10,
|
| 280 |
+
# "debug": False,
|
| 281 |
+
# }
|
| 282 |
+
# output = model.chat_with_confidence(
|
| 283 |
+
# prompt, video_path, token_choices=token_choices, **generation_config
|
| 284 |
+
# )
|
| 285 |
+
# response = output["response"]
|
| 286 |
+
# print(f"Response: {response}")
|
| 287 |
|
| 288 |
+
# if generation_config["return_confidence"]:
|
| 289 |
+
# confidence = output["confidence"]
|
| 290 |
+
# print(f"Confidence: {confidence}")
|
| 291 |
+
# else:
|
| 292 |
+
# # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
|
| 293 |
+
# # otherwise, the raw logits are used, which are not filtered.
|
| 294 |
+
# logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
|
| 295 |
+
# print(f"\n{'β'*80}")
|
| 296 |
+
# print(f"TOP 10 TOKENS FROM {logits_type} LOGITS (outputs.scores):")
|
| 297 |
+
# print(f"{'β'*80}")
|
| 298 |
+
# top_k_tokens = output["top_k_tokens"]
|
| 299 |
+
# for i in range(len(top_k_tokens)):
|
| 300 |
+
# print(f"Top {i+1} token: {top_k_tokens[i][0]}")
|
| 301 |
+
# print(f"Top {i+1} token index: {top_k_tokens[i][1]}")
|
| 302 |
+
# print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
|
| 303 |
+
# print("--------------------------------")
|
models/qwen2_5vl.py
CHANGED
|
@@ -2,11 +2,12 @@
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from transformers import (
|
| 5 |
-
|
| 6 |
AutoProcessor,
|
|
|
|
| 7 |
)
|
| 8 |
from typing import Optional, Dict, Any, Union, List
|
| 9 |
-
from qwen_vl_utils import process_vision_info
|
| 10 |
|
| 11 |
# Handle both relative and absolute imports
|
| 12 |
try:
|
|
@@ -22,10 +23,22 @@ class Qwen2_5VLModel(BaseVideoModel):
|
|
| 22 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 23 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 24 |
attn_implementation: Optional[str] = "flash_attention_2",
|
|
|
|
|
|
|
| 25 |
):
|
| 26 |
super().__init__(model_name)
|
| 27 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
model_name,
|
|
|
|
| 29 |
dtype=dtype,
|
| 30 |
device_map=device_map,
|
| 31 |
attn_implementation=attn_implementation,
|
|
@@ -36,257 +49,254 @@ class Qwen2_5VLModel(BaseVideoModel):
|
|
| 36 |
self,
|
| 37 |
prompt: str,
|
| 38 |
video_path: str,
|
| 39 |
-
fps: float = 1.0,
|
| 40 |
temperature: float = 0.7,
|
| 41 |
max_new_tokens: int = 512,
|
| 42 |
do_sample: Optional[bool] = True,
|
|
|
|
|
|
|
|
|
|
| 43 |
**kwargs: Any,
|
| 44 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Messages containing a local video path and a text query
|
| 46 |
-
|
| 47 |
{
|
| 48 |
"role": "user",
|
| 49 |
"content": [
|
| 50 |
{
|
| 51 |
-
"type": "video",
|
| 52 |
"video": video_path,
|
| 53 |
-
# "max_pixels": 360 * 420,
|
| 54 |
-
"fps": fps,
|
| 55 |
},
|
| 56 |
-
{"type": "text", "text": prompt}
|
| 57 |
],
|
| 58 |
-
}
|
| 59 |
]
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
)
|
| 67 |
-
inputs = self.processor(
|
| 68 |
-
text=[text],
|
| 69 |
-
images=image_inputs,
|
| 70 |
-
videos=video_inputs,
|
| 71 |
-
padding=True,
|
| 72 |
return_tensors="pt",
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
debug: Optional[bool] = False,
|
| 107 |
-
) -> Dict[str, Any]:
|
| 108 |
-
"""
|
| 109 |
-
Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
video_path (str): The path to the video file.
|
| 114 |
-
fps (float, optional): The frames per second of the video. Defaults to 1.0.
|
| 115 |
-
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
|
| 116 |
-
temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
|
| 117 |
-
logits_temperature (float, optional): The logits temperature to use for generation. Defaults to 1.0.
|
| 118 |
-
token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
|
| 119 |
-
return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
|
| 120 |
-
debug (bool, optional): Whether to run in debug mode. Defaults to False.
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
# "max_pixels": 360 * 420,
|
| 151 |
-
"fps": fps,
|
| 152 |
-
},
|
| 153 |
-
{"type": "text", "text": prompt},
|
| 154 |
-
],
|
| 155 |
-
}
|
| 156 |
-
]
|
| 157 |
-
|
| 158 |
-
text = self.processor.apply_chat_template(
|
| 159 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 160 |
-
)
|
| 161 |
-
image_inputs, video_inputs, video_kwargs = process_vision_info(
|
| 162 |
-
messages, return_video_kwargs=True
|
| 163 |
-
)
|
| 164 |
-
inputs = self.processor(
|
| 165 |
-
text=[text],
|
| 166 |
-
images=image_inputs,
|
| 167 |
-
videos=video_inputs,
|
| 168 |
-
padding=True,
|
| 169 |
-
return_tensors="pt",
|
| 170 |
-
**video_kwargs,
|
| 171 |
-
)
|
| 172 |
-
inputs = inputs.to(self.model.device)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
|
| 261 |
|
| 262 |
-
if __name__ == "__main__":
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from transformers import (
|
| 5 |
+
AutoModelForImageTextToText,
|
| 6 |
AutoProcessor,
|
| 7 |
+
BitsAndBytesConfig,
|
| 8 |
)
|
| 9 |
from typing import Optional, Dict, Any, Union, List
|
| 10 |
+
# from qwen_vl_utils import process_vision_info
|
| 11 |
|
| 12 |
# Handle both relative and absolute imports
|
| 13 |
try:
|
|
|
|
| 23 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 24 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 25 |
attn_implementation: Optional[str] = "flash_attention_2",
|
| 26 |
+
load_8bit: Optional[bool] = False,
|
| 27 |
+
load_4bit: Optional[bool] = False,
|
| 28 |
):
|
| 29 |
super().__init__(model_name)
|
| 30 |
+
self.dtype = dtype
|
| 31 |
+
quantization_config = None
|
| 32 |
+
if load_8bit or load_4bit:
|
| 33 |
+
quantization_config = BitsAndBytesConfig(
|
| 34 |
+
load_in_8bit=load_8bit,
|
| 35 |
+
load_in_4bit=load_4bit,
|
| 36 |
+
bnb_4bit_quant_type="nf4",
|
| 37 |
+
bnb_4bit_compute_dtype=torch.float16
|
| 38 |
+
)
|
| 39 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 40 |
model_name,
|
| 41 |
+
quantization_config=quantization_config,
|
| 42 |
dtype=dtype,
|
| 43 |
device_map=device_map,
|
| 44 |
attn_implementation=attn_implementation,
|
|
|
|
| 49 |
self,
|
| 50 |
prompt: str,
|
| 51 |
video_path: str,
|
|
|
|
| 52 |
temperature: float = 0.7,
|
| 53 |
max_new_tokens: int = 512,
|
| 54 |
do_sample: Optional[bool] = True,
|
| 55 |
+
fps: Optional[float] = 1.0,
|
| 56 |
+
num_frames: Optional[int] = 10,
|
| 57 |
+
video_mode: Optional[str] = "video",
|
| 58 |
**kwargs: Any,
|
| 59 |
) -> str:
|
| 60 |
+
# Ensure only one of fps or num_frames is provided
|
| 61 |
+
if video_mode == "frames":
|
| 62 |
+
fps = None
|
| 63 |
+
elif video_mode == "video":
|
| 64 |
+
num_frames = None
|
| 65 |
# Messages containing a local video path and a text query
|
| 66 |
+
conversation = [
|
| 67 |
{
|
| 68 |
"role": "user",
|
| 69 |
"content": [
|
| 70 |
{
|
| 71 |
+
"type": "video",
|
| 72 |
"video": video_path,
|
|
|
|
|
|
|
| 73 |
},
|
| 74 |
+
{"type": "text", "text": prompt}
|
| 75 |
],
|
| 76 |
+
},
|
| 77 |
]
|
| 78 |
|
| 79 |
+
inputs = self.processor.apply_chat_template(
|
| 80 |
+
conversation,
|
| 81 |
+
add_generation_prompt=True,
|
| 82 |
+
tokenize=True,
|
| 83 |
+
return_dict=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
return_tensors="pt",
|
| 85 |
+
do_sample_frames=True,
|
| 86 |
+
fps=fps,
|
| 87 |
+
num_frames=num_frames
|
| 88 |
+
).to(self.model.device)
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
out = self.model.generate(
|
| 91 |
+
**inputs,
|
| 92 |
+
do_sample=do_sample,
|
| 93 |
+
temperature=temperature,
|
| 94 |
+
max_new_tokens=max_new_tokens,
|
| 95 |
+
**kwargs,
|
| 96 |
+
)
|
| 97 |
+
raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
| 98 |
+
response = raw_response.split("assistant")[1].strip()
|
| 99 |
+
return response
|
| 100 |
+
|
| 101 |
|
| 102 |
+
# def chat_with_confidence(
|
| 103 |
+
# self,
|
| 104 |
+
# prompt: str,
|
| 105 |
+
# video_path: str,
|
| 106 |
+
# fps: Optional[float] = 1.0,
|
| 107 |
+
# num_frames: Optional[int] = 10,
|
| 108 |
+
# max_new_tokens: int = 512,
|
| 109 |
+
# temperature: float = 0.7,
|
| 110 |
+
# do_sample: Optional[bool] = True,
|
| 111 |
+
# video_mode: Optional[str] = "video",
|
| 112 |
+
# token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 113 |
+
# logits_temperature: Optional[float] = 1.0,
|
| 114 |
+
# return_confidence: Optional[bool] = False,
|
| 115 |
+
# debug: Optional[bool] = False,
|
| 116 |
+
# **kwargs: Any,
|
| 117 |
+
# ) -> Dict[str, Any]:
|
| 118 |
+
# """
|
| 119 |
+
# Returns the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
|
| 120 |
|
| 121 |
+
# Args:
|
| 122 |
+
# prompt (str): The text prompt to generate a response for.
|
| 123 |
+
# video_path (str): The path to the video file.
|
| 124 |
+
# fps (float, optional): The frames per second of the video. Defaults to 1.0.
|
| 125 |
+
# max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
|
| 126 |
+
# temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
|
| 127 |
+
# logits_temperature (float, optional): The logits temperature to use for generation. Defaults to 1.0.
|
| 128 |
+
# token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
|
| 129 |
+
# return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
|
| 130 |
+
# debug (bool, optional): Whether to run in debug mode. Defaults to False.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
# Returns:
|
| 133 |
+
# Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the token logits for token_choices.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# e.g., return_confidence: False
|
| 136 |
+
# Output:
|
| 137 |
+
# {
|
| 138 |
+
# "response": "Yes",
|
| 139 |
+
# "logits": {
|
| 140 |
+
# "Yes": 12.0,
|
| 141 |
+
# "No": 9.0
|
| 142 |
+
# }
|
| 143 |
+
# }
|
| 144 |
|
| 145 |
+
# e.g., return_confidence: True
|
| 146 |
+
# Output:
|
| 147 |
+
# {
|
| 148 |
+
# "response": "Yes",
|
| 149 |
+
# "confidence": 0.9999
|
| 150 |
+
# }
|
| 151 |
+
# """
|
| 152 |
+
# # Messages containing a local video path and a text query
|
| 153 |
+
# messages = [
|
| 154 |
+
# {
|
| 155 |
+
# "role": "user",
|
| 156 |
+
# "content": [
|
| 157 |
+
# {
|
| 158 |
+
# "type": "video",
|
| 159 |
+
# "video": video_path,
|
| 160 |
+
# # "max_pixels": 360 * 420,
|
| 161 |
+
# "fps": fps,
|
| 162 |
+
# },
|
| 163 |
+
# {"type": "text", "text": prompt},
|
| 164 |
+
# ],
|
| 165 |
+
# }
|
| 166 |
+
# ]
|
| 167 |
|
| 168 |
+
# text = self.processor.apply_chat_template(
|
| 169 |
+
# messages, tokenize=False, add_generation_prompt=True
|
| 170 |
+
# )
|
| 171 |
+
# image_inputs, video_inputs, video_kwargs = process_vision_info(
|
| 172 |
+
# messages, return_video_kwargs=True
|
| 173 |
+
# )
|
| 174 |
+
# inputs = self.processor(
|
| 175 |
+
# text=[text],
|
| 176 |
+
# images=image_inputs,
|
| 177 |
+
# videos=video_inputs,
|
| 178 |
+
# padding=True,
|
| 179 |
+
# return_tensors="pt",
|
| 180 |
+
# **video_kwargs,
|
| 181 |
+
# )
|
| 182 |
+
# inputs = inputs.to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
# # Inference with scores
|
| 185 |
+
# with torch.no_grad():
|
| 186 |
+
# outputs = self.model.generate(
|
| 187 |
+
# **inputs,
|
| 188 |
+
# temperature=temperature,
|
| 189 |
+
# max_new_tokens=max_new_tokens,
|
| 190 |
+
# output_scores=True,
|
| 191 |
+
# return_dict_in_generate=True,
|
| 192 |
+
# )
|
| 193 |
|
| 194 |
+
# generated_ids = outputs.sequences
|
| 195 |
+
# scores = outputs.scores # Tuple of tensors, one per generated token
|
| 196 |
+
# scores = tuple(
|
| 197 |
+
# s / logits_temperature for s in scores
|
| 198 |
+
# ) # Scales the logits by a factor for normalization during reporting
|
| 199 |
|
| 200 |
+
# print(f"Number of generated tokens: {len(scores)}")
|
| 201 |
+
# print(f"Vocabulary size: {scores[0].shape[1]}")
|
| 202 |
+
# # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
|
| 203 |
+
# if debug:
|
| 204 |
+
# print("****Running inference in debug mode****")
|
| 205 |
+
# # Print first token scores shape and max/min scores in debug mode
|
| 206 |
+
# print(f"Single token scores shape: {scores[0].shape}")
|
| 207 |
+
# print(
|
| 208 |
+
# f"First token max/min scores: {scores[0].max().item()}, {scores[0].min().item()}"
|
| 209 |
+
# )
|
| 210 |
+
# # Print details about top 3 tokens
|
| 211 |
+
# top_3_tokens = torch.topk(scores[0], k=3, dim=-1)
|
| 212 |
+
# for i in range(3):
|
| 213 |
+
# print(
|
| 214 |
+
# f"Pos 0 | {i+1}th Token: {self.processor.decode(top_3_tokens.indices[0, i].item())}"
|
| 215 |
+
# )
|
| 216 |
+
# print(
|
| 217 |
+
# f"Pos 0 | {i+1}th Token logit: {top_3_tokens.values[0, i].item()}"
|
| 218 |
+
# )
|
| 219 |
|
| 220 |
+
# # Trim the prompt tokens from generated sequences
|
| 221 |
+
# generated_ids_trimmed = [
|
| 222 |
+
# out_ids[len(in_ids) :]
|
| 223 |
+
# for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 224 |
+
# ]
|
| 225 |
|
| 226 |
+
# # Decode the text
|
| 227 |
+
# output_response = self.processor.batch_decode(
|
| 228 |
+
# generated_ids_trimmed,
|
| 229 |
+
# skip_special_tokens=True,
|
| 230 |
+
# clean_up_tokenization_spaces=False,
|
| 231 |
+
# )[0]
|
| 232 |
|
| 233 |
+
# # Convert scores to probabilities
|
| 234 |
+
# # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
|
| 235 |
+
# selected_token_probs = []
|
| 236 |
+
# selected_token_logits = []
|
| 237 |
+
# first_token_probs = torch.softmax(scores[0], dim=-1)
|
| 238 |
|
| 239 |
+
# # Now, find indices of tokens in token_choices and get their probabilities
|
| 240 |
+
# for token_choice in token_choices:
|
| 241 |
+
# # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
|
| 242 |
+
# token_index = self.processor.tokenizer.encode(
|
| 243 |
+
# token_choice, add_special_tokens=False
|
| 244 |
+
# )[0]
|
| 245 |
+
# selected_token_probs.append(first_token_probs[0, token_index].item())
|
| 246 |
+
# selected_token_logits.append(scores[0][0, token_index].item())
|
| 247 |
|
| 248 |
+
# # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
|
| 249 |
+
# if return_confidence:
|
| 250 |
+
# first_token_id = generated_ids_trimmed[0][
|
| 251 |
+
# 0
|
| 252 |
+
# ].item() # First token of the first sequence
|
| 253 |
+
# confidence = (
|
| 254 |
+
# first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
|
| 255 |
+
# if sum(selected_token_probs) > 0
|
| 256 |
+
# else 0.0
|
| 257 |
+
# )
|
| 258 |
+
# return {
|
| 259 |
+
# "response": output_response,
|
| 260 |
+
# "confidence": confidence,
|
| 261 |
+
# }
|
| 262 |
|
| 263 |
+
# # Retrn token logits
|
| 264 |
+
# else:
|
| 265 |
+
# token_logits = dict(zip(token_choices, selected_token_logits))
|
| 266 |
+
# return {
|
| 267 |
+
# "response": output_response,
|
| 268 |
+
# "logits": token_logits,
|
| 269 |
+
# }
|
| 270 |
|
| 271 |
|
| 272 |
+
# if __name__ == "__main__":
|
| 273 |
+
# model_path = "Qwen/Qwen2.5-VL-7B-Instruct" # "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 274 |
+
# model = Qwen2_5VLModel(model_path)
|
| 275 |
+
# prompt = (
|
| 276 |
+
# "Which of the following exist in the video? Answer in A or B.\nA: Hand\nB: Face"
|
| 277 |
+
# )
|
| 278 |
+
# token_choices = ["A", "B"]
|
| 279 |
+
# ext = ".webm"
|
| 280 |
+
# video_path = "/home/shreyasj/Syed/data/Something-Something-V2/videos/101917" + ext
|
| 281 |
|
| 282 |
+
# generation_config = {
|
| 283 |
+
# "max_new_tokens": 128,
|
| 284 |
+
# "temperature": 0.7,
|
| 285 |
+
# "logits_temperature": 5.0,
|
| 286 |
+
# "fps": 3.0,
|
| 287 |
+
# "return_confidence": False,
|
| 288 |
+
# "debug": True,
|
| 289 |
+
# }
|
| 290 |
+
# output = model.chat_with_confidence(
|
| 291 |
+
# prompt, video_path, token_choices=token_choices, **generation_config
|
| 292 |
+
# )
|
| 293 |
+
# response = output["response"]
|
| 294 |
+
# print(f"Response: {response}")
|
| 295 |
|
| 296 |
+
# if generation_config["return_confidence"]:
|
| 297 |
+
# confidence = output["confidence"]
|
| 298 |
+
# print(f"Confidence: {confidence}")
|
| 299 |
+
# else:
|
| 300 |
+
# selected_token_logits = output["logits"]
|
| 301 |
+
# print(f"Selected token logits: {selected_token_logits}")
|
| 302 |
+
# print(f"Logits temperature: {generation_config['logits_temperature']}")
|
models/qwen3vl.py
CHANGED
|
@@ -2,14 +2,11 @@
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from transformers import (
|
| 5 |
-
|
| 6 |
AutoProcessor,
|
|
|
|
| 7 |
)
|
| 8 |
from typing import Optional, Dict, Any, Union, List, Tuple
|
| 9 |
-
from qwen_vl_utils import process_vision_info
|
| 10 |
-
import cv2
|
| 11 |
-
import numpy as np
|
| 12 |
-
from PIL import Image
|
| 13 |
|
| 14 |
# Handle both relative and absolute imports
|
| 15 |
try:
|
|
@@ -18,50 +15,31 @@ except ImportError:
|
|
| 18 |
from base import BaseVideoModel
|
| 19 |
|
| 20 |
|
| 21 |
-
def downsample_video(video_path, max_dim=720, num_frames=10):
|
| 22 |
-
vidcap = cv2.VideoCapture(video_path)
|
| 23 |
-
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 24 |
-
frames = []
|
| 25 |
-
frame_indices = np.linspace(
|
| 26 |
-
0, total_frames - 1, min(total_frames, num_frames), dtype=int
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
for i in frame_indices:
|
| 30 |
-
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 31 |
-
success, image = vidcap.read()
|
| 32 |
-
if success:
|
| 33 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 34 |
-
|
| 35 |
-
h, w = image.shape[:2]
|
| 36 |
-
scale = max_dim / max(h, w)
|
| 37 |
-
if scale < 1:
|
| 38 |
-
image = cv2.resize(
|
| 39 |
-
image,
|
| 40 |
-
(int(w * scale), int(h * scale)),
|
| 41 |
-
interpolation=cv2.INTER_AREA,
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
pil_image = Image.fromarray(image)
|
| 45 |
-
frames.append(pil_image)
|
| 46 |
-
|
| 47 |
-
vidcap.release()
|
| 48 |
-
return frames
|
| 49 |
-
|
| 50 |
-
|
| 51 |
class Qwen3VLModel(BaseVideoModel):
|
| 52 |
def __init__(
|
| 53 |
self,
|
| 54 |
-
model_name: str = "Qwen/Qwen3-VL-
|
| 55 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 56 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 57 |
attn_implementation: Optional[str] = "flash_attention_2",
|
|
|
|
|
|
|
| 58 |
):
|
| 59 |
super().__init__(model_name)
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
model_name,
|
| 62 |
-
|
| 63 |
device_map=device_map,
|
| 64 |
attn_implementation=attn_implementation,
|
|
|
|
| 65 |
)
|
| 66 |
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 67 |
|
|
@@ -69,467 +47,443 @@ class Qwen3VLModel(BaseVideoModel):
|
|
| 69 |
self,
|
| 70 |
prompt: str,
|
| 71 |
video_path: str,
|
| 72 |
-
fps: float = 1.0,
|
| 73 |
temperature: float = 0.7,
|
| 74 |
do_sample: Optional[
|
| 75 |
bool
|
| 76 |
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 77 |
max_new_tokens: int = 512,
|
| 78 |
video_mode: Optional[str] = "video", # Choose from "video" or "frames"
|
| 79 |
-
|
|
|
|
| 80 |
**kwargs: Any,
|
| 81 |
) -> str:
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
],
|
| 89 |
-
}
|
| 90 |
-
]
|
| 91 |
-
if video_mode == "video":
|
| 92 |
-
messages[0]["content"].append(
|
| 93 |
-
{
|
| 94 |
-
"type": "video",
|
| 95 |
-
"video": video_path,
|
| 96 |
-
# "max_pixels": 360 * 420,
|
| 97 |
-
"fps": fps,
|
| 98 |
-
}
|
| 99 |
-
)
|
| 100 |
-
inputs = self.processor.apply_chat_template(
|
| 101 |
-
messages,
|
| 102 |
-
tokenize=True,
|
| 103 |
-
add_generation_prompt=True,
|
| 104 |
-
return_dict=True,
|
| 105 |
-
return_tensors="pt",
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
elif video_mode == "frames":
|
| 109 |
-
frames = downsample_video(video_path, max_dim=720, num_frames=video_frames)
|
| 110 |
-
images_for_processor = []
|
| 111 |
-
for frame in frames:
|
| 112 |
-
messages[0]["content"].append({"type": "image"})
|
| 113 |
-
images_for_processor.append(frame)
|
| 114 |
-
prompt_full = self.processor.apply_chat_template(
|
| 115 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 116 |
-
)
|
| 117 |
-
inputs = self.processor(
|
| 118 |
-
text=[prompt_full],
|
| 119 |
-
images=images_for_processor,
|
| 120 |
-
return_tensors="pt",
|
| 121 |
-
padding=True,
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
inputs = inputs.to(self.model.device)
|
| 125 |
-
|
| 126 |
-
generated_ids = self.model.generate(
|
| 127 |
-
**inputs,
|
| 128 |
-
max_new_tokens=max_new_tokens,
|
| 129 |
-
temperature=temperature,
|
| 130 |
-
do_sample=do_sample,
|
| 131 |
-
**kwargs,
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
generated_ids_trimmed = [
|
| 135 |
-
out_ids[len(in_ids) :]
|
| 136 |
-
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 137 |
-
]
|
| 138 |
-
|
| 139 |
-
output_response = self.processor.batch_decode(
|
| 140 |
-
generated_ids_trimmed,
|
| 141 |
-
skip_special_tokens=True,
|
| 142 |
-
clean_up_tokenization_spaces=False,
|
| 143 |
-
)[0]
|
| 144 |
-
|
| 145 |
-
return output_response
|
| 146 |
-
|
| 147 |
-
def chat_with_confidence(
|
| 148 |
-
self,
|
| 149 |
-
prompt: str,
|
| 150 |
-
video_path: str,
|
| 151 |
-
fps: float = 1.0,
|
| 152 |
-
max_new_tokens: int = 512,
|
| 153 |
-
temperature: float = 0.7,
|
| 154 |
-
do_sample: Optional[
|
| 155 |
-
bool
|
| 156 |
-
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 157 |
-
token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 158 |
-
logits_temperature: Optional[float] = 1.0,
|
| 159 |
-
return_confidence: Optional[bool] = False,
|
| 160 |
-
top_k_tokens: Optional[int] = 10,
|
| 161 |
-
debug: Optional[bool] = False,
|
| 162 |
-
**kwargs: Any,
|
| 163 |
-
) -> Dict[str, Any]:
|
| 164 |
-
"""
|
| 165 |
-
Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
|
| 166 |
-
|
| 167 |
-
Args:
|
| 168 |
-
prompt (str): The text prompt to generate a response for.
|
| 169 |
-
video_path (str): The path to the video file.
|
| 170 |
-
temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
|
| 171 |
-
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512.
|
| 172 |
-
token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
|
| 173 |
-
generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
|
| 174 |
-
return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
|
| 175 |
-
top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
|
| 176 |
-
debug (bool, optional): Whether to run in debug mode. Defaults to False.
|
| 177 |
-
|
| 178 |
-
Returns:
|
| 179 |
-
Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
|
| 180 |
-
|
| 181 |
-
e.g., return_confidence: False
|
| 182 |
-
Output:
|
| 183 |
-
{
|
| 184 |
-
"response": "Yes",
|
| 185 |
-
"top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
e.g., return_confidence: True
|
| 189 |
-
Output:
|
| 190 |
-
{
|
| 191 |
-
"response": "Yes",
|
| 192 |
-
"confidence": 0.9999
|
| 193 |
-
}
|
| 194 |
-
"""
|
| 195 |
-
# Messages containing a local video path and a text query
|
| 196 |
-
messages = [
|
| 197 |
{
|
| 198 |
"role": "user",
|
| 199 |
"content": [
|
| 200 |
{
|
| 201 |
-
"type": "video",
|
| 202 |
"video": video_path,
|
| 203 |
-
# "max_pixels": 360 * 420,
|
| 204 |
-
"fps": fps,
|
| 205 |
},
|
| 206 |
-
{"type": "text", "text": prompt}
|
| 207 |
],
|
| 208 |
-
}
|
| 209 |
]
|
| 210 |
|
| 211 |
inputs = self.processor.apply_chat_template(
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
return_dict=True,
|
| 216 |
return_tensors="pt",
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
# In debug mode, inspect what logits processors will be used
|
| 222 |
-
if debug:
|
| 223 |
-
print("\n" + "=" * 80)
|
| 224 |
-
print("INSPECTING GENERATION CONFIG & WARPERS")
|
| 225 |
-
print("=" * 80)
|
| 226 |
-
# Get the generation config to see what processors will be added
|
| 227 |
-
gen_config = self.model.generation_config
|
| 228 |
-
print(f"Generation config attributes:")
|
| 229 |
-
print(f" Processor-related:")
|
| 230 |
-
print(
|
| 231 |
-
f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
|
| 232 |
-
)
|
| 233 |
-
print(
|
| 234 |
-
f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
|
| 235 |
-
)
|
| 236 |
-
print(
|
| 237 |
-
f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
|
| 238 |
-
)
|
| 239 |
-
print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
|
| 240 |
-
print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
|
| 241 |
-
print(
|
| 242 |
-
f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
|
| 243 |
-
)
|
| 244 |
-
print(
|
| 245 |
-
f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
|
| 246 |
-
)
|
| 247 |
-
print(f" Warper-related (THESE MASK TOKENS TO -INF):")
|
| 248 |
-
print(f" - temperature: {temperature} (passed as arg)")
|
| 249 |
-
print(
|
| 250 |
-
f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
|
| 251 |
-
)
|
| 252 |
-
print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
|
| 253 |
-
print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
|
| 254 |
-
print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
|
| 255 |
-
print(
|
| 256 |
-
f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
|
| 257 |
-
)
|
| 258 |
-
print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
|
| 259 |
-
print(
|
| 260 |
-
f"\n β οΈ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
|
| 261 |
-
)
|
| 262 |
-
print("=" * 80 + "\n")
|
| 263 |
-
|
| 264 |
-
# Inference with scores
|
| 265 |
with torch.no_grad():
|
| 266 |
-
|
| 267 |
**inputs,
|
|
|
|
| 268 |
temperature=temperature,
|
| 269 |
max_new_tokens=max_new_tokens,
|
| 270 |
-
do_sample=do_sample,
|
| 271 |
-
output_scores=True,
|
| 272 |
-
output_logits=True, # Get TRUE raw logits before any processing
|
| 273 |
-
return_dict_in_generate=True,
|
| 274 |
**kwargs,
|
| 275 |
)
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from transformers import (
|
| 5 |
+
AutoModelForImageTextToText,
|
| 6 |
AutoProcessor,
|
| 7 |
+
BitsAndBytesConfig,
|
| 8 |
)
|
| 9 |
from typing import Optional, Dict, Any, Union, List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Handle both relative and absolute imports
|
| 12 |
try:
|
|
|
|
| 15 |
from base import BaseVideoModel
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class Qwen3VLModel(BaseVideoModel):
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
+
model_name: str = "Qwen/Qwen3-VL-4B-Instruct",
|
| 22 |
dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16,
|
| 23 |
device_map: Optional[Union[str, Dict]] = "auto",
|
| 24 |
attn_implementation: Optional[str] = "flash_attention_2",
|
| 25 |
+
load_8bit: Optional[bool] = False,
|
| 26 |
+
load_4bit: Optional[bool] = False,
|
| 27 |
):
|
| 28 |
super().__init__(model_name)
|
| 29 |
+
quantization_config = None
|
| 30 |
+
if load_8bit or load_4bit:
|
| 31 |
+
quantization_config = BitsAndBytesConfig(
|
| 32 |
+
load_in_8bit=load_8bit,
|
| 33 |
+
load_in_4bit=load_4bit,
|
| 34 |
+
bnb_4bit_quant_type="nf4",
|
| 35 |
+
bnb_4bit_compute_dtype=torch.float16
|
| 36 |
+
)
|
| 37 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 38 |
model_name,
|
| 39 |
+
quantization_config=quantization_config,
|
| 40 |
device_map=device_map,
|
| 41 |
attn_implementation=attn_implementation,
|
| 42 |
+
dtype=dtype,
|
| 43 |
)
|
| 44 |
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 45 |
|
|
|
|
| 47 |
self,
|
| 48 |
prompt: str,
|
| 49 |
video_path: str,
|
|
|
|
| 50 |
temperature: float = 0.7,
|
| 51 |
do_sample: Optional[
|
| 52 |
bool
|
| 53 |
] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 54 |
max_new_tokens: int = 512,
|
| 55 |
video_mode: Optional[str] = "video", # Choose from "video" or "frames"
|
| 56 |
+
fps: Optional[float] = 1.0,
|
| 57 |
+
num_frames: Optional[int] = 10,
|
| 58 |
**kwargs: Any,
|
| 59 |
) -> str:
|
| 60 |
+
# Ensure only one of fps or num_frames is provided
|
| 61 |
+
if video_mode == "frames":
|
| 62 |
+
fps = None
|
| 63 |
+
elif video_mode == "video":
|
| 64 |
+
num_frames = None
|
| 65 |
+
conversation = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
{
|
| 67 |
"role": "user",
|
| 68 |
"content": [
|
| 69 |
{
|
| 70 |
+
"type": "video",
|
| 71 |
"video": video_path,
|
|
|
|
|
|
|
| 72 |
},
|
| 73 |
+
{"type": "text", "text": prompt}
|
| 74 |
],
|
| 75 |
+
},
|
| 76 |
]
|
| 77 |
|
| 78 |
inputs = self.processor.apply_chat_template(
|
| 79 |
+
conversation,
|
| 80 |
+
add_generation_prompt=True,
|
| 81 |
+
tokenize=True,
|
| 82 |
+
return_dict=True,
|
| 83 |
return_tensors="pt",
|
| 84 |
+
do_sample_frames=True,
|
| 85 |
+
fps=fps,
|
| 86 |
+
num_frames=num_frames
|
| 87 |
+
).to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
with torch.no_grad():
|
| 89 |
+
out = self.model.generate(
|
| 90 |
**inputs,
|
| 91 |
+
do_sample=do_sample,
|
| 92 |
temperature=temperature,
|
| 93 |
max_new_tokens=max_new_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
**kwargs,
|
| 95 |
)
|
| 96 |
+
raw_response = self.processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
| 97 |
+
response = raw_response.split("assistant")[1].strip()
|
| 98 |
+
return response
|
| 99 |
+
|
| 100 |
+
# def chat_with_confidence(
|
| 101 |
+
# self,
|
| 102 |
+
# prompt: str,
|
| 103 |
+
# video_path: str,
|
| 104 |
+
# max_new_tokens: int = 512,
|
| 105 |
+
# temperature: float = 0.7,
|
| 106 |
+
# do_sample: Optional[
|
| 107 |
+
# bool
|
| 108 |
+
# ] = True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P!
|
| 109 |
+
# fps: Optional[float] = 1.0,
|
| 110 |
+
# num_frames: Optional[int] = 10,
|
| 111 |
+
# token_choices: Optional[List[str]] = ["Yes", "No"],
|
| 112 |
+
# logits_temperature: Optional[float] = 1.0,
|
| 113 |
+
# return_confidence: Optional[bool] = False,
|
| 114 |
+
# top_k_tokens: Optional[int] = 10,
|
| 115 |
+
# debug: Optional[bool] = False,
|
| 116 |
+
# **kwargs: Any,
|
| 117 |
+
# ) -> Dict[str, Any]:
|
| 118 |
+
# """
|
| 119 |
+
# Returns the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
|
| 120 |
+
|
| 121 |
+
# Args:
|
| 122 |
+
# prompt (str): The text prompt to generate a response for.
|
| 123 |
+
# video_path (str): The path to the video file.
|
| 124 |
+
# temperature (float, optional): The temperature to use for generation. Defaults to 0.7.
|
| 125 |
+
# max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 512.
|
| 126 |
+
# token_choices (List[str], optional): The list of token choices to return logits for. Defaults to ["Yes", "No"].
|
| 127 |
+
# generation_config (Dict[str, Any], optional): The generation configuration. Defaults to None.
|
| 128 |
+
# return_confidence (bool, optional): Whether to return the confidence of the response. Defaults to False.
|
| 129 |
+
# top_k_tokens (int, optional): The number of top tokens to return. Defaults to 10. Only applicable if return_confidence is False.
|
| 130 |
+
# debug (bool, optional): Whether to run in debug mode. Defaults to False.
|
| 131 |
+
|
| 132 |
+
# Returns:
|
| 133 |
+
# Dict[str, Any]: A dictionary containing the response and confidence of the response, if return_confidence is True. Else, returns the top k tokens and their logits.
|
| 134 |
+
|
| 135 |
+
# e.g., return_confidence: False
|
| 136 |
+
# Output:
|
| 137 |
+
# {
|
| 138 |
+
# "response": "Yes",
|
| 139 |
+
# "top_k_tokens": [("Yes", 12.0, 12), ("No", 9.0, 9)],
|
| 140 |
+
# }
|
| 141 |
+
|
| 142 |
+
# e.g., return_confidence: True
|
| 143 |
+
# Output:
|
| 144 |
+
# {
|
| 145 |
+
# "response": "Yes",
|
| 146 |
+
# "confidence": 0.9999
|
| 147 |
+
# }
|
| 148 |
+
# """
|
| 149 |
+
# # Messages containing a local video path and a text query
|
| 150 |
+
# messages = [
|
| 151 |
+
# {
|
| 152 |
+
# "role": "user",
|
| 153 |
+
# "content": [
|
| 154 |
+
# {
|
| 155 |
+
# "type": "video",
|
| 156 |
+
# "video": video_path,
|
| 157 |
+
# # "max_pixels": 360 * 420,
|
| 158 |
+
# "fps": fps,
|
| 159 |
+
# },
|
| 160 |
+
# {"type": "text", "text": prompt},
|
| 161 |
+
# ],
|
| 162 |
+
# }
|
| 163 |
+
# ]
|
| 164 |
+
|
| 165 |
+
# inputs = self.processor.apply_chat_template(
|
| 166 |
+
# messages,
|
| 167 |
+
# tokenize=True,
|
| 168 |
+
# add_generation_prompt=True,
|
| 169 |
+
# return_dict=True,
|
| 170 |
+
# return_tensors="pt",
|
| 171 |
+
# )
|
| 172 |
+
|
| 173 |
+
# inputs = inputs.to(self.model.device)
|
| 174 |
+
|
| 175 |
+
# # In debug mode, inspect what logits processors will be used
|
| 176 |
+
# if debug:
|
| 177 |
+
# print("\n" + "=" * 80)
|
| 178 |
+
# print("INSPECTING GENERATION CONFIG & WARPERS")
|
| 179 |
+
# print("=" * 80)
|
| 180 |
+
# # Get the generation config to see what processors will be added
|
| 181 |
+
# gen_config = self.model.generation_config
|
| 182 |
+
# print(f"Generation config attributes:")
|
| 183 |
+
# print(f" Processor-related:")
|
| 184 |
+
# print(
|
| 185 |
+
# f" - repetition_penalty: {getattr(gen_config, 'repetition_penalty', None)}"
|
| 186 |
+
# )
|
| 187 |
+
# print(
|
| 188 |
+
# f" - no_repeat_ngram_size: {getattr(gen_config, 'no_repeat_ngram_size', None)}"
|
| 189 |
+
# )
|
| 190 |
+
# print(
|
| 191 |
+
# f" - encoder_no_repeat_ngram_size: {getattr(gen_config, 'encoder_no_repeat_ngram_size', None)}"
|
| 192 |
+
# )
|
| 193 |
+
# print(f" - bad_words_ids: {getattr(gen_config, 'bad_words_ids', None)}")
|
| 194 |
+
# print(f" - min_length: {getattr(gen_config, 'min_length', None)}")
|
| 195 |
+
# print(
|
| 196 |
+
# f" - forced_bos_token_id: {getattr(gen_config, 'forced_bos_token_id', None)}"
|
| 197 |
+
# )
|
| 198 |
+
# print(
|
| 199 |
+
# f" - forced_eos_token_id: {getattr(gen_config, 'forced_eos_token_id', None)}"
|
| 200 |
+
# )
|
| 201 |
+
# print(f" Warper-related (THESE MASK TOKENS TO -INF):")
|
| 202 |
+
# print(f" - temperature: {temperature} (passed as arg)")
|
| 203 |
+
# print(
|
| 204 |
+
# f" - do_sample: {getattr(gen_config, 'do_sample', 'Not set (will be inferred)')}"
|
| 205 |
+
# )
|
| 206 |
+
# print(f" - top_k: {getattr(gen_config, 'top_k', None)}")
|
| 207 |
+
# print(f" - top_p: {getattr(gen_config, 'top_p', None)}")
|
| 208 |
+
# print(f" - typical_p: {getattr(gen_config, 'typical_p', None)}")
|
| 209 |
+
# print(
|
| 210 |
+
# f" - epsilon_cutoff: {getattr(gen_config, 'epsilon_cutoff', None)}"
|
| 211 |
+
# )
|
| 212 |
+
# print(f" - eta_cutoff: {getattr(gen_config, 'eta_cutoff', None)}")
|
| 213 |
+
# print(
|
| 214 |
+
# f"\n β οΈ If top_k or top_p are set, they will mask non-selected tokens to -inf!"
|
| 215 |
+
# )
|
| 216 |
+
# print("=" * 80 + "\n")
|
| 217 |
+
|
| 218 |
+
# # Inference with scores
|
| 219 |
+
# with torch.no_grad():
|
| 220 |
+
# outputs = self.model.generate(
|
| 221 |
+
# **inputs,
|
| 222 |
+
# temperature=temperature,
|
| 223 |
+
# max_new_tokens=max_new_tokens,
|
| 224 |
+
# do_sample=do_sample,
|
| 225 |
+
# output_scores=True,
|
| 226 |
+
# output_logits=True, # Get TRUE raw logits before any processing
|
| 227 |
+
# return_dict_in_generate=True,
|
| 228 |
+
# **kwargs,
|
| 229 |
+
# )
|
| 230 |
+
|
| 231 |
+
# generated_ids = outputs.sequences
|
| 232 |
+
# scores = outputs.scores # Tuple of tensors - PROCESSED logits used for sampling
|
| 233 |
+
# logits = (
|
| 234 |
+
# outputs.logits if hasattr(outputs, "logits") else None
|
| 235 |
+
# ) # TRUE raw logits from model
|
| 236 |
+
|
| 237 |
+
# scores = tuple(
|
| 238 |
+
# s / logits_temperature for s in scores
|
| 239 |
+
# ) # Scales the logits by a factor for normalization during reporting
|
| 240 |
+
|
| 241 |
+
# print(f"Number of generated tokens: {len(scores)}")
|
| 242 |
+
# print(f"Vocabulary size: {scores[0].shape[1]}")
|
| 243 |
+
|
| 244 |
+
# # Check if logits differ from scores
|
| 245 |
+
# if debug and logits is not None:
|
| 246 |
+
# print(f"\n[IMPORTANT] output_logits available: True")
|
| 247 |
+
# print(
|
| 248 |
+
# f"[IMPORTANT] Comparing outputs.logits (raw) vs outputs.scores (processed):"
|
| 249 |
+
# )
|
| 250 |
+
# logits_raw = logits[0] / logits_temperature # First token's raw logits
|
| 251 |
+
# scores_first = scores[0] # First token's processed scores
|
| 252 |
+
|
| 253 |
+
# logits_diff = (logits_raw.cpu() - scores_first.cpu()).abs()
|
| 254 |
+
# max_diff = logits_diff.max().item()
|
| 255 |
+
# if max_diff > 0.001:
|
| 256 |
+
# print(
|
| 257 |
+
# f"[IMPORTANT] β οΈ outputs.scores ARE DIFFERENT from outputs.logits!"
|
| 258 |
+
# )
|
| 259 |
+
# print(f"[IMPORTANT] Max difference: {max_diff:.6f}")
|
| 260 |
+
# print(
|
| 261 |
+
# f"[IMPORTANT] This means outputs.scores are PROCESSED, not raw!"
|
| 262 |
+
# )
|
| 263 |
+
# else:
|
| 264 |
+
# print(f"[IMPORTANT] β outputs.scores == outputs.logits (both are raw)")
|
| 265 |
+
# elif debug:
|
| 266 |
+
# print(
|
| 267 |
+
# f"\n[IMPORTANT] output_logits not available in this transformers version"
|
| 268 |
+
# )
|
| 269 |
+
|
| 270 |
+
# # Print top 3 tokens at 1st position (i.e., scores[0]) along with their probabilities in debug mode
|
| 271 |
+
# if debug:
|
| 272 |
+
# print("\n" + "=" * 80)
|
| 273 |
+
# print("****Running inference in debug mode****")
|
| 274 |
+
# print("=" * 80)
|
| 275 |
+
|
| 276 |
+
# # Use truly raw logits if available, otherwise use scores
|
| 277 |
+
# raw_logits_to_show = (
|
| 278 |
+
# logits[0] / logits_temperature if logits is not None else scores[0]
|
| 279 |
+
# )
|
| 280 |
+
# logits_label = (
|
| 281 |
+
# "TRUE RAW LOGITS (from outputs.logits)"
|
| 282 |
+
# if logits is not None
|
| 283 |
+
# else "LOGITS (from outputs.scores)"
|
| 284 |
+
# )
|
| 285 |
+
|
| 286 |
+
# # Print first token scores shape and max/min scores in debug mode
|
| 287 |
+
# print(
|
| 288 |
+
# f"\n[{logits_label}] Single token scores shape: {raw_logits_to_show.shape}"
|
| 289 |
+
# )
|
| 290 |
+
# print(
|
| 291 |
+
# f"[{logits_label}] First token max/min: {raw_logits_to_show.max().item():.4f}, {raw_logits_to_show.min().item():.4f}"
|
| 292 |
+
# )
|
| 293 |
+
|
| 294 |
+
# # Print details about top 3 tokens from RAW logits
|
| 295 |
+
# print(f"\n{'β'*80}")
|
| 296 |
+
# print(f"TOP 3 TOKENS FROM {logits_label}:")
|
| 297 |
+
# print(f"{'β'*80}")
|
| 298 |
+
# top_3_tokens = torch.topk(raw_logits_to_show, k=3, dim=-1)
|
| 299 |
+
# for i in range(3):
|
| 300 |
+
# token_id = top_3_tokens.indices[0, i].item()
|
| 301 |
+
# token_text = self.processor.decode(token_id)
|
| 302 |
+
# token_logit = top_3_tokens.values[0, i].item()
|
| 303 |
+
# print(
|
| 304 |
+
# f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
|
| 305 |
+
# )
|
| 306 |
+
|
| 307 |
+
# # Now compare with POST-PROCESSED logits (outputs.scores)
|
| 308 |
+
# scores_first = scores[0] / logits_temperature
|
| 309 |
+
# print(f"\n{'β'*80}")
|
| 310 |
+
# print("TOP 3 TOKENS FROM LOGITS CAPTURE (after all processors):")
|
| 311 |
+
# print(f"{'β'*80}")
|
| 312 |
+
# print(
|
| 313 |
+
# f"[POST-PROCESSED] Max/min logits: {scores_first.max().item():.4f}, {scores_first.min().item():.4f}"
|
| 314 |
+
# )
|
| 315 |
+
|
| 316 |
+
# top_3_processed = torch.topk(scores_first, k=3, dim=-1)
|
| 317 |
+
# for i in range(3):
|
| 318 |
+
# token_id = top_3_processed.indices[0, i].item()
|
| 319 |
+
# token_text = self.processor.decode(token_id)
|
| 320 |
+
# token_logit = top_3_processed.values[0, i].item()
|
| 321 |
+
# print(
|
| 322 |
+
# f" #{i+1}: Token='{token_text}' | ID={token_id} | Logit={token_logit:.4f}"
|
| 323 |
+
# )
|
| 324 |
+
|
| 325 |
+
# # Check if the distributions differ (compare against truly raw logits if available)
|
| 326 |
+
# print(f"\n{'β'*80}")
|
| 327 |
+
# print("DIFFERENCE ANALYSIS (Raw β Post-Processed):")
|
| 328 |
+
# print(f"{'β'*80}")
|
| 329 |
+
# logit_diff = (scores_first.cpu() - raw_logits_to_show.cpu()).abs()
|
| 330 |
+
# max_diff = logit_diff.max().item()
|
| 331 |
+
# num_changed = (logit_diff > 0.001).sum().item()
|
| 332 |
+
|
| 333 |
+
# print(f" Max logit difference: {max_diff:.6f}")
|
| 334 |
+
# print(
|
| 335 |
+
# f" Number of tokens with changed logits: {num_changed}/{raw_logits_to_show.shape[1]}"
|
| 336 |
+
# )
|
| 337 |
+
|
| 338 |
+
# if max_diff > 0.001:
|
| 339 |
+
# print(f"\n β οΈ LOGITS WERE MODIFIED BY PROCESSORS!")
|
| 340 |
+
# # Show which tokens changed the most
|
| 341 |
+
# top_changes = torch.topk(logit_diff[0], k=min(5, num_changed))
|
| 342 |
+
# print(f"\n Top 5 most changed tokens:")
|
| 343 |
+
# for i in range(min(5, len(top_changes.indices))):
|
| 344 |
+
# token_id = top_changes.indices[i].item()
|
| 345 |
+
# token_text = self.processor.decode(token_id)
|
| 346 |
+
# raw_logit = raw_logits_to_show[0, token_id].item()
|
| 347 |
+
# processed_logit = scores_first[0, token_id].item()
|
| 348 |
+
# diff = top_changes.values[i].item()
|
| 349 |
+
# print(f" Token='{token_text}' | ID={token_id}")
|
| 350 |
+
# print(
|
| 351 |
+
# f" Raw: {raw_logit:.4f} β Processed: {processed_logit:.4f} (Ξ={diff:.4f})"
|
| 352 |
+
# )
|
| 353 |
+
# else:
|
| 354 |
+
# print(f" β No significant modifications detected")
|
| 355 |
+
|
| 356 |
+
# # Show what token was actually selected
|
| 357 |
+
# print(f"\n{'β'*80}")
|
| 358 |
+
# print("ACTUALLY GENERATED TOKEN:")
|
| 359 |
+
# print(f"{'β'*80}")
|
| 360 |
+
# first_generated_id = generated_ids[0, len(inputs.input_ids[0])].item()
|
| 361 |
+
# first_generated_token = self.processor.decode(first_generated_id)
|
| 362 |
+
# raw_logit_for_generated = raw_logits_to_show[0, first_generated_id].item()
|
| 363 |
+
|
| 364 |
+
# print(f" Token: '{first_generated_token}' | ID={first_generated_id}")
|
| 365 |
+
# print(f" Raw logit: {raw_logit_for_generated:.4f}")
|
| 366 |
+
|
| 367 |
+
# processed_logit_for_generated = scores_first[0, first_generated_id].item()
|
| 368 |
+
# print(f" Post-processed logit: {processed_logit_for_generated:.4f}")
|
| 369 |
+
|
| 370 |
+
# # Check if this token is in top-k of raw logits
|
| 371 |
+
# top_k_raw_indices = torch.topk(
|
| 372 |
+
# raw_logits_to_show, k=min(10, raw_logits_to_show.shape[1]), dim=-1
|
| 373 |
+
# ).indices[0]
|
| 374 |
+
# is_in_top10_raw = first_generated_id in top_k_raw_indices
|
| 375 |
+
# print(f" In top-10 of RAW logits: {is_in_top10_raw}")
|
| 376 |
+
|
| 377 |
+
# if not is_in_top10_raw:
|
| 378 |
+
# print(
|
| 379 |
+
# f"\n π¨ CRITICAL: Generated token was NOT in top-10 of raw logits!"
|
| 380 |
+
# )
|
| 381 |
+
# print(
|
| 382 |
+
# f" This proves that logits processors modified the distribution."
|
| 383 |
+
# )
|
| 384 |
+
# # Find the rank of the generated token in raw logits
|
| 385 |
+
# sorted_raw = torch.argsort(raw_logits_to_show[0], descending=True)
|
| 386 |
+
# raw_rank = (sorted_raw == first_generated_id).nonzero(as_tuple=True)[
|
| 387 |
+
# 0
|
| 388 |
+
# ].item() + 1
|
| 389 |
+
# print(f" Raw logits rank: {raw_rank}")
|
| 390 |
+
|
| 391 |
+
# print("=" * 80 + "\n")
|
| 392 |
+
|
| 393 |
+
# # Trim the prompt tokens from generated sequences
|
| 394 |
+
# generated_ids_trimmed = [
|
| 395 |
+
# out_ids[len(in_ids) :]
|
| 396 |
+
# for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 397 |
+
# ]
|
| 398 |
+
|
| 399 |
+
# # Decode the text
|
| 400 |
+
# output_response = self.processor.batch_decode(
|
| 401 |
+
# generated_ids_trimmed,
|
| 402 |
+
# skip_special_tokens=True,
|
| 403 |
+
# clean_up_tokenization_spaces=False,
|
| 404 |
+
# )[0]
|
| 405 |
+
|
| 406 |
+
# # Convert scores to probabilities
|
| 407 |
+
# # scores is a tuple of (batch_size, vocab_size) tensors, one per generated token
|
| 408 |
+
# selected_token_probs = []
|
| 409 |
+
# selected_token_logits = []
|
| 410 |
+
# first_token_probs = torch.softmax(scores[0], dim=-1)
|
| 411 |
+
|
| 412 |
+
# # Now, find indices of tokens in token_choices and get their probabilities
|
| 413 |
+
# for token_choice in token_choices:
|
| 414 |
+
# # Tokenize the choice - encode returns a list, we want the first actual token (skip special tokens)
|
| 415 |
+
# token_index = self.processor.tokenizer.encode(
|
| 416 |
+
# token_choice, add_special_tokens=False
|
| 417 |
+
# )[0]
|
| 418 |
+
# selected_token_probs.append(first_token_probs[0, token_index].item())
|
| 419 |
+
# selected_token_logits.append(scores[0][0, token_index].item())
|
| 420 |
+
|
| 421 |
+
# # Compute confidence as the ratio of first token's probability to the sum of all probabilities in selected_token_probs
|
| 422 |
+
# if return_confidence:
|
| 423 |
+
# first_token_id = generated_ids_trimmed[0][
|
| 424 |
+
# 0
|
| 425 |
+
# ].item() # First token of the first sequence
|
| 426 |
+
# confidence = (
|
| 427 |
+
# first_token_probs[0, first_token_id].item() / sum(selected_token_probs)
|
| 428 |
+
# if sum(selected_token_probs) > 0
|
| 429 |
+
# else 0.0
|
| 430 |
+
# )
|
| 431 |
+
# return {
|
| 432 |
+
# "response": output_response,
|
| 433 |
+
# "confidence": confidence,
|
| 434 |
+
# }
|
| 435 |
+
|
| 436 |
+
# # Return token logits
|
| 437 |
+
# else:
|
| 438 |
+
# token_logits = dict(zip(token_choices, selected_token_logits))
|
| 439 |
+
# top_k_logits_indices = torch.topk(scores[0], k=top_k_tokens, dim=-1)
|
| 440 |
+
# top_k_tokens_list: List[Tuple[str, int, float]] = []
|
| 441 |
+
# for i in range(top_k_tokens):
|
| 442 |
+
# logit_index = top_k_logits_indices.indices[0, i].item()
|
| 443 |
+
# token = self.processor.decode(logit_index)
|
| 444 |
+
# logit = top_k_logits_indices.values[0, i].item()
|
| 445 |
+
# top_k_tokens_list.append((token, logit_index, logit))
|
| 446 |
+
# return {
|
| 447 |
+
# "response": output_response,
|
| 448 |
+
# "top_k_tokens": top_k_tokens_list,
|
| 449 |
+
# "token_logits": token_logits,
|
| 450 |
+
# }
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# if __name__ == "__main__":
|
| 454 |
+
# model_path = "Qwen/Qwen3-VL-4B-Instruct" # "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 455 |
+
# model = Qwen3VLModel(model_path)
|
| 456 |
+
# prompt = 'Does the following action accurately describe the one shown in the video? \nAnswer with "Yes" or "No".\n\nAction: Trying to bend stick so nothing happens\n\nConditions which may/may not be true BEFORE the aforementioned action occurs:\n- Stick is held by hands at two distinct points\n- Stick is intact\n\nConditions which may/may not be true AFTER the aforementioned action occurs:\n- Stick retains its original geometric shape\n- Stick remains intact\n\nAnswer:'
|
| 457 |
+
# token_choices = ["Yes", "No"]
|
| 458 |
+
# video_path = (
|
| 459 |
+
# "/home/shreyasj/Syed/data/Something-Something-V2/pre-post/videos/188064.mp4"
|
| 460 |
+
# )
|
| 461 |
+
|
| 462 |
+
# generation_config = {
|
| 463 |
+
# "max_new_tokens": 128,
|
| 464 |
+
# "do_sample": True, # False enables greedy sampling, which invalidates things like temperature, top-K, top-P. Allows return of raw logits
|
| 465 |
+
# "temperature": 0.7,
|
| 466 |
+
# "logits_temperature": 1.0,
|
| 467 |
+
# "fps": 1.0,
|
| 468 |
+
# "return_confidence": False,
|
| 469 |
+
# "top_k_tokens": 10,
|
| 470 |
+
# "debug": False,
|
| 471 |
+
# }
|
| 472 |
+
# output = model.chat_with_confidence(
|
| 473 |
+
# prompt, video_path, token_choices=token_choices, **generation_config
|
| 474 |
+
# )
|
| 475 |
+
# response = output["response"]
|
| 476 |
+
# print(f"Response: {response}")
|
| 477 |
+
|
| 478 |
+
# if generation_config["return_confidence"]:
|
| 479 |
+
# confidence = output["confidence"]
|
| 480 |
+
# print(f"Confidence: {confidence}")
|
| 481 |
+
# else:
|
| 482 |
+
# # If do_sample is True, logits pass through logit warpers which filter out un-important tokens (based on logits) to -inf,
|
| 483 |
+
# # otherwise, the raw logits are used, which are not filtered.
|
| 484 |
+
# logits_type = "POST-PROCESSED" if generation_config["do_sample"] else "RAW"
|
| 485 |
+
# top_k_tokens = output["top_k_tokens"]
|
| 486 |
+
# for i in range(len(top_k_tokens)):
|
| 487 |
+
# print(f"Top {i+1} token: {top_k_tokens[i][0]}")
|
| 488 |
+
# print(f"Top {i+1} token logit: {top_k_tokens[i][2]}")
|
| 489 |
+
# print("--------------------------------")
|