"""DW-KhotTaeVL-2B-QueryFrames — query-aware frame selection for video MCQ. Single-file inference module. Wraps stock Qwen3-VL-2B-Instruct with a CLIP-ViT-L/14 query-aware frame selector and an optional task-type-aware uniform-fallback policy. Usage:: from dw_queryframes import QueryFrames fv = QueryFrames(device="mps") answer = fv.answer_mcq( video_path="cooking.mp4", question="What does the chef do after pouring the oil?", options=["Stirs the oil", "Adds salt", "Pours broth", "Chops herbs"], task_type=None, # or "Action Recognition" etc. for hybrid mode ) License: Apache 2.0 (this code) Copyright 2026 Deaw (HF: @commandeaw) Base model: Qwen3-VL-2B-Instruct (Apache 2.0) Frame scorer: openai/clip-vit-large-patch14 (MIT) Always credit Qwen3-VL-Instruct as the base when using this work. """ from __future__ import annotations import re import os from pathlib import Path from typing import Optional import torch import torch.nn.functional as F from PIL import Image # Tasks where stock-64f does NOT outperform stock-8f on Video-MME mini # (measured: Object Reasoning Δ -0.083, Temporal Reasoning Δ +0.000). # For these tasks, frame-coverage is not the bottleneck; uniform sampling # is at least as good as query-aware. The hybrid policy uses uniform # selection for these task types when a label is provided. NO_FRAME_GAIN_TASKS = frozenset({"Object Reasoning", "Temporal Reasoning"}) PROMPT_TEMPLATE = ( "Select the best answer based on the video.\n\n" "Question: {question}\n" "Options:\n{options}\n" "Answer with only the letter." ) LETTER_RE = re.compile(r"\b([ABCD])\b", re.IGNORECASE) ANSWER_LINE_RE = re.compile(r"Answer:\s*([ABCD])\b", re.IGNORECASE) class QueryFrames: """Query-aware frame selection over stock Qwen3-VL-2B-Instruct.""" def __init__( self, base_model: str = "Qwen/Qwen3-VL-2B-Instruct", clip_model: str = "openai/clip-vit-large-patch14", device: str = "auto", max_pixels: int = 262_144, max_new_tokens: int = 8, n_frames: int = 8, n_candidates: int = 32, ): os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") self.device = self._resolve_device(device) self.n_frames = n_frames self.n_candidates = n_candidates self.max_new_tokens = max_new_tokens from transformers import ( AutoProcessor, Qwen3VLForConditionalGeneration, CLIPModel, CLIPProcessor, ) self.qwen_processor = AutoProcessor.from_pretrained(base_model, max_pixels=max_pixels) self.qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( base_model, dtype=torch.bfloat16, ).to(self.device).eval() self.clip_model = CLIPModel.from_pretrained( clip_model, torch_dtype=torch.float32, ).to(self.device).eval() self.clip_processor = CLIPProcessor.from_pretrained(clip_model) @staticmethod def _resolve_device(device: str) -> str: if device == "auto": if torch.backends.mps.is_available(): return "mps" if torch.cuda.is_available(): return "cuda" return "cpu" return device def sample_uniform_candidates(self, video_path: str | Path) -> list[Image.Image]: """Sample ``n_candidates`` uniformly-spaced frames as PIL images.""" import decord vid = decord.VideoReader(str(video_path)) total = len(vid) step = total / (self.n_candidates + 1) indices = [int((i + 1) * step) for i in range(self.n_candidates)] return [Image.fromarray(vid[i].asnumpy()) for i in indices] def select_frames( self, candidates: list[Image.Image], question: str, ) -> list[Image.Image]: """Return ``n_frames`` images: top-K by CLIP similarity to question, sorted by original temporal index (preserving sequence).""" inputs = self.clip_processor( text=[question], images=candidates, return_tensors="pt", padding=True, truncation=True, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.inference_mode(): # transformers ≤ 4.x returns a tensor directly; ≥ 5.x returns # a BaseModelOutputWithPooling whose .pooler_output is the # projected embedding. Handle both. text_out = self.clip_model.get_text_features( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) text_emb = (text_out.pooler_output if hasattr(text_out, "pooler_output") else text_out) image_out = self.clip_model.get_image_features( pixel_values=inputs["pixel_values"] ) image_embs = (image_out.pooler_output if hasattr(image_out, "pooler_output") else image_out) text_emb = F.normalize(text_emb, dim=-1) image_embs = F.normalize(image_embs, dim=-1) sims = (text_emb @ image_embs.T).squeeze(0).float().cpu() topk = sims.topk(self.n_frames).indices.tolist() topk_sorted = sorted(topk) return [candidates[i] for i in topk_sorted] def select_uniform(self, candidates: list[Image.Image]) -> list[Image.Image]: """Return ``n_frames`` images sampled uniformly from candidates.""" step = len(candidates) / self.n_frames idx = [int((k + 0.5) * step) for k in range(self.n_frames)] idx = [min(i, len(candidates) - 1) for i in idx] return [candidates[i] for i in idx] def answer_mcq( self, video_path: str | Path, question: str, options: list[str], task_type: Optional[str] = None, ) -> dict: """Answer one MCQ question on a video. Args: video_path: path to .mp4 (or any decord-readable video) question: string question (no options) options: list of 4 option strings (will be lettered A-D) task_type: optional task category. If provided and matches a known no-frame-gain task, falls back to uniform sampling for collision-safe behavior. Returns: dict with keys: pred (letter), raw (model output), frames_used ("query_aware" | "uniform_fallback"), n_candidates, latency_clip_s, latency_gen_s. """ import time candidates = self.sample_uniform_candidates(video_path) # Decide policy. use_uniform = task_type in NO_FRAME_GAIN_TASKS t1 = time.time() if use_uniform: frames = self.select_uniform(candidates) else: frames = self.select_frames(candidates, question) clip_dt = time.time() - t1 # Build Qwen prompt and run inference. opts_text = "\n".join(f"{chr(65+i)}. {str(o).strip()}" for i, o in enumerate(options)) prompt = PROMPT_TEMPLATE.format(question=question, options=opts_text) messages = [{"role": "user", "content": [{"type": "image"} for _ in frames] + [{"type": "text", "text": prompt}]}] text_in = self.qwen_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = self.qwen_processor( text=[text_in], images=frames, return_tensors="pt", padding=True, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} t2 = time.time() with torch.inference_mode(): out_ids = self.qwen_model.generate( **inputs, max_new_tokens=self.max_new_tokens, do_sample=False, temperature=1.0, ) gen_dt = time.time() - t2 new_tokens = out_ids[0, inputs["input_ids"].shape[1]:] raw = self.qwen_processor.tokenizer.decode( new_tokens, skip_special_tokens=True, ) pred = self._extract_letter(raw) return { "pred": pred, "raw": raw, "frames_used": "uniform_fallback" if use_uniform else "query_aware", "n_candidates": self.n_candidates, "latency_clip_s": round(clip_dt, 3), "latency_gen_s": round(gen_dt, 3), } @staticmethod def _extract_letter(text: str) -> Optional[str]: s = text or "" m = ANSWER_LINE_RE.search(s) if m: return m.group(1).upper() m = LETTER_RE.search(s) return m.group(1).upper() if m else None __all__ = ["QueryFrames", "NO_FRAME_GAIN_TASKS"]