| import io |
| import os |
| import time |
| import torch |
| import numpy as np |
| import gradio as gr |
| import multiprocessing |
| from decord import cpu, VideoReader |
|
|
| import matplotlib.pyplot as plt |
| import matplotlib.ticker as ticker |
| import matplotlib.colors as mcolors |
| from scipy.interpolate import make_interp_spline |
| from PIL import Image |
|
|
| from tempo.builder import load_pretrained_model |
| from tempo.conversation import conv_templates, SeparatorStyle |
| from tempo.constants import ( |
| DEFAULT_IM_END_TOKEN, |
| DEFAULT_IM_START_TOKEN, |
| DEFAULT_IMAGE_TOKEN, |
| IMAGE_TOKEN_INDEX, |
| ) |
| from tempo.mm_datautils import ( |
| compute_segment_timestamp, |
| KeywordsStoppingCriteria, |
| process_qwen_content, |
| tokenizer_image_token, |
| ) |
|
|
| import spaces |
| from huggingface_hub import snapshot_download |
|
|
| def get_real_cpu_cores(): |
| """use multiple threads for video decoding""" |
| try: |
| |
| cores = len(os.sched_getaffinity(0)) |
| except AttributeError: |
| |
| cores = multiprocessing.cpu_count() |
| return cores |
|
|
| def compute_sample_indices( |
| total_frames: int, |
| original_fps: float, |
| video_fps: float = 2.0, |
| min_frames_num: int = 4, |
| max_frames_num: int = 1024 |
| ) -> list[int]: |
|
|
| start_frame, end_frame = 0, total_frames - 1 |
| clip_frames = end_frame - start_frame + 1 |
| if clip_frames <= 1: |
| return [start_frame] |
|
|
| if original_fps is None or original_fps <= 0: |
| original_fps = video_fps |
|
|
| clip_duration = clip_frames / original_fps |
| target_num_frames = max(1, round(clip_duration * video_fps)) |
| final_num_frames = min(max(target_num_frames, min_frames_num), max_frames_num) |
|
|
| if final_num_frames == 1: |
| return [end_frame] |
|
|
| indices = np.round(np.linspace(start_frame, end_frame, final_num_frames)).astype(int) |
| indices = np.clip(indices, start_frame, end_frame) |
|
|
| return indices.tolist() |
|
|
| def load_video(video_path: str, video_fps: float = 2.0, max_frames: int = 1024) -> tuple: |
|
|
| available_cores = get_real_cpu_cores() |
| optimal_threads = min(max(1, available_cores - 1), 16) |
| print(f"[Profiling] Detected {available_cores} CPU cores. Decord using {optimal_threads} threads.") |
|
|
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=optimal_threads) |
| total_frames = len(vr) |
| original_fps = vr.get_avg_fps() |
| frame_idx = compute_sample_indices(total_frames, original_fps, video_fps, max_frames_num=max_frames) |
| images = vr.get_batch(frame_idx).asnumpy() |
| clip_duration = total_frames / original_fps |
|
|
| real_fps = len(images) / clip_duration if clip_duration > 0 else video_fps |
|
|
| return images, real_fps |
|
|
| def generate_allocation_plot(allocations): |
| """ |
| Token allocation visualization function |
| """ |
| if allocations is None or len(allocations) == 0: |
| |
| return Image.new('RGB', (1600, 350), color='white') |
|
|
| allocations = np.array(allocations) |
| num_segments = len(allocations) |
|
|
| plt.rcParams.update({'font.size': 14, 'font.family': 'serif'}) |
| fig = plt.figure(figsize=(16, 3.5), layout='constrained') |
| gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1.0], hspace=0.05) |
|
|
|
|
| ax_heat = fig.add_subplot(gs[0]) |
| ax_heat.set_title(" ", pad=50) |
|
|
| colors = ["#EBF5FB", "#85C1E9", "#F2D7D5", "#E74C3C", "#641E16"] |
| cmap_custom = mcolors.LinearSegmentedColormap.from_list("custom_heat", colors) |
|
|
| vmax_val = max(128, allocations.max()) |
| ax_heat.imshow([allocations], cmap=cmap_custom, aspect='auto', extent=[0.5, num_segments + 0.5, 0, 1], vmin=4, vmax=vmax_val) |
| ax_heat.set_yticks([]) |
| ax_heat.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) |
| for spine in ax_heat.spines.values(): |
| spine.set_linewidth(1.2) |
|
|
| ax_line = fig.add_subplot(gs[1], sharex=ax_heat) |
|
|
| x = np.arange(1, num_segments + 1) |
|
|
| if num_segments > 3: |
| spl = make_interp_spline(x, allocations, k=3) |
| x_smooth = np.linspace(1, num_segments, 800) |
| y_smooth = spl(x_smooth) |
| y_smooth = np.clip(y_smooth, 4, vmax_val) |
| else: |
| x_smooth = x |
| y_smooth = allocations |
|
|
| line_color = '#1A252C' |
| fill_color = '#D5D8DC' |
|
|
| ax_line.plot(x_smooth, y_smooth, color=line_color, linewidth=2.0) |
| ax_line.fill_between(x_smooth, y_smooth, color=fill_color, alpha=0.4) |
|
|
| ax_line.axhline(vmax_val, color='#C0392B', linestyle='--', linewidth=1.2, alpha=0.8) |
| ax_line.axhline(4, color='#2980B9', linestyle='--', linewidth=1.2, alpha=0.8) |
|
|
| ax_line.set_xlim(0.5, num_segments + 0.5) |
| ax_line.set_ylim(0, vmax_val + 12) |
| ax_line.set_ylabel("Tokens / Seg", fontsize=14, fontweight='bold') |
| ax_line.set_xlabel("Temporal Segments", fontsize=14, fontweight='bold') |
| ax_line.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) |
|
|
| ax_line.spines['top'].set_visible(False) |
| ax_line.spines['right'].set_visible(False) |
| ax_line.spines['bottom'].set_linewidth(1.2) |
| ax_line.spines['left'].set_linewidth(1.2) |
| ax_line.grid(axis='y', linestyle=':', color='gray', alpha=0.5) |
|
|
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, transparent=True) |
| plt.close(fig) |
| buf.seek(0) |
| return Image.open(buf) |
|
|
| model_id = "Vision-CAIR/Tempo-6B" |
| print(f"[Init] Downloading/Loading weights from {model_id}...") |
| MODEL_PATH = snapshot_download(repo_id=model_id) |
| print(f"[Init] Loading Tempo model from {MODEL_PATH}...") |
| tokenizer, model, image_processor = load_pretrained_model( |
| MODEL_PATH, |
| device_map="cuda", |
| use_flash_attn=False |
| ) |
|
|
| FIXED_MAX_LENGTH = 16384 |
| model.config.tokenizer_model_max_length = FIXED_MAX_LENGTH |
| tokenizer.model_max_length = FIXED_MAX_LENGTH |
| model.eval() |
| model.to(torch.bfloat16) |
| print(f"[Init] Model loaded! Max context length set to {FIXED_MAX_LENGTH}.") |
|
|
|
|
| |
| |
| |
| @spaces.GPU |
| def predict(video_path, query, max_frames, visual_token_budget, temperature, max_new_tokens, disable_dynamic_compress): |
| if not video_path: |
| return "β οΈ Error: Please upload a video first." |
| if not query: |
| return "β οΈ Error: Please enter a question." |
|
|
| print(f"\n[Request] Video: {video_path} | Query: {query}") |
| |
| model.config.visual_token_budget = int(visual_token_budget) |
| model.get_vision_tower_aux_list()[0].dynamic_compress = not disable_dynamic_compress |
|
|
| |
| start_prep_time = time.perf_counter() |
| try: |
| video_frames, real_fps = load_video(video_path, video_fps=2.0, max_frames=int(max_frames)) |
| except Exception as e: |
| return f"β οΈ Error loading video: {str(e)}" |
|
|
| |
| frame_windows, frame_stride = 8, 8 |
| vlm_inputs = process_qwen_content( |
| video_frames, "video", query, image_processor[0], real_fps, frame_windows, frame_stride, is_eval=True |
| ) |
| vlm_inputs = {key: v.cuda() for key, v in vlm_inputs.items()} |
|
|
| |
| seg_timestamps = compute_segment_timestamp( |
| len(vlm_inputs["video_grid_thw"]), tokenizer, real_fps, frame_stride, frame_windows |
| ) |
|
|
| |
| num_segments = len(vlm_inputs["video_grid_thw"]) |
| segment_duration = frame_windows / real_fps |
| stats_info = f"π¬ Video Stats: Total Segments: {num_segments} | Segment Duration: {segment_duration:.2f}s | Real FPS: {real_fps:.2f}" |
| |
| |
| if getattr(model.config, "mm_use_im_start_end", False): |
| qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + query |
| else: |
| qs = DEFAULT_IMAGE_TOKEN + "\n" + query |
|
|
| conv_version = "qwen" |
| conv = conv_templates[conv_version].copy() |
| conv.append_message(conv.roles[0], qs) |
| conv.append_message(conv.roles[1], None) |
| prompt = conv.get_prompt() |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
|
| |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() |
| stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids) |
|
|
| model._demo_count_allocations = [] |
|
|
| start_infer_time = time.perf_counter() |
|
|
| |
| with torch.inference_mode(): |
| output_ids = model.generate( |
| input_ids, |
| images=None, |
| image_sizes=None, |
| do_sample=(temperature > 0), |
| temperature=temperature if temperature > 0 else None, |
| max_new_tokens=int(max_new_tokens), |
| use_cache=True, |
| stopping_criteria=[stopping_criteria], |
| vlm_inputs=vlm_inputs, |
| seg_timestamps=seg_timestamps, |
| ) |
|
|
| end_infer_time = time.perf_counter() |
|
|
| if isinstance(output_ids, tuple): |
| output_ids = output_ids[0] |
| |
| prep_duration = start_infer_time - start_prep_time |
| infer_duration = end_infer_time - start_infer_time |
| total_duration = end_infer_time - start_prep_time |
| stats_info += f"\nβ‘ Profiling : Prep Time: {prep_duration:.2f}s | Inference Time: {infer_duration:.2f}s | Total: {total_duration:.2f}s" |
|
|
| pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| if pred.endswith(stop_str): |
| pred = pred[: -len(stop_str)].strip() |
| |
| |
| allocations_data = model._demo_count_allocations |
| plot_img = generate_allocation_plot(allocations_data) |
| |
| return pred, plot_img, stats_info |
|
|
| |
| |
| |
| with gr.Blocks(title="Tempo Video Understanding", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # β±οΈ Tempo: Small Vision-Language Models are Smart Compressors for Long Video Understanding |
| Upload a video and ask any question! Tempo dynamically compresses visual tokens based on your query to achieve SOTA performance. |
| **[π Project Page](https://feielysia.github.io/tempo-page)** | **[π» GitHub](https://github.com/FeiElysia/Tempo)** | **[π Paper](https://arxiv.org/abs/2604.08120)** | **[π¨βπ» @Junjie Fei](https://feielysia.github.io/)** |
| |
| *β³ **Slow preprocessing?** Try Examples 4 & 5 below, decrease `Max Sampled Frames` in Advanced Settings, or check our [GitHub](https://github.com/FeiElysia/Tempo) for full-speed local deployment.* |
| """ |
| ) |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| video_input = gr.Video(label="Upload Video") |
| example_poster = gr.Image(label="Video Poster", interactive=False, height=150, visible=False) |
| query_input = gr.Textbox(label="Your Question", placeholder="e.g., What is the person doing in the video?", lines=3) |
| with gr.Row(): |
| clear_btn = gr.Button("π§Ή Clear", variant="secondary") |
| submit_btn = gr.Button("π Generate Response", variant="primary") |
| |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| max_frames_slider = gr.Slider(minimum=16, maximum=2048, value=1024, step=16, label="Max Sampled Frames") |
| budget_slider = gr.Slider(minimum=64, maximum=16384, value=8192, step=64, label="Visual Token Budget") |
| temp_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature (0 = Greedy)") |
| max_tokens_slider = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max New Tokens") |
| disable_compress_chk = gr.Checkbox(label="Disable Dynamic Compression (Baseline)", value=False) |
| |
| |
| with gr.Column(scale=1): |
| output_text = gr.Textbox(label="Tempo Response", lines=12, interactive=False) |
| stats_text = gr.Textbox(label="π Video Segment Stats", lines=1, interactive=False) |
| output_plot = gr.Image(label="Query-Aware Visual Feature Intensity (Visual Token Allocation)", interactive=False, height=180) |
|
|
| |
| submit_btn.click( |
| fn=predict, |
| inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk], |
| outputs=[output_text, output_plot, stats_text] |
| ) |
| |
| query_input.submit( |
| fn=predict, |
| inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk], |
| outputs=[output_text, output_plot, stats_text] |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: (None, None, None, None, None, None), |
| inputs=None, |
| outputs=[video_input, example_poster, query_input, output_text, stats_text, output_plot] |
| ) |
| |
| |
| gr.Markdown("---") |
| gr.Markdown("### π‘ Try an Example") |
| gr.Examples( |
| examples=[ |
| [ |
| "examples/hsr_helloworld.mp4", |
| "Task: Please examine the provided media and answer the following three questions regarding the specific puppy in the scene:\n" |
| "Q1: What is the primary fur color of the puppy positioned on the swing?\n" |
| "Q2: Specify the exact time interval (in seconds, e.g., XX-XXs) during which the puppy is seen sitting on the swing.\n" |
| "Q3: Provide a brief description of the puppy's appearance and its surroundings.", |
| "examples/meme_hsr_helloworld.png" |
| ], |
| [ |
| "examples/hsr_helloworld.mp4", |
| "Task: Please analyze the provided video and answer the following 7 questions precisely.\n" |
| "Q1: How many performers are visible on the stage?\n" |
| "Q2: Describe the architectural elements in the background. What historical civilization do they remind you of?\n" |
| "Q3: What is happening in the night sky above the performers, and what does this suggest about the event?\n" |
| "Q4: List the hair colors of the performers in order from left to right.\n" |
| "Q5: Identify the specific musical instrument being played by the performer located on the far left of the stage.\n" |
| "Q6: What is the specific time interval (in seconds, e.g., XX-XXs) during which this fireworks performance scene occurs in the video?\n" |
| "Q7: Look at the audience in the foreground. How does their silhouette-like depiction affect the viewer's perspective of the stage?", |
| "examples/performance_hsr_helloworld.png" |
| ], |
| [ |
| "examples/honkai3_becauseofyou.mp4", |
| "What text appears in the center of the video behind a sea of pink flowers?", |
| "examples/ocr_honkai3_becauseofyou.png" |
| ], |
| [ |
| "examples/videomme_fFjv93ACGo8.mp4", |
| "How many red socks are above the fireplace at the end of this video?", |
| "examples/cover_videomme_fFjv93ACGo8.png" |
| ], |
| [ |
| "examples/videomme_FsLaTZmP6Uw.mp4", |
| "Which year was the game held?", |
| "examples/cover_videomme_FsLaTZmP6Uw.png" |
| ], |
| [ |
| "examples/honkai3_becauseofyou.mp4", |
| "Describe the video in detail.", |
| "examples/description_honkai3_becauseofyou.png" |
| ] |
| ], |
| inputs=[video_input, query_input, example_poster], |
| cache_examples=False, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|
|
|