multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
5c8e53c verified
Raw
History Blame Contribute Delete
22.8 kB
import os
# Expandable segments to avoid allocator fragmentation under memory spikes
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # MUST be before any torch/CUDA import
import cv2
import re
import json
import torch
import numpy as np
from PIL import Image
from typing import List, Optional, Tuple
import tempfile
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
# ── Load model at module scope (ZeroGPU rule 2) ──────────────────────────────
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
).to("cuda")
# ── VLM call helper ──────────────────────────────────────────────────────────
def vlm_call(images: List[Image.Image], question: str, system_prompt: str = "You are a highly strict UI navigation assistant designed to output JSON.") -> str:
"""Call the local VLM with images and a question, return text response."""
content = []
for img in images:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": question})
messages = [
{"role": "system", "content": [{"type": "text", "text": system_prompt}]},
{"role": "user", "content": content},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[images] if images else None,
padding=True,
return_tensors="pt",
).to("cuda")
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=8192, do_sample=False, temperature=1.0)
# Trim the input tokens from output
input_len = inputs["input_ids"].shape[1]
output_text = processor.batch_decode(
output_ids[:, input_len:], skip_special_tokens=True
)[0]
return output_text
def parse_json_response(text: str):
"""Extract a JSON object from a text response."""
try:
match = re.search(r'\{.*\}', text, re.DOTALL)
if match:
return json.loads(match.group(0))
except Exception:
pass
return None
# ── Video utilities ──────────────────────────────────────────────────────────
def extract_frame(video_path: str, frame_idx: int) -> Optional[Image.Image]:
"""Extract a single frame from the video as PIL Image."""
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
cap.release()
if not ret:
return None
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
def compute_color_histogram(img: Image.Image) -> np.ndarray:
"""Compute a normalized 3-channel color histogram."""
arr = np.array(img)
hist = cv2.calcHist([arr], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256])
cv2.normalize(hist, hist)
return hist
def frame_similarity(hist1: np.ndarray, hist2: np.ndarray) -> float:
"""Compare two color histograms using correlation."""
return float(cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL))
def is_frame_redundant(new_hist: np.ndarray, existing_hists: List[np.ndarray], threshold: float = 0.985) -> bool:
"""Check if a new frame is too similar to existing ones."""
for h in existing_hists:
if frame_similarity(new_hist, h) >= threshold:
return True
return False
# ── TASKER core: A* tree search keyframe extraction ─────────────────────────
class VideoSeg:
"""A video segment (tree node)."""
def __init__(self, start: int, end: int):
self.start = start
self.end = end
def find_visual_change_split_point(video_path: str, seg_start: int, seg_end: int) -> int:
"""Find the frame with the largest visual change in a segment."""
midpoint = (seg_start + seg_end) // 2
try:
seg_length = seg_end - seg_start
if seg_length <= 2:
return midpoint
cap = cv2.VideoCapture(video_path)
num_samples = min(seg_length, 10)
step = max(1, seg_length // num_samples)
sample_indices = list(range(seg_start, seg_end, step))
if sample_indices[-1] != seg_end:
sample_indices.append(seg_end)
frames = {}
hists = {}
for idx in sample_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frames[idx] = frame
hist = cv2.calcHist([frame], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256])
cv2.normalize(hist, hist)
hists[idx] = hist
if len(frames) < 2:
cap.release()
return midpoint
sorted_indices = sorted(frames.keys())
max_diff = -1
best_a, best_b = sorted_indices[0], sorted_indices[-1]
for i in range(len(sorted_indices) - 1):
idx_a, idx_b = sorted_indices[i], sorted_indices[i + 1]
if idx_a in hists and idx_b in hists:
diff = 1.0 - cv2.compareHist(hists[idx_a], hists[idx_b], cv2.HISTCMP_CORREL)
if diff > max_diff:
max_diff = diff
best_a, best_b = idx_a, idx_b
candidate = best_b
cap.release()
# Clamp to valid range
min_pos = seg_start + int(seg_length * 0.15)
max_pos = seg_start + int(seg_length * 0.85)
if candidate < min_pos or candidate > max_pos:
return midpoint
return candidate
except Exception:
return midpoint
def a_star_select_segment(images: List[Image.Image], goal: str, segment_des: str) -> str:
"""A* strategy: balance goal-relevance and UI state changes."""
prompt = f"""You are provided with sequential images sampled from a video.
Each image is labeled with its frame index. The images are shown in chronological order.
Goal: {goal}
Candidate segments (gaps between current frames):
{segment_des}
(A* Strategy - Balance missing goal-relevant info and visual state changes)
Identify ONE single candidate segment that BEST satisfies BOTH conditions simultaneously:
1. GOAL PROXIMITY: The segment likely contains crucial missing actions that are necessary steps toward achieving the Goal.
2. STATE CHANGE MAGNITUDE: The segment whose boundary frames show the MOST different visual states is more likely to contain important operations.
Return JSON format: {{"frame_descriptions": [{{"segment_id": "1", "description": "Best A* candidate: missing goal step + visual state change"}}]}}
"""
return vlm_call(images, prompt)
def qa_and_reflect(images: List[Image.Image], goal: str) -> Tuple[str, int]:
"""Evaluate whether current frames are sufficient."""
prompt_qa = f"Task Goal: {goal}\nLook at these sequential frames. Describe the EXACT step-by-step actions that happen transitioning from one frame to the next."
answer = vlm_call(images, prompt_qa, system_prompt="You are a helpful video analysis assistant.")
prompt_eval = f"""Task Goal: {goal}
Your sequential analysis: {answer}
Evaluate your confidence level strictly:
1: Severe Jumps (There are completely missing screens or sudden state changes. MUST expand.)
2: Minor Disconnects (The flow makes sense, but some intermediate actions are missing. Should expand.)
3: Strong Continuity (The frames capture all important actions and transitions. No key step is skipped.)
Output JSON exactly like this: {{"confidence": 3}}
"""
conf_str = vlm_call(images, prompt_eval)
conf_json = parse_json_response(conf_str)
confidence = conf_json.get("confidence", 1) if conf_json else 1
return answer, int(confidence)
@spaces.GPU(duration=240)
def extract_keyframes(video_path: str, goal: str, search_strategy: str = "a_star", max_frames: int = 10, min_frames: int = 6, min_steps: int = 3, conf_lower: int = 3, progress=gr.Progress()):
"""
TASKER keyframe extraction: tree-search with VLM-guided segment selection.
Args:
video_path: Path to the input video.
goal: Task query describing what the user wants to see.
search_strategy: One of "a_star", "bfs", "gbfs", "dijkstra".
max_frames: Maximum number of keyframes to extract.
min_frames: Minimum number of frames before confidence check can stop.
min_steps: Minimum expansion steps before confidence check can stop.
conf_lower: Confidence threshold (1-3) to stop searching.
Returns:
List of (PIL Image, caption) tuples for gallery display, plus a summary string.
"""
cap = cv2.VideoCapture(video_path)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
if num_frames <= 0 or fps <= 0:
return [], "Error: Could not read video file. Please upload a valid video."
# ── Initial uniform sampling ─────────────────────────────────────────────
init_frames = 4
content_start = 0
content_end = num_frames - 1
if content_end - content_start + 1 <= init_frames:
sample_idx = list(range(content_start, content_end + 1))
else:
interval = max(1, (content_end - content_start + 1) // (init_frames - 1))
sample_idx = list(range(content_start, content_end + 1, interval))
if sample_idx[-1] != content_end:
sample_idx.append(content_end)
progress(0.1, desc=f"Initial sampling: {len(sample_idx)} frames from {num_frames} total")
video_segments = [VideoSeg(sample_idx[i-1], sample_idx[i]) for i in range(1, len(sample_idx))]
# Histogram cache for dedup
hist_cache = {}
frozen_segments = set()
effective_step = 0
last_confidence = 0
max_total_attempts = max_frames + 10
for attempt in range(1, max_total_attempts + 1):
current_frames = len(sample_idx)
if current_frames >= max_frames:
break
# Extract current frames as images
images = []
for idx in sample_idx:
img = extract_frame(video_path, idx)
if img is not None:
images.append(img)
if not images:
break
progress(
0.1 + 0.6 * (attempt / max_total_attempts),
desc=f"Step {attempt}: {current_frames} frames, evaluating..."
)
# Confidence check
if current_frames >= min_frames and effective_step > min_steps:
_, confidence = qa_and_reflect(images, goal)
last_confidence = confidence
if confidence >= conf_lower:
break
else:
if current_frames < min_frames:
pass # forced expansion
# Build segment descriptions
frame_to_img_idx = {frame: i + 1 for i, frame in enumerate(sample_idx)}
segment_des_lines = []
for i, seg in enumerate(video_segments):
seg_id = i + 1
if (seg.start, seg.end) in frozen_segments:
continue
start_img = frame_to_img_idx.get(seg.start, "?")
end_img = frame_to_img_idx.get(seg.end, "?")
segment_des_lines.append(
f" Segment {seg_id}: frames {seg.start}-{seg.end} (Image #{start_img} -> Image #{end_img})"
)
segment_des_str = "\n".join(segment_des_lines)
if not segment_des_str:
break
# VLM segment selection
try:
if search_strategy == "bfs":
response = vlm_call(images, f"""You are provided with sequential images sampled from a video.
Goal: {goal}
Candidate segments:
{segment_des_str}
Select MULTIPLE segments that likely contain crucial missing actions.
Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""")
elif search_strategy == "gbfs":
response = vlm_call(images, f"""You are provided with sequential images sampled from a video.
Goal: {goal}
Candidate segments:
{segment_des_str}
Select the SINGLE segment MOST LIKELY to contain crucial missing actions.
Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""")
elif search_strategy == "dijkstra":
response = vlm_call(images, f"""You are provided with sequential images sampled from a video.
Candidate segments:
{segment_des_str}
Select the SINGLE segment with the MOST significant visual state transition.
Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""")
else: # a_star
response = a_star_select_segment(images, goal, segment_des_str)
parsed = parse_json_response(response)
except Exception as e:
print(f"VLM call error at step {attempt}: {e}")
parsed = None
# Determine selected segment IDs
selected_seg_ids = set()
if parsed and "frame_descriptions" in parsed:
for desc in parsed["frame_descriptions"]:
for key in desc:
if key.lower() == "segment_id":
val = str(desc[key]).strip()
nums = re.findall(r'\d+', val)
if nums:
seg_id = int(nums[0])
if 1 <= seg_id <= len(video_segments):
selected_seg_ids.add(seg_id)
break
# Fallback: pick longest segment
if not selected_seg_ids:
longest_seg_id = None
longest_len = 0
for i, seg in enumerate(video_segments):
seg_len = seg.end - seg.start
if seg_len > longest_len and seg_len > 1 and (seg.start, seg.end) not in frozen_segments:
longest_len = seg_len
longest_seg_id = i + 1
if longest_seg_id is not None:
selected_seg_ids.add(longest_seg_id)
if not selected_seg_ids:
break
# BFS quota limit
if search_strategy == "bfs" and len(selected_seg_ids) > 1:
remaining_quota = max_frames - len(sample_idx)
if remaining_quota <= 0:
break
if len(selected_seg_ids) > remaining_quota:
sorted_seg_ids = sorted(selected_seg_ids,
key=lambda sid: video_segments[sid-1].end - video_segments[sid-1].start,
reverse=True)
selected_seg_ids = set(sorted_seg_ids[:remaining_quota])
# Split selected segments
split_origin = {}
new_segments = []
seg_counter = 0
for i, seg in enumerate(video_segments):
seg_id = i + 1
if seg_id in selected_seg_ids:
if seg.end - seg.start <= 1:
seg_counter += 1
new_segments.append(VideoSeg(seg.start, seg.end))
else:
sp = find_visual_change_split_point(video_path, seg.start, seg.end)
split_origin[sp] = (seg.start, seg.end)
seg_counter += 1
new_segments.append(VideoSeg(seg.start, sp))
seg_counter += 1
new_segments.append(VideoSeg(sp, seg.end))
else:
seg_counter += 1
new_segments.append(VideoSeg(seg.start, seg.end))
video_segments = new_segments
# Rebuild sample_idx
sample_idx_set = set()
for seg in video_segments:
sample_idx_set.add(seg.start)
sample_idx_set.add(seg.end)
new_sample_idx = sorted(list(sample_idx_set))
# Visual deduplication
new_frames = [idx for idx in new_sample_idx if idx not in set(sample_idx)]
old_sample_set = set(sample_idx)
# Compute histograms for old frames
old_hists = []
for idx in sample_idx:
img = extract_frame(video_path, idx)
if img is not None:
old_hists.append(compute_color_histogram(img))
frames_to_remove = []
accepted_new_hists = []
for new_idx in new_frames:
new_img = extract_frame(video_path, new_idx)
if new_img is None:
continue
new_hist = compute_color_histogram(new_img)
all_compare_hists = old_hists + accepted_new_hists
if is_frame_redundant(new_hist, all_compare_hists, threshold=0.985):
frames_to_remove.append(new_idx)
if new_idx in split_origin:
frozen_segments.add(split_origin[new_idx])
else:
accepted_new_hists.append(new_hist)
if frames_to_remove:
new_sample_idx = [idx for idx in new_sample_idx if idx not in frames_to_remove]
new_sample_idx = sorted(new_sample_idx)
video_segments = [VideoSeg(new_sample_idx[i-1], new_sample_idx[i])
for i in range(1, len(new_sample_idx))]
actually_added = len(new_sample_idx) > len(sample_idx)
sample_idx = new_sample_idx
if actually_added:
effective_step += 1
progress(0.85, desc="Finalizing keyframes...")
# Force-fill if too few frames
if len(sample_idx) < min_frames and last_confidence < conf_lower:
max_force = min_frames + 5
for _ in range(max_force):
if len(sample_idx) >= min_frames:
break
max_gap = 0
max_gap_idx = 0
for i in range(len(sample_idx) - 1):
if (sample_idx[i], sample_idx[i+1]) in frozen_segments:
continue
gap = sample_idx[i+1] - sample_idx[i]
if gap > max_gap:
max_gap = gap
max_gap_idx = i
if max_gap <= 1:
break
sp = find_visual_change_split_point(video_path, sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1])
sp_img = extract_frame(video_path, sp)
if sp_img is None:
break
sp_hist = compute_color_histogram(sp_img)
existing_hists = []
for idx in sample_idx:
img = extract_frame(video_path, idx)
if img is not None:
existing_hists.append(compute_color_histogram(img))
if is_frame_redundant(sp_hist, existing_hists, threshold=0.985):
frozen_segments.add((sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1]))
continue
sample_idx.insert(max_gap_idx + 1, sp)
# Extract final keyframes
progress(0.95, desc="Extracting final keyframes...")
gallery = []
for i, idx in enumerate(sample_idx):
img = extract_frame(video_path, idx)
if img is not None:
timestamp = idx / fps if fps > 0 else 0
mins = int(timestamp // 60)
secs = int(timestamp % 60)
percent = (idx / max(1, num_frames - 1)) * 100
caption = f"Frame {i+1}/{len(sample_idx)} | idx={idx} | {mins:02d}:{secs:02d} | {percent:.1f}%"
gallery.append((img, caption))
summary = (
f"**TASKER {search_strategy.upper()}** extracted **{len(gallery)}** keyframes "
f"from {num_frames} total frames ({num_frames/fps:.1f}s video).\n\n"
f"Search stats: {effective_step} effective expansion steps, "
f"confidence={last_confidence}/3, "
f"target range {min_frames}-{max_frames} frames."
)
progress(1.0, desc="Done!")
return gallery, summary
# ── Gradio UI ───────────────────────────────────────────────────────────────
CUSTOM_CSS = """
#header { text-align: center; margin-bottom: 20px; }
#header h1 { font-size: 2em; margin-bottom: 5px; }
#header p { color: #666; font-size: 1.1em; }
"""
with gr.Blocks(css=CUSTOM_CSS, title="TASKER Keyframe Extractor") as demo:
gr.HTML("""
<div id="header">
<h1>TASKER: Task-driven and Scene-aware Keyframe Search</h1>
<p>Extract task-relevant keyframes from a video using VLM-guided tree search (A* / BFS / GBFS / Dijkstra)</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="Upload Video", sources=["upload"])
goal_input = gr.Textbox(
label="Task Query / Goal",
placeholder="e.g., How to send an email with an attachment?",
lines=2,
)
strategy_input = gr.Dropdown(
choices=["a_star", "bfs", "gbfs", "dijkstra"],
value="a_star",
label="Search Strategy",
info="A* balances goal-relevance and visual changes. BFS explores broadly. GBFS focuses on goal. Dijkstra focuses on visual changes.",
)
with gr.Accordion("Advanced Settings", open=False):
max_frames_slider = gr.Slider(4, 16, value=10, step=1, label="Max Keyframes")
min_frames_slider = gr.Slider(2, 8, value=6, step=1, label="Min Keyframes (before confidence check)")
min_steps_slider = gr.Slider(1, 8, value=3, step=1, label="Min Search Steps")
conf_slider = gr.Slider(1, 3, value=3, step=1, label="Confidence Threshold (3=strictest)")
extract_btn = gr.Button("Extract Keyframes", variant="primary")
with gr.Column(scale=2):
summary_output = gr.Markdown(label="Summary")
gallery_output = gr.Gallery(
label="Extracted Keyframes",
columns=3,
height=600,
object_fit="contain",
)
extract_btn.click(
fn=extract_keyframes,
inputs=[
video_input,
goal_input,
strategy_input,
max_frames_slider,
min_frames_slider,
min_steps_slider,
conf_slider,
],
outputs=[gallery_output, summary_output],
)
demo.launch()