Spaces:
Running on Zero
Running on Zero
| import os | |
| # Expandable segments to avoid allocator fragmentation under memory spikes | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import spaces # MUST be before any torch/CUDA import | |
| import cv2 | |
| import re | |
| import json | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from typing import List, Optional, Tuple | |
| import tempfile | |
| import gradio as gr | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
| MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| # ββ Load model at module scope (ZeroGPU rule 2) ββββββββββββββββββββββββββββββ | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| ).to("cuda") | |
| # ββ VLM call helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vlm_call(images: List[Image.Image], question: str, system_prompt: str = "You are a highly strict UI navigation assistant designed to output JSON.") -> str: | |
| """Call the local VLM with images and a question, return text response.""" | |
| content = [] | |
| for img in images: | |
| content.append({"type": "image", "image": img}) | |
| content.append({"type": "text", "text": question}) | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": system_prompt}]}, | |
| {"role": "user", "content": content}, | |
| ] | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor( | |
| text=[text], | |
| images=[images] if images else None, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to("cuda") | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, max_new_tokens=8192, do_sample=False, temperature=1.0) | |
| # Trim the input tokens from output | |
| input_len = inputs["input_ids"].shape[1] | |
| output_text = processor.batch_decode( | |
| output_ids[:, input_len:], skip_special_tokens=True | |
| )[0] | |
| return output_text | |
| def parse_json_response(text: str): | |
| """Extract a JSON object from a text response.""" | |
| try: | |
| match = re.search(r'\{.*\}', text, re.DOTALL) | |
| if match: | |
| return json.loads(match.group(0)) | |
| except Exception: | |
| pass | |
| return None | |
| # ββ Video utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_frame(video_path: str, frame_idx: int) -> Optional[Image.Image]: | |
| """Extract a single frame from the video as PIL Image.""" | |
| cap = cv2.VideoCapture(video_path) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None | |
| return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| def compute_color_histogram(img: Image.Image) -> np.ndarray: | |
| """Compute a normalized 3-channel color histogram.""" | |
| arr = np.array(img) | |
| hist = cv2.calcHist([arr], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256]) | |
| cv2.normalize(hist, hist) | |
| return hist | |
| def frame_similarity(hist1: np.ndarray, hist2: np.ndarray) -> float: | |
| """Compare two color histograms using correlation.""" | |
| return float(cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)) | |
| def is_frame_redundant(new_hist: np.ndarray, existing_hists: List[np.ndarray], threshold: float = 0.985) -> bool: | |
| """Check if a new frame is too similar to existing ones.""" | |
| for h in existing_hists: | |
| if frame_similarity(new_hist, h) >= threshold: | |
| return True | |
| return False | |
| # ββ TASKER core: A* tree search keyframe extraction βββββββββββββββββββββββββ | |
| class VideoSeg: | |
| """A video segment (tree node).""" | |
| def __init__(self, start: int, end: int): | |
| self.start = start | |
| self.end = end | |
| def find_visual_change_split_point(video_path: str, seg_start: int, seg_end: int) -> int: | |
| """Find the frame with the largest visual change in a segment.""" | |
| midpoint = (seg_start + seg_end) // 2 | |
| try: | |
| seg_length = seg_end - seg_start | |
| if seg_length <= 2: | |
| return midpoint | |
| cap = cv2.VideoCapture(video_path) | |
| num_samples = min(seg_length, 10) | |
| step = max(1, seg_length // num_samples) | |
| sample_indices = list(range(seg_start, seg_end, step)) | |
| if sample_indices[-1] != seg_end: | |
| sample_indices.append(seg_end) | |
| frames = {} | |
| hists = {} | |
| for idx in sample_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frames[idx] = frame | |
| hist = cv2.calcHist([frame], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256]) | |
| cv2.normalize(hist, hist) | |
| hists[idx] = hist | |
| if len(frames) < 2: | |
| cap.release() | |
| return midpoint | |
| sorted_indices = sorted(frames.keys()) | |
| max_diff = -1 | |
| best_a, best_b = sorted_indices[0], sorted_indices[-1] | |
| for i in range(len(sorted_indices) - 1): | |
| idx_a, idx_b = sorted_indices[i], sorted_indices[i + 1] | |
| if idx_a in hists and idx_b in hists: | |
| diff = 1.0 - cv2.compareHist(hists[idx_a], hists[idx_b], cv2.HISTCMP_CORREL) | |
| if diff > max_diff: | |
| max_diff = diff | |
| best_a, best_b = idx_a, idx_b | |
| candidate = best_b | |
| cap.release() | |
| # Clamp to valid range | |
| min_pos = seg_start + int(seg_length * 0.15) | |
| max_pos = seg_start + int(seg_length * 0.85) | |
| if candidate < min_pos or candidate > max_pos: | |
| return midpoint | |
| return candidate | |
| except Exception: | |
| return midpoint | |
| def a_star_select_segment(images: List[Image.Image], goal: str, segment_des: str) -> str: | |
| """A* strategy: balance goal-relevance and UI state changes.""" | |
| prompt = f"""You are provided with sequential images sampled from a video. | |
| Each image is labeled with its frame index. The images are shown in chronological order. | |
| Goal: {goal} | |
| Candidate segments (gaps between current frames): | |
| {segment_des} | |
| (A* Strategy - Balance missing goal-relevant info and visual state changes) | |
| Identify ONE single candidate segment that BEST satisfies BOTH conditions simultaneously: | |
| 1. GOAL PROXIMITY: The segment likely contains crucial missing actions that are necessary steps toward achieving the Goal. | |
| 2. STATE CHANGE MAGNITUDE: The segment whose boundary frames show the MOST different visual states is more likely to contain important operations. | |
| Return JSON format: {{"frame_descriptions": [{{"segment_id": "1", "description": "Best A* candidate: missing goal step + visual state change"}}]}} | |
| """ | |
| return vlm_call(images, prompt) | |
| def qa_and_reflect(images: List[Image.Image], goal: str) -> Tuple[str, int]: | |
| """Evaluate whether current frames are sufficient.""" | |
| prompt_qa = f"Task Goal: {goal}\nLook at these sequential frames. Describe the EXACT step-by-step actions that happen transitioning from one frame to the next." | |
| answer = vlm_call(images, prompt_qa, system_prompt="You are a helpful video analysis assistant.") | |
| prompt_eval = f"""Task Goal: {goal} | |
| Your sequential analysis: {answer} | |
| Evaluate your confidence level strictly: | |
| 1: Severe Jumps (There are completely missing screens or sudden state changes. MUST expand.) | |
| 2: Minor Disconnects (The flow makes sense, but some intermediate actions are missing. Should expand.) | |
| 3: Strong Continuity (The frames capture all important actions and transitions. No key step is skipped.) | |
| Output JSON exactly like this: {{"confidence": 3}} | |
| """ | |
| conf_str = vlm_call(images, prompt_eval) | |
| conf_json = parse_json_response(conf_str) | |
| confidence = conf_json.get("confidence", 1) if conf_json else 1 | |
| return answer, int(confidence) | |
| def extract_keyframes(video_path: str, goal: str, search_strategy: str = "a_star", max_frames: int = 10, min_frames: int = 6, min_steps: int = 3, conf_lower: int = 3, progress=gr.Progress()): | |
| """ | |
| TASKER keyframe extraction: tree-search with VLM-guided segment selection. | |
| Args: | |
| video_path: Path to the input video. | |
| goal: Task query describing what the user wants to see. | |
| search_strategy: One of "a_star", "bfs", "gbfs", "dijkstra". | |
| max_frames: Maximum number of keyframes to extract. | |
| min_frames: Minimum number of frames before confidence check can stop. | |
| min_steps: Minimum expansion steps before confidence check can stop. | |
| conf_lower: Confidence threshold (1-3) to stop searching. | |
| Returns: | |
| List of (PIL Image, caption) tuples for gallery display, plus a summary string. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| if num_frames <= 0 or fps <= 0: | |
| return [], "Error: Could not read video file. Please upload a valid video." | |
| # ββ Initial uniform sampling βββββββββββββββββββββββββββββββββββββββββββββ | |
| init_frames = 4 | |
| content_start = 0 | |
| content_end = num_frames - 1 | |
| if content_end - content_start + 1 <= init_frames: | |
| sample_idx = list(range(content_start, content_end + 1)) | |
| else: | |
| interval = max(1, (content_end - content_start + 1) // (init_frames - 1)) | |
| sample_idx = list(range(content_start, content_end + 1, interval)) | |
| if sample_idx[-1] != content_end: | |
| sample_idx.append(content_end) | |
| progress(0.1, desc=f"Initial sampling: {len(sample_idx)} frames from {num_frames} total") | |
| video_segments = [VideoSeg(sample_idx[i-1], sample_idx[i]) for i in range(1, len(sample_idx))] | |
| # Histogram cache for dedup | |
| hist_cache = {} | |
| frozen_segments = set() | |
| effective_step = 0 | |
| last_confidence = 0 | |
| max_total_attempts = max_frames + 10 | |
| for attempt in range(1, max_total_attempts + 1): | |
| current_frames = len(sample_idx) | |
| if current_frames >= max_frames: | |
| break | |
| # Extract current frames as images | |
| images = [] | |
| for idx in sample_idx: | |
| img = extract_frame(video_path, idx) | |
| if img is not None: | |
| images.append(img) | |
| if not images: | |
| break | |
| progress( | |
| 0.1 + 0.6 * (attempt / max_total_attempts), | |
| desc=f"Step {attempt}: {current_frames} frames, evaluating..." | |
| ) | |
| # Confidence check | |
| if current_frames >= min_frames and effective_step > min_steps: | |
| _, confidence = qa_and_reflect(images, goal) | |
| last_confidence = confidence | |
| if confidence >= conf_lower: | |
| break | |
| else: | |
| if current_frames < min_frames: | |
| pass # forced expansion | |
| # Build segment descriptions | |
| frame_to_img_idx = {frame: i + 1 for i, frame in enumerate(sample_idx)} | |
| segment_des_lines = [] | |
| for i, seg in enumerate(video_segments): | |
| seg_id = i + 1 | |
| if (seg.start, seg.end) in frozen_segments: | |
| continue | |
| start_img = frame_to_img_idx.get(seg.start, "?") | |
| end_img = frame_to_img_idx.get(seg.end, "?") | |
| segment_des_lines.append( | |
| f" Segment {seg_id}: frames {seg.start}-{seg.end} (Image #{start_img} -> Image #{end_img})" | |
| ) | |
| segment_des_str = "\n".join(segment_des_lines) | |
| if not segment_des_str: | |
| break | |
| # VLM segment selection | |
| try: | |
| if search_strategy == "bfs": | |
| response = vlm_call(images, f"""You are provided with sequential images sampled from a video. | |
| Goal: {goal} | |
| Candidate segments: | |
| {segment_des_str} | |
| Select MULTIPLE segments that likely contain crucial missing actions. | |
| Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""") | |
| elif search_strategy == "gbfs": | |
| response = vlm_call(images, f"""You are provided with sequential images sampled from a video. | |
| Goal: {goal} | |
| Candidate segments: | |
| {segment_des_str} | |
| Select the SINGLE segment MOST LIKELY to contain crucial missing actions. | |
| Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""") | |
| elif search_strategy == "dijkstra": | |
| response = vlm_call(images, f"""You are provided with sequential images sampled from a video. | |
| Candidate segments: | |
| {segment_des_str} | |
| Select the SINGLE segment with the MOST significant visual state transition. | |
| Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""") | |
| else: # a_star | |
| response = a_star_select_segment(images, goal, segment_des_str) | |
| parsed = parse_json_response(response) | |
| except Exception as e: | |
| print(f"VLM call error at step {attempt}: {e}") | |
| parsed = None | |
| # Determine selected segment IDs | |
| selected_seg_ids = set() | |
| if parsed and "frame_descriptions" in parsed: | |
| for desc in parsed["frame_descriptions"]: | |
| for key in desc: | |
| if key.lower() == "segment_id": | |
| val = str(desc[key]).strip() | |
| nums = re.findall(r'\d+', val) | |
| if nums: | |
| seg_id = int(nums[0]) | |
| if 1 <= seg_id <= len(video_segments): | |
| selected_seg_ids.add(seg_id) | |
| break | |
| # Fallback: pick longest segment | |
| if not selected_seg_ids: | |
| longest_seg_id = None | |
| longest_len = 0 | |
| for i, seg in enumerate(video_segments): | |
| seg_len = seg.end - seg.start | |
| if seg_len > longest_len and seg_len > 1 and (seg.start, seg.end) not in frozen_segments: | |
| longest_len = seg_len | |
| longest_seg_id = i + 1 | |
| if longest_seg_id is not None: | |
| selected_seg_ids.add(longest_seg_id) | |
| if not selected_seg_ids: | |
| break | |
| # BFS quota limit | |
| if search_strategy == "bfs" and len(selected_seg_ids) > 1: | |
| remaining_quota = max_frames - len(sample_idx) | |
| if remaining_quota <= 0: | |
| break | |
| if len(selected_seg_ids) > remaining_quota: | |
| sorted_seg_ids = sorted(selected_seg_ids, | |
| key=lambda sid: video_segments[sid-1].end - video_segments[sid-1].start, | |
| reverse=True) | |
| selected_seg_ids = set(sorted_seg_ids[:remaining_quota]) | |
| # Split selected segments | |
| split_origin = {} | |
| new_segments = [] | |
| seg_counter = 0 | |
| for i, seg in enumerate(video_segments): | |
| seg_id = i + 1 | |
| if seg_id in selected_seg_ids: | |
| if seg.end - seg.start <= 1: | |
| seg_counter += 1 | |
| new_segments.append(VideoSeg(seg.start, seg.end)) | |
| else: | |
| sp = find_visual_change_split_point(video_path, seg.start, seg.end) | |
| split_origin[sp] = (seg.start, seg.end) | |
| seg_counter += 1 | |
| new_segments.append(VideoSeg(seg.start, sp)) | |
| seg_counter += 1 | |
| new_segments.append(VideoSeg(sp, seg.end)) | |
| else: | |
| seg_counter += 1 | |
| new_segments.append(VideoSeg(seg.start, seg.end)) | |
| video_segments = new_segments | |
| # Rebuild sample_idx | |
| sample_idx_set = set() | |
| for seg in video_segments: | |
| sample_idx_set.add(seg.start) | |
| sample_idx_set.add(seg.end) | |
| new_sample_idx = sorted(list(sample_idx_set)) | |
| # Visual deduplication | |
| new_frames = [idx for idx in new_sample_idx if idx not in set(sample_idx)] | |
| old_sample_set = set(sample_idx) | |
| # Compute histograms for old frames | |
| old_hists = [] | |
| for idx in sample_idx: | |
| img = extract_frame(video_path, idx) | |
| if img is not None: | |
| old_hists.append(compute_color_histogram(img)) | |
| frames_to_remove = [] | |
| accepted_new_hists = [] | |
| for new_idx in new_frames: | |
| new_img = extract_frame(video_path, new_idx) | |
| if new_img is None: | |
| continue | |
| new_hist = compute_color_histogram(new_img) | |
| all_compare_hists = old_hists + accepted_new_hists | |
| if is_frame_redundant(new_hist, all_compare_hists, threshold=0.985): | |
| frames_to_remove.append(new_idx) | |
| if new_idx in split_origin: | |
| frozen_segments.add(split_origin[new_idx]) | |
| else: | |
| accepted_new_hists.append(new_hist) | |
| if frames_to_remove: | |
| new_sample_idx = [idx for idx in new_sample_idx if idx not in frames_to_remove] | |
| new_sample_idx = sorted(new_sample_idx) | |
| video_segments = [VideoSeg(new_sample_idx[i-1], new_sample_idx[i]) | |
| for i in range(1, len(new_sample_idx))] | |
| actually_added = len(new_sample_idx) > len(sample_idx) | |
| sample_idx = new_sample_idx | |
| if actually_added: | |
| effective_step += 1 | |
| progress(0.85, desc="Finalizing keyframes...") | |
| # Force-fill if too few frames | |
| if len(sample_idx) < min_frames and last_confidence < conf_lower: | |
| max_force = min_frames + 5 | |
| for _ in range(max_force): | |
| if len(sample_idx) >= min_frames: | |
| break | |
| max_gap = 0 | |
| max_gap_idx = 0 | |
| for i in range(len(sample_idx) - 1): | |
| if (sample_idx[i], sample_idx[i+1]) in frozen_segments: | |
| continue | |
| gap = sample_idx[i+1] - sample_idx[i] | |
| if gap > max_gap: | |
| max_gap = gap | |
| max_gap_idx = i | |
| if max_gap <= 1: | |
| break | |
| sp = find_visual_change_split_point(video_path, sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1]) | |
| sp_img = extract_frame(video_path, sp) | |
| if sp_img is None: | |
| break | |
| sp_hist = compute_color_histogram(sp_img) | |
| existing_hists = [] | |
| for idx in sample_idx: | |
| img = extract_frame(video_path, idx) | |
| if img is not None: | |
| existing_hists.append(compute_color_histogram(img)) | |
| if is_frame_redundant(sp_hist, existing_hists, threshold=0.985): | |
| frozen_segments.add((sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1])) | |
| continue | |
| sample_idx.insert(max_gap_idx + 1, sp) | |
| # Extract final keyframes | |
| progress(0.95, desc="Extracting final keyframes...") | |
| gallery = [] | |
| for i, idx in enumerate(sample_idx): | |
| img = extract_frame(video_path, idx) | |
| if img is not None: | |
| timestamp = idx / fps if fps > 0 else 0 | |
| mins = int(timestamp // 60) | |
| secs = int(timestamp % 60) | |
| percent = (idx / max(1, num_frames - 1)) * 100 | |
| caption = f"Frame {i+1}/{len(sample_idx)} | idx={idx} | {mins:02d}:{secs:02d} | {percent:.1f}%" | |
| gallery.append((img, caption)) | |
| summary = ( | |
| f"**TASKER {search_strategy.upper()}** extracted **{len(gallery)}** keyframes " | |
| f"from {num_frames} total frames ({num_frames/fps:.1f}s video).\n\n" | |
| f"Search stats: {effective_step} effective expansion steps, " | |
| f"confidence={last_confidence}/3, " | |
| f"target range {min_frames}-{max_frames} frames." | |
| ) | |
| progress(1.0, desc="Done!") | |
| return gallery, summary | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CUSTOM_CSS = """ | |
| #header { text-align: center; margin-bottom: 20px; } | |
| #header h1 { font-size: 2em; margin-bottom: 5px; } | |
| #header p { color: #666; font-size: 1.1em; } | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, title="TASKER Keyframe Extractor") as demo: | |
| gr.HTML(""" | |
| <div id="header"> | |
| <h1>TASKER: Task-driven and Scene-aware Keyframe Search</h1> | |
| <p>Extract task-relevant keyframes from a video using VLM-guided tree search (A* / BFS / GBFS / Dijkstra)</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_input = gr.Video(label="Upload Video", sources=["upload"]) | |
| goal_input = gr.Textbox( | |
| label="Task Query / Goal", | |
| placeholder="e.g., How to send an email with an attachment?", | |
| lines=2, | |
| ) | |
| strategy_input = gr.Dropdown( | |
| choices=["a_star", "bfs", "gbfs", "dijkstra"], | |
| value="a_star", | |
| label="Search Strategy", | |
| info="A* balances goal-relevance and visual changes. BFS explores broadly. GBFS focuses on goal. Dijkstra focuses on visual changes.", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_frames_slider = gr.Slider(4, 16, value=10, step=1, label="Max Keyframes") | |
| min_frames_slider = gr.Slider(2, 8, value=6, step=1, label="Min Keyframes (before confidence check)") | |
| min_steps_slider = gr.Slider(1, 8, value=3, step=1, label="Min Search Steps") | |
| conf_slider = gr.Slider(1, 3, value=3, step=1, label="Confidence Threshold (3=strictest)") | |
| extract_btn = gr.Button("Extract Keyframes", variant="primary") | |
| with gr.Column(scale=2): | |
| summary_output = gr.Markdown(label="Summary") | |
| gallery_output = gr.Gallery( | |
| label="Extracted Keyframes", | |
| columns=3, | |
| height=600, | |
| object_fit="contain", | |
| ) | |
| extract_btn.click( | |
| fn=extract_keyframes, | |
| inputs=[ | |
| video_input, | |
| goal_input, | |
| strategy_input, | |
| max_frames_slider, | |
| min_frames_slider, | |
| min_steps_slider, | |
| conf_slider, | |
| ], | |
| outputs=[gallery_output, summary_output], | |
| ) | |
| demo.launch() |