DW-KhotTaeVL-2B-QueryFrames / dw_queryframes.py
commandeaw's picture
Fix transformers 5.x API change: get_text_features now returns BaseModelOutputWithPooling
5e31798 verified
"""DW-KhotTaeVL-2B-QueryFrames — query-aware frame selection for video MCQ.
Single-file inference module. Wraps stock Qwen3-VL-2B-Instruct with a
CLIP-ViT-L/14 query-aware frame selector and an optional task-type-aware
uniform-fallback policy.
Usage::
from dw_queryframes import QueryFrames
fv = QueryFrames(device="mps")
answer = fv.answer_mcq(
video_path="cooking.mp4",
question="What does the chef do after pouring the oil?",
options=["Stirs the oil", "Adds salt", "Pours broth", "Chops herbs"],
task_type=None, # or "Action Recognition" etc. for hybrid mode
)
License: Apache 2.0 (this code)
Copyright 2026 Deaw (HF: @commandeaw)
Base model: Qwen3-VL-2B-Instruct (Apache 2.0)
Frame scorer: openai/clip-vit-large-patch14 (MIT)
Always credit Qwen3-VL-Instruct as the base when using this work.
"""
from __future__ import annotations
import re
import os
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as F
from PIL import Image
# Tasks where stock-64f does NOT outperform stock-8f on Video-MME mini
# (measured: Object Reasoning Δ -0.083, Temporal Reasoning Δ +0.000).
# For these tasks, frame-coverage is not the bottleneck; uniform sampling
# is at least as good as query-aware. The hybrid policy uses uniform
# selection for these task types when a label is provided.
NO_FRAME_GAIN_TASKS = frozenset({"Object Reasoning", "Temporal Reasoning"})
PROMPT_TEMPLATE = (
"Select the best answer based on the video.\n\n"
"Question: {question}\n"
"Options:\n{options}\n"
"Answer with only the letter."
)
LETTER_RE = re.compile(r"\b([ABCD])\b", re.IGNORECASE)
ANSWER_LINE_RE = re.compile(r"Answer:\s*([ABCD])\b", re.IGNORECASE)
class QueryFrames:
"""Query-aware frame selection over stock Qwen3-VL-2B-Instruct."""
def __init__(
self,
base_model: str = "Qwen/Qwen3-VL-2B-Instruct",
clip_model: str = "openai/clip-vit-large-patch14",
device: str = "auto",
max_pixels: int = 262_144,
max_new_tokens: int = 8,
n_frames: int = 8,
n_candidates: int = 32,
):
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
self.device = self._resolve_device(device)
self.n_frames = n_frames
self.n_candidates = n_candidates
self.max_new_tokens = max_new_tokens
from transformers import (
AutoProcessor, Qwen3VLForConditionalGeneration,
CLIPModel, CLIPProcessor,
)
self.qwen_processor = AutoProcessor.from_pretrained(base_model, max_pixels=max_pixels)
self.qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
base_model, dtype=torch.bfloat16,
).to(self.device).eval()
self.clip_model = CLIPModel.from_pretrained(
clip_model, torch_dtype=torch.float32,
).to(self.device).eval()
self.clip_processor = CLIPProcessor.from_pretrained(clip_model)
@staticmethod
def _resolve_device(device: str) -> str:
if device == "auto":
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
return device
def sample_uniform_candidates(self, video_path: str | Path) -> list[Image.Image]:
"""Sample ``n_candidates`` uniformly-spaced frames as PIL images."""
import decord
vid = decord.VideoReader(str(video_path))
total = len(vid)
step = total / (self.n_candidates + 1)
indices = [int((i + 1) * step) for i in range(self.n_candidates)]
return [Image.fromarray(vid[i].asnumpy()) for i in indices]
def select_frames(
self,
candidates: list[Image.Image],
question: str,
) -> list[Image.Image]:
"""Return ``n_frames`` images: top-K by CLIP similarity to question,
sorted by original temporal index (preserving sequence)."""
inputs = self.clip_processor(
text=[question], images=candidates,
return_tensors="pt", padding=True, truncation=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.inference_mode():
# transformers ≤ 4.x returns a tensor directly; ≥ 5.x returns
# a BaseModelOutputWithPooling whose .pooler_output is the
# projected embedding. Handle both.
text_out = self.clip_model.get_text_features(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
text_emb = (text_out.pooler_output
if hasattr(text_out, "pooler_output") else text_out)
image_out = self.clip_model.get_image_features(
pixel_values=inputs["pixel_values"]
)
image_embs = (image_out.pooler_output
if hasattr(image_out, "pooler_output") else image_out)
text_emb = F.normalize(text_emb, dim=-1)
image_embs = F.normalize(image_embs, dim=-1)
sims = (text_emb @ image_embs.T).squeeze(0).float().cpu()
topk = sims.topk(self.n_frames).indices.tolist()
topk_sorted = sorted(topk)
return [candidates[i] for i in topk_sorted]
def select_uniform(self, candidates: list[Image.Image]) -> list[Image.Image]:
"""Return ``n_frames`` images sampled uniformly from candidates."""
step = len(candidates) / self.n_frames
idx = [int((k + 0.5) * step) for k in range(self.n_frames)]
idx = [min(i, len(candidates) - 1) for i in idx]
return [candidates[i] for i in idx]
def answer_mcq(
self,
video_path: str | Path,
question: str,
options: list[str],
task_type: Optional[str] = None,
) -> dict:
"""Answer one MCQ question on a video.
Args:
video_path: path to .mp4 (or any decord-readable video)
question: string question (no options)
options: list of 4 option strings (will be lettered A-D)
task_type: optional task category. If provided and matches
a known no-frame-gain task, falls back to
uniform sampling for collision-safe behavior.
Returns:
dict with keys: pred (letter), raw (model output),
frames_used ("query_aware" | "uniform_fallback"),
n_candidates, latency_clip_s, latency_gen_s.
"""
import time
candidates = self.sample_uniform_candidates(video_path)
# Decide policy.
use_uniform = task_type in NO_FRAME_GAIN_TASKS
t1 = time.time()
if use_uniform:
frames = self.select_uniform(candidates)
else:
frames = self.select_frames(candidates, question)
clip_dt = time.time() - t1
# Build Qwen prompt and run inference.
opts_text = "\n".join(f"{chr(65+i)}. {str(o).strip()}"
for i, o in enumerate(options))
prompt = PROMPT_TEMPLATE.format(question=question, options=opts_text)
messages = [{"role": "user", "content":
[{"type": "image"} for _ in frames]
+ [{"type": "text", "text": prompt}]}]
text_in = self.qwen_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = self.qwen_processor(
text=[text_in], images=frames,
return_tensors="pt", padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
t2 = time.time()
with torch.inference_mode():
out_ids = self.qwen_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
temperature=1.0,
)
gen_dt = time.time() - t2
new_tokens = out_ids[0, inputs["input_ids"].shape[1]:]
raw = self.qwen_processor.tokenizer.decode(
new_tokens, skip_special_tokens=True,
)
pred = self._extract_letter(raw)
return {
"pred": pred,
"raw": raw,
"frames_used": "uniform_fallback" if use_uniform else "query_aware",
"n_candidates": self.n_candidates,
"latency_clip_s": round(clip_dt, 3),
"latency_gen_s": round(gen_dt, 3),
}
@staticmethod
def _extract_letter(text: str) -> Optional[str]:
s = text or ""
m = ANSWER_LINE_RE.search(s)
if m:
return m.group(1).upper()
m = LETTER_RE.search(s)
return m.group(1).upper() if m else None
__all__ = ["QueryFrames", "NO_FRAME_GAIN_TASKS"]