Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from typing import List | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| # ========== Configuration ========== | |
| MODEL_ID = "WoWolf/Qwen2_5vl-7b-fm-tuned" | |
| MAX_FRAMES = 48 | |
| MAX_NEW_TOKENS = 128 | |
| TEMPERATURE = 1.0 | |
| # ========== Video Examples Configuration ========== | |
| VIDEO_EXAMPLES = { | |
| "1_raw.mp4": { | |
| "path": "1_raw.mp4", | |
| "questions": ["What's happening in this video?", "Which hand holds the pen?"] | |
| }, | |
| "4_raw.mp4": { | |
| "path": "4_raw.mp4", | |
| "questions": ["What's happening in this video?", "What is the main action in the video?"] | |
| }, | |
| "6_raw.mp4": { | |
| "path": "6_raw.mp4", | |
| "questions": ["What's happening in this video?", "What's the right hand doing?"] | |
| }, | |
| } | |
| # ========== Load Model & Processor ========== | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| ) | |
| # ========== Video Frame Extraction ========== | |
| def extract_video_frames(video_path: str, max_frames: int = 8) -> List[Image.Image]: | |
| """Extract key frames from video using OpenCV""" | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| cap.release() | |
| return frames | |
| # Select frames evenly | |
| frame_indices = np.linspace(0, total_frames - 1, min(max_frames, total_frames), dtype=int) | |
| for frame_idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| # ========== Message Builder ========== | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful assistant that watches a user-provided video and answers " | |
| "questions about it concisely and accurately." | |
| ) | |
| def build_messages(frames: List[Image.Image], question: str, fps: float = 1.0): | |
| """Build messages in Qwen-VL format""" | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": SYSTEM_PROMPT}], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "video", | |
| "video": frames, | |
| "fps": fps, | |
| }, | |
| {"type": "text", "text": question}, | |
| ], | |
| }, | |
| ] | |
| return messages | |
| # ========== Helper Functions ========== | |
| def update_video_display(video_name): | |
| """Update video display and example questions when video is selected""" | |
| if video_name is None: | |
| return None, "" | |
| video_info = VIDEO_EXAMPLES[video_name] | |
| video_path = video_info["path"] | |
| example_questions = "\n".join([f"• {q}" for q in video_info["questions"]]) | |
| return video_path, example_questions | |
| def fill_question(video_name, question_idx): | |
| """Fill the question textbox with selected example question""" | |
| if video_name is None: | |
| return "" | |
| questions = VIDEO_EXAMPLES[video_name]["questions"] | |
| if 0 <= question_idx < len(questions): | |
| return questions[question_idx] | |
| return "" | |
| # ========== Inference ========== | |
| def answer(video_name, question): | |
| if video_name is None: | |
| return "Please select a video first." | |
| if not question or question.strip() == "": | |
| question = "Describe this video in detail." | |
| video_path = VIDEO_EXAMPLES[video_name]["path"] | |
| # Extract frames from video | |
| frames = extract_video_frames(video_path, max_frames=MAX_FRAMES) | |
| if not frames: | |
| return "Error: Unable to extract frames from video." | |
| # Build messages | |
| messages = build_messages(frames, question, fps=1.0) | |
| # Apply chat template | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Process vision info | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # Prepare inputs | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(model.device) | |
| # Generation settings | |
| gen_kwargs = dict( | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=(TEMPERATURE > 0.0), | |
| temperature=TEMPERATURE if TEMPERATURE > 0 else None, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| # Generate | |
| generated_ids = model.generate(**inputs, **gen_kwargs) | |
| # Decode output | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| return output_text.strip() | |
| # ========== Gradio UI ========== | |
| with gr.Blocks(title="Video Q&A with Qwen2.5-VL-7B") as demo: | |
| gr.Markdown( | |
| """ | |
| # FoundationMotion: Auto-Labeling and Reasoning about Spatial Movement in Videos | |
| Select a video, ask a question, and get an answer! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Video selector dropdown | |
| video_selector = gr.Dropdown( | |
| choices=list(VIDEO_EXAMPLES.keys()), | |
| label="Select a Video", | |
| value=None, | |
| interactive=True, | |
| ) | |
| # Video display (read-only) | |
| video_display = gr.Video( | |
| label="Video Preview", | |
| height=400, | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| # Example questions display | |
| example_questions_display = gr.Textbox( | |
| label="Example Questions (click buttons below to use)", | |
| lines=3, | |
| interactive=False, | |
| ) | |
| # Buttons for quick question selection | |
| with gr.Row(): | |
| q1_btn = gr.Button("Use Question 1", size="sm") | |
| q2_btn = gr.Button("Use Question 2", size="sm") | |
| question = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Type your question or click an example button above", | |
| lines=2, | |
| ) | |
| ask_btn = gr.Button("Ask", variant="primary") | |
| output = gr.Textbox(label="Answer", lines=10, show_copy_button=True) | |
| # Event handlers | |
| video_selector.change( | |
| fn=update_video_display, | |
| inputs=[video_selector], | |
| outputs=[video_display, example_questions_display], | |
| ) | |
| q1_btn.click( | |
| fn=lambda v: fill_question(v, 0), | |
| inputs=[video_selector], | |
| outputs=[question], | |
| ) | |
| q2_btn.click( | |
| fn=lambda v: fill_question(v, 1), | |
| inputs=[video_selector], | |
| outputs=[question], | |
| ) | |
| ask_btn.click( | |
| fn=answer, | |
| inputs=[video_selector, question], | |
| outputs=[output], | |
| ) | |
| # ========== Launch ========== | |
| if __name__ == "__main__": | |
| demo.launch() |