import os # Expandable segments to avoid allocator fragmentation under memory spikes os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import spaces # MUST be before any torch/CUDA import import cv2 import re import json import torch import numpy as np from PIL import Image from typing import List, Optional, Tuple import tempfile import gradio as gr from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # ── Load model at module scope (ZeroGPU rule 2) ────────────────────────────── processor = AutoProcessor.from_pretrained(MODEL_ID) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation="sdpa", ).to("cuda") # ── VLM call helper ────────────────────────────────────────────────────────── def vlm_call(images: List[Image.Image], question: str, system_prompt: str = "You are a highly strict UI navigation assistant designed to output JSON.") -> str: """Call the local VLM with images and a question, return text response.""" content = [] for img in images: content.append({"type": "image", "image": img}) content.append({"type": "text", "text": question}) messages = [ {"role": "system", "content": [{"type": "text", "text": system_prompt}]}, {"role": "user", "content": content}, ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=[images] if images else None, padding=True, return_tensors="pt", ).to("cuda") with torch.no_grad(): output_ids = model.generate(**inputs, max_new_tokens=8192, do_sample=False, temperature=1.0) # Trim the input tokens from output input_len = inputs["input_ids"].shape[1] output_text = processor.batch_decode( output_ids[:, input_len:], skip_special_tokens=True )[0] return output_text def parse_json_response(text: str): """Extract a JSON object from a text response.""" try: match = re.search(r'\{.*\}', text, re.DOTALL) if match: return json.loads(match.group(0)) except Exception: pass return None # ── Video utilities ────────────────────────────────────────────────────────── def extract_frame(video_path: str, frame_idx: int) -> Optional[Image.Image]: """Extract a single frame from the video as PIL Image.""" cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() cap.release() if not ret: return None return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) def compute_color_histogram(img: Image.Image) -> np.ndarray: """Compute a normalized 3-channel color histogram.""" arr = np.array(img) hist = cv2.calcHist([arr], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256]) cv2.normalize(hist, hist) return hist def frame_similarity(hist1: np.ndarray, hist2: np.ndarray) -> float: """Compare two color histograms using correlation.""" return float(cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)) def is_frame_redundant(new_hist: np.ndarray, existing_hists: List[np.ndarray], threshold: float = 0.985) -> bool: """Check if a new frame is too similar to existing ones.""" for h in existing_hists: if frame_similarity(new_hist, h) >= threshold: return True return False # ── TASKER core: A* tree search keyframe extraction ───────────────────────── class VideoSeg: """A video segment (tree node).""" def __init__(self, start: int, end: int): self.start = start self.end = end def find_visual_change_split_point(video_path: str, seg_start: int, seg_end: int) -> int: """Find the frame with the largest visual change in a segment.""" midpoint = (seg_start + seg_end) // 2 try: seg_length = seg_end - seg_start if seg_length <= 2: return midpoint cap = cv2.VideoCapture(video_path) num_samples = min(seg_length, 10) step = max(1, seg_length // num_samples) sample_indices = list(range(seg_start, seg_end, step)) if sample_indices[-1] != seg_end: sample_indices.append(seg_end) frames = {} hists = {} for idx in sample_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frames[idx] = frame hist = cv2.calcHist([frame], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256]) cv2.normalize(hist, hist) hists[idx] = hist if len(frames) < 2: cap.release() return midpoint sorted_indices = sorted(frames.keys()) max_diff = -1 best_a, best_b = sorted_indices[0], sorted_indices[-1] for i in range(len(sorted_indices) - 1): idx_a, idx_b = sorted_indices[i], sorted_indices[i + 1] if idx_a in hists and idx_b in hists: diff = 1.0 - cv2.compareHist(hists[idx_a], hists[idx_b], cv2.HISTCMP_CORREL) if diff > max_diff: max_diff = diff best_a, best_b = idx_a, idx_b candidate = best_b cap.release() # Clamp to valid range min_pos = seg_start + int(seg_length * 0.15) max_pos = seg_start + int(seg_length * 0.85) if candidate < min_pos or candidate > max_pos: return midpoint return candidate except Exception: return midpoint def a_star_select_segment(images: List[Image.Image], goal: str, segment_des: str) -> str: """A* strategy: balance goal-relevance and UI state changes.""" prompt = f"""You are provided with sequential images sampled from a video. Each image is labeled with its frame index. The images are shown in chronological order. Goal: {goal} Candidate segments (gaps between current frames): {segment_des} (A* Strategy - Balance missing goal-relevant info and visual state changes) Identify ONE single candidate segment that BEST satisfies BOTH conditions simultaneously: 1. GOAL PROXIMITY: The segment likely contains crucial missing actions that are necessary steps toward achieving the Goal. 2. STATE CHANGE MAGNITUDE: The segment whose boundary frames show the MOST different visual states is more likely to contain important operations. Return JSON format: {{"frame_descriptions": [{{"segment_id": "1", "description": "Best A* candidate: missing goal step + visual state change"}}]}} """ return vlm_call(images, prompt) def qa_and_reflect(images: List[Image.Image], goal: str) -> Tuple[str, int]: """Evaluate whether current frames are sufficient.""" prompt_qa = f"Task Goal: {goal}\nLook at these sequential frames. Describe the EXACT step-by-step actions that happen transitioning from one frame to the next." answer = vlm_call(images, prompt_qa, system_prompt="You are a helpful video analysis assistant.") prompt_eval = f"""Task Goal: {goal} Your sequential analysis: {answer} Evaluate your confidence level strictly: 1: Severe Jumps (There are completely missing screens or sudden state changes. MUST expand.) 2: Minor Disconnects (The flow makes sense, but some intermediate actions are missing. Should expand.) 3: Strong Continuity (The frames capture all important actions and transitions. No key step is skipped.) Output JSON exactly like this: {{"confidence": 3}} """ conf_str = vlm_call(images, prompt_eval) conf_json = parse_json_response(conf_str) confidence = conf_json.get("confidence", 1) if conf_json else 1 return answer, int(confidence) @spaces.GPU(duration=240) def extract_keyframes(video_path: str, goal: str, search_strategy: str = "a_star", max_frames: int = 10, min_frames: int = 6, min_steps: int = 3, conf_lower: int = 3, progress=gr.Progress()): """ TASKER keyframe extraction: tree-search with VLM-guided segment selection. Args: video_path: Path to the input video. goal: Task query describing what the user wants to see. search_strategy: One of "a_star", "bfs", "gbfs", "dijkstra". max_frames: Maximum number of keyframes to extract. min_frames: Minimum number of frames before confidence check can stop. min_steps: Minimum expansion steps before confidence check can stop. conf_lower: Confidence threshold (1-3) to stop searching. Returns: List of (PIL Image, caption) tuples for gallery display, plus a summary string. """ cap = cv2.VideoCapture(video_path) num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) cap.release() if num_frames <= 0 or fps <= 0: return [], "Error: Could not read video file. Please upload a valid video." # ── Initial uniform sampling ───────────────────────────────────────────── init_frames = 4 content_start = 0 content_end = num_frames - 1 if content_end - content_start + 1 <= init_frames: sample_idx = list(range(content_start, content_end + 1)) else: interval = max(1, (content_end - content_start + 1) // (init_frames - 1)) sample_idx = list(range(content_start, content_end + 1, interval)) if sample_idx[-1] != content_end: sample_idx.append(content_end) progress(0.1, desc=f"Initial sampling: {len(sample_idx)} frames from {num_frames} total") video_segments = [VideoSeg(sample_idx[i-1], sample_idx[i]) for i in range(1, len(sample_idx))] # Histogram cache for dedup hist_cache = {} frozen_segments = set() effective_step = 0 last_confidence = 0 max_total_attempts = max_frames + 10 for attempt in range(1, max_total_attempts + 1): current_frames = len(sample_idx) if current_frames >= max_frames: break # Extract current frames as images images = [] for idx in sample_idx: img = extract_frame(video_path, idx) if img is not None: images.append(img) if not images: break progress( 0.1 + 0.6 * (attempt / max_total_attempts), desc=f"Step {attempt}: {current_frames} frames, evaluating..." ) # Confidence check if current_frames >= min_frames and effective_step > min_steps: _, confidence = qa_and_reflect(images, goal) last_confidence = confidence if confidence >= conf_lower: break else: if current_frames < min_frames: pass # forced expansion # Build segment descriptions frame_to_img_idx = {frame: i + 1 for i, frame in enumerate(sample_idx)} segment_des_lines = [] for i, seg in enumerate(video_segments): seg_id = i + 1 if (seg.start, seg.end) in frozen_segments: continue start_img = frame_to_img_idx.get(seg.start, "?") end_img = frame_to_img_idx.get(seg.end, "?") segment_des_lines.append( f" Segment {seg_id}: frames {seg.start}-{seg.end} (Image #{start_img} -> Image #{end_img})" ) segment_des_str = "\n".join(segment_des_lines) if not segment_des_str: break # VLM segment selection try: if search_strategy == "bfs": response = vlm_call(images, f"""You are provided with sequential images sampled from a video. Goal: {goal} Candidate segments: {segment_des_str} Select MULTIPLE segments that likely contain crucial missing actions. Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""") elif search_strategy == "gbfs": response = vlm_call(images, f"""You are provided with sequential images sampled from a video. Goal: {goal} Candidate segments: {segment_des_str} Select the SINGLE segment MOST LIKELY to contain crucial missing actions. Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""") elif search_strategy == "dijkstra": response = vlm_call(images, f"""You are provided with sequential images sampled from a video. Candidate segments: {segment_des_str} Select the SINGLE segment with the MOST significant visual state transition. Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""") else: # a_star response = a_star_select_segment(images, goal, segment_des_str) parsed = parse_json_response(response) except Exception as e: print(f"VLM call error at step {attempt}: {e}") parsed = None # Determine selected segment IDs selected_seg_ids = set() if parsed and "frame_descriptions" in parsed: for desc in parsed["frame_descriptions"]: for key in desc: if key.lower() == "segment_id": val = str(desc[key]).strip() nums = re.findall(r'\d+', val) if nums: seg_id = int(nums[0]) if 1 <= seg_id <= len(video_segments): selected_seg_ids.add(seg_id) break # Fallback: pick longest segment if not selected_seg_ids: longest_seg_id = None longest_len = 0 for i, seg in enumerate(video_segments): seg_len = seg.end - seg.start if seg_len > longest_len and seg_len > 1 and (seg.start, seg.end) not in frozen_segments: longest_len = seg_len longest_seg_id = i + 1 if longest_seg_id is not None: selected_seg_ids.add(longest_seg_id) if not selected_seg_ids: break # BFS quota limit if search_strategy == "bfs" and len(selected_seg_ids) > 1: remaining_quota = max_frames - len(sample_idx) if remaining_quota <= 0: break if len(selected_seg_ids) > remaining_quota: sorted_seg_ids = sorted(selected_seg_ids, key=lambda sid: video_segments[sid-1].end - video_segments[sid-1].start, reverse=True) selected_seg_ids = set(sorted_seg_ids[:remaining_quota]) # Split selected segments split_origin = {} new_segments = [] seg_counter = 0 for i, seg in enumerate(video_segments): seg_id = i + 1 if seg_id in selected_seg_ids: if seg.end - seg.start <= 1: seg_counter += 1 new_segments.append(VideoSeg(seg.start, seg.end)) else: sp = find_visual_change_split_point(video_path, seg.start, seg.end) split_origin[sp] = (seg.start, seg.end) seg_counter += 1 new_segments.append(VideoSeg(seg.start, sp)) seg_counter += 1 new_segments.append(VideoSeg(sp, seg.end)) else: seg_counter += 1 new_segments.append(VideoSeg(seg.start, seg.end)) video_segments = new_segments # Rebuild sample_idx sample_idx_set = set() for seg in video_segments: sample_idx_set.add(seg.start) sample_idx_set.add(seg.end) new_sample_idx = sorted(list(sample_idx_set)) # Visual deduplication new_frames = [idx for idx in new_sample_idx if idx not in set(sample_idx)] old_sample_set = set(sample_idx) # Compute histograms for old frames old_hists = [] for idx in sample_idx: img = extract_frame(video_path, idx) if img is not None: old_hists.append(compute_color_histogram(img)) frames_to_remove = [] accepted_new_hists = [] for new_idx in new_frames: new_img = extract_frame(video_path, new_idx) if new_img is None: continue new_hist = compute_color_histogram(new_img) all_compare_hists = old_hists + accepted_new_hists if is_frame_redundant(new_hist, all_compare_hists, threshold=0.985): frames_to_remove.append(new_idx) if new_idx in split_origin: frozen_segments.add(split_origin[new_idx]) else: accepted_new_hists.append(new_hist) if frames_to_remove: new_sample_idx = [idx for idx in new_sample_idx if idx not in frames_to_remove] new_sample_idx = sorted(new_sample_idx) video_segments = [VideoSeg(new_sample_idx[i-1], new_sample_idx[i]) for i in range(1, len(new_sample_idx))] actually_added = len(new_sample_idx) > len(sample_idx) sample_idx = new_sample_idx if actually_added: effective_step += 1 progress(0.85, desc="Finalizing keyframes...") # Force-fill if too few frames if len(sample_idx) < min_frames and last_confidence < conf_lower: max_force = min_frames + 5 for _ in range(max_force): if len(sample_idx) >= min_frames: break max_gap = 0 max_gap_idx = 0 for i in range(len(sample_idx) - 1): if (sample_idx[i], sample_idx[i+1]) in frozen_segments: continue gap = sample_idx[i+1] - sample_idx[i] if gap > max_gap: max_gap = gap max_gap_idx = i if max_gap <= 1: break sp = find_visual_change_split_point(video_path, sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1]) sp_img = extract_frame(video_path, sp) if sp_img is None: break sp_hist = compute_color_histogram(sp_img) existing_hists = [] for idx in sample_idx: img = extract_frame(video_path, idx) if img is not None: existing_hists.append(compute_color_histogram(img)) if is_frame_redundant(sp_hist, existing_hists, threshold=0.985): frozen_segments.add((sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1])) continue sample_idx.insert(max_gap_idx + 1, sp) # Extract final keyframes progress(0.95, desc="Extracting final keyframes...") gallery = [] for i, idx in enumerate(sample_idx): img = extract_frame(video_path, idx) if img is not None: timestamp = idx / fps if fps > 0 else 0 mins = int(timestamp // 60) secs = int(timestamp % 60) percent = (idx / max(1, num_frames - 1)) * 100 caption = f"Frame {i+1}/{len(sample_idx)} | idx={idx} | {mins:02d}:{secs:02d} | {percent:.1f}%" gallery.append((img, caption)) summary = ( f"**TASKER {search_strategy.upper()}** extracted **{len(gallery)}** keyframes " f"from {num_frames} total frames ({num_frames/fps:.1f}s video).\n\n" f"Search stats: {effective_step} effective expansion steps, " f"confidence={last_confidence}/3, " f"target range {min_frames}-{max_frames} frames." ) progress(1.0, desc="Done!") return gallery, summary # ── Gradio UI ─────────────────────────────────────────────────────────────── CUSTOM_CSS = """ #header { text-align: center; margin-bottom: 20px; } #header h1 { font-size: 2em; margin-bottom: 5px; } #header p { color: #666; font-size: 1.1em; } """ with gr.Blocks(css=CUSTOM_CSS, title="TASKER Keyframe Extractor") as demo: gr.HTML("""
Extract task-relevant keyframes from a video using VLM-guided tree search (A* / BFS / GBFS / Dijkstra)