Tempo / app.py
visioncairgroup
Update app.py
c149b09 verified
import io
import os
import time
import torch
import numpy as np
import gradio as gr
import multiprocessing
from decord import cpu, VideoReader
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.colors as mcolors
from scipy.interpolate import make_interp_spline
from PIL import Image
from tempo.builder import load_pretrained_model
from tempo.conversation import conv_templates, SeparatorStyle
from tempo.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
IMAGE_TOKEN_INDEX,
)
from tempo.mm_datautils import (
compute_segment_timestamp,
KeywordsStoppingCriteria,
process_qwen_content,
tokenizer_image_token,
)
import spaces
from huggingface_hub import snapshot_download
def get_real_cpu_cores():
"""use multiple threads for video decoding"""
try:
# HF Spaces
cores = len(os.sched_getaffinity(0))
except AttributeError:
# Local environments
cores = multiprocessing.cpu_count()
return cores
def compute_sample_indices(
total_frames: int,
original_fps: float,
video_fps: float = 2.0,
min_frames_num: int = 4,
max_frames_num: int = 1024
) -> list[int]:
start_frame, end_frame = 0, total_frames - 1
clip_frames = end_frame - start_frame + 1
if clip_frames <= 1:
return [start_frame]
if original_fps is None or original_fps <= 0:
original_fps = video_fps
clip_duration = clip_frames / original_fps
target_num_frames = max(1, round(clip_duration * video_fps))
final_num_frames = min(max(target_num_frames, min_frames_num), max_frames_num)
if final_num_frames == 1:
return [end_frame]
indices = np.round(np.linspace(start_frame, end_frame, final_num_frames)).astype(int)
indices = np.clip(indices, start_frame, end_frame)
return indices.tolist()
def load_video(video_path: str, video_fps: float = 2.0, max_frames: int = 1024) -> tuple:
available_cores = get_real_cpu_cores()
optimal_threads = min(max(1, available_cores - 1), 16)
print(f"[Profiling] Detected {available_cores} CPU cores. Decord using {optimal_threads} threads.")
vr = VideoReader(video_path, ctx=cpu(0), num_threads=optimal_threads)
total_frames = len(vr)
original_fps = vr.get_avg_fps()
frame_idx = compute_sample_indices(total_frames, original_fps, video_fps, max_frames_num=max_frames)
images = vr.get_batch(frame_idx).asnumpy()
clip_duration = total_frames / original_fps
real_fps = len(images) / clip_duration if clip_duration > 0 else video_fps
return images, real_fps
def generate_allocation_plot(allocations):
"""
Token allocation visualization function
"""
if allocations is None or len(allocations) == 0:
# if disable_dynamic_compress is True, we return a blank image
return Image.new('RGB', (1600, 350), color='white')
allocations = np.array(allocations)
num_segments = len(allocations)
plt.rcParams.update({'font.size': 14, 'font.family': 'serif'})
fig = plt.figure(figsize=(16, 3.5), layout='constrained')
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1.0], hspace=0.05)
ax_heat = fig.add_subplot(gs[0])
ax_heat.set_title(" ", pad=50)
colors = ["#EBF5FB", "#85C1E9", "#F2D7D5", "#E74C3C", "#641E16"]
cmap_custom = mcolors.LinearSegmentedColormap.from_list("custom_heat", colors)
vmax_val = max(128, allocations.max())
ax_heat.imshow([allocations], cmap=cmap_custom, aspect='auto', extent=[0.5, num_segments + 0.5, 0, 1], vmin=4, vmax=vmax_val)
ax_heat.set_yticks([])
ax_heat.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
for spine in ax_heat.spines.values():
spine.set_linewidth(1.2)
ax_line = fig.add_subplot(gs[1], sharex=ax_heat)
x = np.arange(1, num_segments + 1)
if num_segments > 3:
spl = make_interp_spline(x, allocations, k=3)
x_smooth = np.linspace(1, num_segments, 800)
y_smooth = spl(x_smooth)
y_smooth = np.clip(y_smooth, 4, vmax_val)
else:
x_smooth = x
y_smooth = allocations
line_color = '#1A252C'
fill_color = '#D5D8DC'
ax_line.plot(x_smooth, y_smooth, color=line_color, linewidth=2.0)
ax_line.fill_between(x_smooth, y_smooth, color=fill_color, alpha=0.4)
ax_line.axhline(vmax_val, color='#C0392B', linestyle='--', linewidth=1.2, alpha=0.8)
ax_line.axhline(4, color='#2980B9', linestyle='--', linewidth=1.2, alpha=0.8)
ax_line.set_xlim(0.5, num_segments + 0.5)
ax_line.set_ylim(0, vmax_val + 12)
ax_line.set_ylabel("Tokens / Seg", fontsize=14, fontweight='bold')
ax_line.set_xlabel("Temporal Segments", fontsize=14, fontweight='bold')
ax_line.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
ax_line.spines['top'].set_visible(False)
ax_line.spines['right'].set_visible(False)
ax_line.spines['bottom'].set_linewidth(1.2)
ax_line.spines['left'].set_linewidth(1.2)
ax_line.grid(axis='y', linestyle=':', color='gray', alpha=0.5)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, transparent=True)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
model_id = "Vision-CAIR/Tempo-6B"
print(f"[Init] Downloading/Loading weights from {model_id}...")
MODEL_PATH = snapshot_download(repo_id=model_id)
print(f"[Init] Loading Tempo model from {MODEL_PATH}...")
tokenizer, model, image_processor = load_pretrained_model(
MODEL_PATH,
device_map="cuda",
use_flash_attn=False
)
FIXED_MAX_LENGTH = 16384
model.config.tokenizer_model_max_length = FIXED_MAX_LENGTH
tokenizer.model_max_length = FIXED_MAX_LENGTH
model.eval()
model.to(torch.bfloat16)
print(f"[Init] Model loaded! Max context length set to {FIXED_MAX_LENGTH}.")
# ==========================================
# inference
# ==========================================
@spaces.GPU
def predict(video_path, query, max_frames, visual_token_budget, temperature, max_new_tokens, disable_dynamic_compress):
if not video_path:
return "⚠️ Error: Please upload a video first."
if not query:
return "⚠️ Error: Please enter a question."
print(f"\n[Request] Video: {video_path} | Query: {query}")
model.config.visual_token_budget = int(visual_token_budget)
model.get_vision_tower_aux_list()[0].dynamic_compress = not disable_dynamic_compress
# video process
start_prep_time = time.perf_counter()
try:
video_frames, real_fps = load_video(video_path, video_fps=2.0, max_frames=int(max_frames))
except Exception as e:
return f"⚠️ Error loading video: {str(e)}"
# process local compressor inputs
frame_windows, frame_stride = 8, 8
vlm_inputs = process_qwen_content(
video_frames, "video", query, image_processor[0], real_fps, frame_windows, frame_stride, is_eval=True
)
vlm_inputs = {key: v.cuda() for key, v in vlm_inputs.items()}
# compute timestamp for each segment
seg_timestamps = compute_segment_timestamp(
len(vlm_inputs["video_grid_thw"]), tokenizer, real_fps, frame_stride, frame_windows
)
# stat info
num_segments = len(vlm_inputs["video_grid_thw"])
segment_duration = frame_windows / real_fps
stats_info = f"🎬 Video Stats: Total Segments: {num_segments} | Segment Duration: {segment_duration:.2f}s | Real FPS: {real_fps:.2f}"
# prompt
if getattr(model.config, "mm_use_im_start_end", False):
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + query
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + query
conv_version = "qwen"
conv = conv_templates[conv_version].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
# tokenization
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
model._demo_count_allocations = []
start_infer_time = time.perf_counter()
# generating
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=None, # Qwen-VL architecture usually uses vlm_inputs instead of raw images in kwargs if projector is vlm
image_sizes=None,
do_sample=(temperature > 0),
temperature=temperature if temperature > 0 else None,
max_new_tokens=int(max_new_tokens),
use_cache=True,
stopping_criteria=[stopping_criteria],
vlm_inputs=vlm_inputs,
seg_timestamps=seg_timestamps,
)
end_infer_time = time.perf_counter()
if isinstance(output_ids, tuple):
output_ids = output_ids[0]
prep_duration = start_infer_time - start_prep_time
infer_duration = end_infer_time - start_infer_time
total_duration = end_infer_time - start_prep_time
stats_info += f"\n⚑ Profiling : Prep Time: {prep_duration:.2f}s | Inference Time: {infer_duration:.2f}s | Total: {total_duration:.2f}s"
pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
if pred.endswith(stop_str):
pred = pred[: -len(stop_str)].strip()
# token allocation plot
allocations_data = model._demo_count_allocations
plot_img = generate_allocation_plot(allocations_data)
return pred, plot_img, stats_info
# ==========================================
# UI
# ==========================================
with gr.Blocks(title="Tempo Video Understanding", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ⏱️ Tempo: Small Vision-Language Models are Smart Compressors for Long Video Understanding
Upload a video and ask any question! Tempo dynamically compresses visual tokens based on your query to achieve SOTA performance.
**[🏠 Project Page](https://feielysia.github.io/tempo-page)** | **[πŸ’» GitHub](https://github.com/FeiElysia/Tempo)** | **[πŸ“„ Paper](https://arxiv.org/abs/2604.08120)** | **[πŸ‘¨β€πŸ’» @Junjie Fei](https://feielysia.github.io/)**
*⏳ **Slow preprocessing?** Try Examples 4 & 5 below, decrease `Max Sampled Frames` in Advanced Settings, or check our [GitHub](https://github.com/FeiElysia/Tempo) for full-speed local deployment.*
"""
)
with gr.Row():
# left column: inputs
with gr.Column(scale=1):
video_input = gr.Video(label="Upload Video")
example_poster = gr.Image(label="Video Poster", interactive=False, height=150, visible=False)
query_input = gr.Textbox(label="Your Question", placeholder="e.g., What is the person doing in the video?", lines=3)
with gr.Row():
clear_btn = gr.Button("🧹 Clear", variant="secondary")
submit_btn = gr.Button("πŸš€ Generate Response", variant="primary")
# hyperparameters
with gr.Accordion("Advanced Settings", open=False):
max_frames_slider = gr.Slider(minimum=16, maximum=2048, value=1024, step=16, label="Max Sampled Frames")
budget_slider = gr.Slider(minimum=64, maximum=16384, value=8192, step=64, label="Visual Token Budget")
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature (0 = Greedy)")
max_tokens_slider = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max New Tokens")
disable_compress_chk = gr.Checkbox(label="Disable Dynamic Compression (Baseline)", value=False)
# right column: outputs
with gr.Column(scale=1):
output_text = gr.Textbox(label="Tempo Response", lines=12, interactive=False)
stats_text = gr.Textbox(label="πŸ“Š Video Segment Stats", lines=1, interactive=False)
output_plot = gr.Image(label="Query-Aware Visual Feature Intensity (Visual Token Allocation)", interactive=False, height=180)
# clicking submit_btn or pressing enter in query_input will trigger prediction
submit_btn.click(
fn=predict,
inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk],
outputs=[output_text, output_plot, stats_text]
)
query_input.submit(
fn=predict,
inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk],
outputs=[output_text, output_plot, stats_text]
)
clear_btn.click(
fn=lambda: (None, None, None, None, None, None),
inputs=None,
outputs=[video_input, example_poster, query_input, output_text, stats_text, output_plot]
)
# Examples
gr.Markdown("---")
gr.Markdown("### πŸ’‘ Try an Example")
gr.Examples(
examples=[
[
"examples/hsr_helloworld.mp4",
"Task: Please examine the provided media and answer the following three questions regarding the specific puppy in the scene:\n"
"Q1: What is the primary fur color of the puppy positioned on the swing?\n"
"Q2: Specify the exact time interval (in seconds, e.g., XX-XXs) during which the puppy is seen sitting on the swing.\n"
"Q3: Provide a brief description of the puppy's appearance and its surroundings.",
"examples/meme_hsr_helloworld.png"
],
[
"examples/hsr_helloworld.mp4",
"Task: Please analyze the provided video and answer the following 7 questions precisely.\n"
"Q1: How many performers are visible on the stage?\n"
"Q2: Describe the architectural elements in the background. What historical civilization do they remind you of?\n"
"Q3: What is happening in the night sky above the performers, and what does this suggest about the event?\n"
"Q4: List the hair colors of the performers in order from left to right.\n"
"Q5: Identify the specific musical instrument being played by the performer located on the far left of the stage.\n"
"Q6: What is the specific time interval (in seconds, e.g., XX-XXs) during which this fireworks performance scene occurs in the video?\n"
"Q7: Look at the audience in the foreground. How does their silhouette-like depiction affect the viewer's perspective of the stage?",
"examples/performance_hsr_helloworld.png"
],
[
"examples/honkai3_becauseofyou.mp4",
"What text appears in the center of the video behind a sea of pink flowers?",
"examples/ocr_honkai3_becauseofyou.png"
],
[
"examples/videomme_fFjv93ACGo8.mp4",
"How many red socks are above the fireplace at the end of this video?",
"examples/cover_videomme_fFjv93ACGo8.png"
],
[
"examples/videomme_FsLaTZmP6Uw.mp4",
"Which year was the game held?",
"examples/cover_videomme_FsLaTZmP6Uw.png"
],
[
"examples/honkai3_becauseofyou.mp4",
"Describe the video in detail.",
"examples/description_honkai3_becauseofyou.png"
]
],
inputs=[video_input, query_input, example_poster],
cache_examples=False,
)
if __name__ == "__main__":
demo.queue().launch()