Spaces:
Runtime error
Runtime error
| import os | |
| import asyncio | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from io import BytesIO | |
| from moviepy import VideoFileClip | |
| import matplotlib.pyplot as plt | |
| import base64 | |
| from TStar.TStarFramework import run_tstar | |
| def img2base64(image_path): | |
| return base64.b64encode(open(image_path, "rb").read()).decode("utf-8") | |
| def create_timeline(frame_times, duration): | |
| """ | |
| Creates a timeline visualization for the sampled frames. | |
| """ | |
| fig, ax = plt.subplots(figsize=(10, 2)) | |
| ax.set_xlim(0, duration) | |
| ax.hlines(0.5, 0, duration, colors="gray", linestyles="dotted") | |
| ax.plot(frame_times, [0.5] * len(frame_times), 'ro') | |
| ax.set_xlabel("Time (s)") | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close(fig) | |
| return Image.open(buf) | |
| def analyze_and_sample_frames( | |
| video_file, | |
| question, | |
| openai_api_key, | |
| num_frames=8, | |
| batch=1, | |
| total_batches=3 | |
| ): | |
| """ | |
| 结合后端 run_tstar 函数,对视频进行关键帧搜索,并在前端生成可视化所需的结果: | |
| - metadata: 记录后端结果与问题、答案等信息 | |
| - frames: PIL 图像列表,用于在 Gradio Gallery 显示 | |
| - frame_times: 关键帧时间戳列表(秒) | |
| - timeline_image: 带有关键帧标注的时间线图像 | |
| """ | |
| if not os.path.exists(video_file): | |
| print("video_file does not exist:", video_file) | |
| return None, None, None, None | |
| if not question: | |
| question = "No question provided" | |
| options = "Freeform Question" | |
| # 你也可以根据 batch / total_batches 动态改变 search 的参数 | |
| # 例如 batch 越大,search_budget 越大;或者直接固定即可 | |
| # 这里只做演示,不做复杂逻辑 | |
| results = run_tstar( | |
| video_path=video_file, | |
| question=question, | |
| options=options, | |
| grounder="gpt-4o", | |
| heuristic="owl-vit", | |
| device="cuda:0", | |
| search_nframes=num_frames, | |
| grid_rows=4, | |
| grid_cols=4, | |
| confidence_threshold=0.6, | |
| search_budget=0.5, | |
| output_dir='./output', | |
| openai_api_key=openai_api_key | |
| ) | |
| # 从后端结果解析关键信息 | |
| frame_times = results.get("Frame Timestamps", []) | |
| answer = results.get("Answer", "No answer") | |
| grounding_objects = results.get("Grounding Objects", []) | |
| # 截取关键帧图像 | |
| frames = [] | |
| clip = VideoFileClip(video_file) | |
| video_duration = clip.duration | |
| for t in frame_times: | |
| # 确保时间戳不超过视频长度 | |
| if t > video_duration: | |
| t = video_duration | |
| frame_img = clip.get_frame(t) # 取对应秒的帧,返回 (H,W,3) numpy | |
| frame_pil = Image.fromarray(frame_img.astype(np.uint8)) | |
| frames.append(frame_pil) | |
| clip.close() | |
| # 生成时间线图像 | |
| timeline_image = create_timeline(frame_times, duration=video_duration) | |
| # 生成元数据(可根据需要增减字段) | |
| metadata = { | |
| "batch": batch, | |
| "total_batches": total_batches, | |
| "question": question, | |
| "answer": answer, | |
| "grounding_objects": grounding_objects, | |
| "frame_times": frame_times | |
| } | |
| return metadata, frames, frame_times, timeline_image | |
| def switch_batch(state_batches, selected_batch): | |
| """ | |
| Switches the display to the selected batch. | |
| """ | |
| if not selected_batch or selected_batch == "": | |
| return None, None, None, None | |
| batch_index = int(selected_batch.split()[-1]) - 1 | |
| timeline_image, frames, metadata = state_batches[batch_index] | |
| return ( | |
| gr.update(value=timeline_image, visible=True), | |
| gr.update(value=frames, visible=True), | |
| gr.update(value=metadata, visible=True), | |
| selected_batch, | |
| ) | |
| async def process_video_iteratively_with_state(video_file, question_input, openai_api_key_input, state_batches, current_display_batch, total_batches=1, num_frames=8): | |
| """ | |
| Processes the video and samples frames iteratively. | |
| """ | |
| if not video_file: | |
| yield None, None, None, "No video uploaded!", None, state_batches, current_display_batch | |
| return | |
| metadata = None | |
| for batch in range(1, total_batches + 1): | |
| metadata, frames, frame_times, timeline_image = analyze_and_sample_frames( | |
| video_file, question=question_input, openai_api_key=openai_api_key_input, num_frames=num_frames, batch=batch, total_batches=total_batches | |
| ) | |
| if metadata is None: | |
| continue | |
| state_batches.append((timeline_image, frames, metadata)) | |
| batch_choices = [f"Batch {i + 1}" for i in range(len(state_batches))] | |
| if current_display_batch is None or current_display_batch == f"Batch {batch - 1}": | |
| current_display_batch = f"Batch {batch}" | |
| yield ( | |
| gr.update(value=timeline_image, visible=True), | |
| gr.update(value=frames, visible=True), | |
| gr.update(value=metadata, visible=True), | |
| f"Processing Batch: {batch} / Total Batches: {total_batches}", | |
| gr.update(choices=batch_choices, value=f"Batch {batch}", visible=True), | |
| state_batches, | |
| current_display_batch, | |
| ) | |
| await asyncio.sleep(0.5) | |
| def generate_header(base64_logo, title="⭐ T - Efficient Long Video QA Tool"): | |
| """ | |
| Generates the header section for the app. | |
| """ | |
| return f""" | |
| <h1 style="text-align: center; font-size: 3em; color: #4CAF50; font-family: 'Open Sans', sans-serif; margin-bottom: 20px;">{title}</h1> | |
| <div style="display: flex; justify-content: center; align-items: center; height: 333px;"> | |
| <img src="data:image/png;base64,{base64_logo}" alt="Logo" style="width: auto; height: 300px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);"> | |
| </div> | |
| <div style="display: flex; justify-content: center; align-items: center; margin-top: 20px;"> | |
| <h2 style="text-align: center; font-size: 2em; color: #333; margin-bottom: 30px;">📖 How to Use?</h2> | |
| </div> | |
| """ | |
| def generate_instruction(step, title, description): | |
| """ | |
| Generates a single instruction card. | |
| """ | |
| return f""" | |
| <div style="background-color: #F9F9F9; padding: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); height: 150px; display: flex; flex-direction: column; justify-content: flex-start;"> | |
| <h3 style="font-size: 1.5em; color: #4CAF50; font-family: 'Open Sans', sans-serif; margin-bottom: 10px;">Step {step}: {title}</h3> | |
| <p style="font-size: 1em; color: #666; line-height: 1.5; margin: 0;"> | |
| {description} | |
| </p> | |
| </div> | |
| """ | |
| def create_ui_components(default_video_path): | |
| """ | |
| Creates the UI components for the Gradio application. | |
| """ | |
| # Layout in two columns | |
| with gr.Row(equal_height=True, elem_id="video-container"): | |
| with gr.Column(scale=1, min_width=300): | |
| # Left Column: Video Upload | |
| gr.Markdown(""" | |
| <br> | |
| <h2 style="color: #333;">Upload Your Video</h3> | |
| <p style="color: #666; font-size: 0.9em;">You can upload a sample video or provide your own video for analysis.</p> | |
| """) | |
| video_input = gr.File( | |
| label="Select Video", | |
| type="filepath", | |
| value=default_video_path, | |
| interactive=True | |
| ) | |
| with gr.Column(scale=1.5, min_width=400): | |
| # Right Column: Video Preview | |
| gr.Markdown(""" | |
| <br> | |
| <h2 style="color: #333;">Video Preview</h3> | |
| <p style="color: #666; font-size: 0.9em;">View your video here before starting the analysis.</p> | |
| """) | |
| video_preview = gr.Video( | |
| label="Preview", | |
| value=default_video_path, | |
| visible=True, | |
| autoplay=True, | |
| loop=True, | |
| ) | |
| # add a textbox to input question | |
| openai_api_key_input = gr.Textbox( | |
| label="Provide your OpenAI API Key", | |
| placeholder="sk-...", | |
| value="sk-...", | |
| type="text", | |
| elem_id="openai-api-key-input", | |
| ) | |
| question_input = gr.Textbox( | |
| label="Ask a Question", | |
| placeholder="", | |
| value="Where's the microwave? A. Under the cabinet B. On top of the refrigerator C. Next to the stove D. Beside the sink E. In the pantry", | |
| type="text", | |
| elem_id="question-input", | |
| ) | |
| submit_button = gr.Button( | |
| "Analyze!", | |
| elem_id="analyze-button", | |
| ) | |
| # Add a new component for displaying the video preview | |
| state_batches = gr.State([]) # Stores all generated batch data | |
| current_display_batch = gr.State(None) # Tracks the currently displayed batch | |
| output_timeline = gr.Image(label="Video Timeline", type="pil", visible=False) | |
| output_frames = gr.Gallery(label="Sampled Frames", columns=8, visible=False, height=200) | |
| batch_status = gr.Text(label="Batch Status", value="No Batch Processed Yet", visible=True) | |
| batch_selector = gr.Dropdown(choices=[], label="Select Batch", visible=False) | |
| output_metadata = gr.JSON(label="Video Metadata", visible=False) | |
| return ( | |
| openai_api_key_input, | |
| video_input, | |
| question_input, | |
| submit_button, | |
| video_preview, # Add the video preview component | |
| state_batches, | |
| current_display_batch, | |
| output_timeline, | |
| output_frames, | |
| batch_status, | |
| batch_selector, | |
| output_metadata, | |
| ) | |
| def update_video_preview(video_file, default_video_path): | |
| return gr.update(value=(video_file.name if video_file else default_video_path), visible=True, autoplay=True, loop=True) | |
| if __name__ == "__main__": | |
| # Default sample video path | |
| sample_video_path = "data/sample.mp4" | |
| logo_path = "data/logo.png" | |
| base64_logo = img2base64(logo_path) | |
| with gr.Blocks() as demo: | |
| # Add header | |
| gr.Markdown(generate_header(base64_logo)) | |
| # Add instructions | |
| steps = [ | |
| ("Upload", "Sample video is provided. You can also upload your own!<br>Click <strong>Video Preview</strong> to preview it."), | |
| ("Analyze", "Ask a question and click <strong>'Analyze'</strong>.<br>The system will track keyframes to answer your question."), | |
| ("Visualize", "View keyframes with their sample distribution.<br>Explore keyframe tracking dynamics visually!"), | |
| ] | |
| with gr.Row(equal_height=True, elem_id="instructions-container"): | |
| for i, (title, description) in enumerate(steps, start=1): | |
| with gr.Column(scale=1, min_width=100): | |
| gr.Markdown(generate_instruction(i, title, description)) | |
| ( | |
| openai_api_key_input, | |
| video_input, | |
| question_input, | |
| submit_button, | |
| video_preview, # Video preview component | |
| state_batches, | |
| current_display_batch, | |
| output_timeline, | |
| output_frames, | |
| batch_status, | |
| batch_selector, | |
| output_metadata, | |
| ) = create_ui_components(sample_video_path) | |
| video_input.change( | |
| fn=update_video_preview, | |
| inputs=[video_input, gr.State(sample_video_path)], | |
| outputs=video_preview, | |
| ) | |
| submit_button.click( | |
| fn=process_video_iteratively_with_state, | |
| inputs=[video_input, question_input, openai_api_key_input, state_batches, current_display_batch], | |
| outputs=[ | |
| output_timeline, | |
| output_frames, | |
| output_metadata, | |
| batch_status, | |
| batch_selector, | |
| state_batches, | |
| current_display_batch, | |
| ], | |
| ) | |
| batch_selector.change( | |
| fn=switch_batch, | |
| inputs=[state_batches, batch_selector], | |
| outputs=[output_timeline, output_frames, output_metadata, current_display_batch], | |
| ) | |
| # Launch Gradio application | |
| demo.launch(share=True, server_name="0.0.0.0", server_port=8088) |