import gradio as gr import torch from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM import cv2 import numpy as np from typing import Optional import tempfile import os MID = "apple/FastVLM-7B" IMAGE_TOKEN_INDEX = -200 # Load model and tokenizer print("Loading FastVLM model...") tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True, ) print("Model loaded successfully!") def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): """Extract frames from video""" cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames == 0: cap.release() return [] frames = [] if sampling_method == "uniform": # Uniform sampling indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) elif sampling_method == "first": # Take first N frames indices = list(range(min(num_frames, total_frames))) elif sampling_method == "last": # Take last N frames start = max(0, total_frames - num_frames) indices = list(range(start, total_frames)) else: # middle # Take frames from the middle start = max(0, (total_frames - num_frames) // 2) indices = list(range(start, min(start + num_frames, total_frames))) for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() return frames def caption_frame(image: Image.Image, prompt: str) -> str: """Generate caption for a single frame""" # Build chat with custom prompt messages = [ {"role": "user", "content": f"\n{prompt}"} ] rendered = tok.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) pre, post = rendered.split("", 1) # Tokenize the text around the image token pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids # Splice in the IMAGE token id img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) # Preprocess image px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] px = px.to(model.device, dtype=model.dtype) # Generate with torch.no_grad(): out = model.generate( inputs=input_ids, attention_mask=attention_mask, images=px, max_new_tokens=256, temperature=0.7, do_sample=True, ) caption = tok.decode(out[0], skip_special_tokens=True) # Extract only the generated part if prompt in caption: caption = caption.split(prompt)[-1].strip() return caption def process_video( video_path: str, num_frames: int, sampling_method: str, caption_mode: str, custom_prompt: str, progress=gr.Progress() ) -> tuple: """Process video and generate captions""" if not video_path: return "Please upload a video first.", None, None progress(0, desc="Extracting frames...") frames = extract_frames(video_path, num_frames, sampling_method) if not frames: return "Failed to extract frames from video.", None, None # Prepare prompt based on mode if caption_mode == "Detailed Description": prompt = "Describe this image in detail, including all visible objects, actions, and the overall scene." elif caption_mode == "Brief Summary": prompt = "Provide a brief one-sentence description of what's happening in this image." elif caption_mode == "Action Recognition": prompt = "What action or activity is taking place in this image? Focus on the main action." else: # Custom prompt = custom_prompt if custom_prompt else "Describe this image." captions = [] frame_previews = [] for i, frame in enumerate(frames): progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") caption = caption_frame(frame, prompt) captions.append(f"**Frame {i + 1}:** {caption}") frame_previews.append(frame) progress(1.0, desc="Generating summary...") # Combine captions into a narrative full_caption = "\n\n".join(captions) # Generate overall summary if multiple frames if len(frames) > 1: summary_prompt = f"Based on these frame descriptions, provide a coherent summary of the video:\n{full_caption}\n\nSummary:" # For simplicity, we'll just combine the captions video_summary = f"## Video Analysis ({len(frames)} frames analyzed)\n\n{full_caption}" else: video_summary = f"## Video Analysis\n\n{full_caption}" return video_summary, frame_previews, video_path # Create the Gradio interface with gr.Blocks(css=""" .video-container { height: calc(100vh - 100px) !important; } .sidebar { height: calc(100vh - 100px) !important; overflow-y: auto; } """) as demo: gr.Markdown("# 🎬 FastVLM Video Captioning") with gr.Row(): # Main video display with gr.Column(scale=7): video_display = gr.Video( label="Video Input", height=600, elem_classes=["video-container"], autoplay=True, loop=True ) # Sidebar with controls with gr.Sidebar(width=400, elem_classes=["sidebar"]): gr.Markdown("## ⚙️ Settings") with gr.Group(): gr.Markdown("### Frame Sampling") num_frames = gr.Slider( minimum=1, maximum=16, value=8, step=1, label="Number of Frames to Analyze", info="More frames = better understanding but slower processing" ) sampling_method = gr.Radio( choices=["uniform", "first", "last", "middle"], value="uniform", label="Sampling Method", info="How to select frames from the video" ) with gr.Group(): gr.Markdown("### Caption Settings") caption_mode = gr.Radio( choices=["Detailed Description", "Brief Summary", "Action Recognition", "Custom"], value="Detailed Description", label="Caption Mode" ) custom_prompt = gr.Textbox( label="Custom Prompt", placeholder="Enter your custom prompt here...", visible=False, lines=3 ) process_btn = gr.Button("🎯 Analyze Video", variant="primary", size="lg") gr.Markdown("### 📝 Results") output_text = gr.Markdown( value="Upload a video and click 'Analyze Video' to begin.", elem_classes=["output-text"] ) with gr.Accordion("🖼️ Analyzed Frames", open=False): frame_gallery = gr.Gallery( label="Extracted Frames", show_label=False, columns=2, rows=4, object_fit="contain", height="auto" ) # Show/hide custom prompt based on mode selection def toggle_custom_prompt(mode): return gr.Textbox(visible=(mode == "Custom")) caption_mode.change( toggle_custom_prompt, inputs=[caption_mode], outputs=[custom_prompt] ) # Upload handler def handle_upload(video): if video: return video, "Video loaded! Click 'Analyze Video' to generate captions." return None, "Upload a video to begin." video_display.upload( handle_upload, inputs=[video_display], outputs=[video_display, output_text] ) # Process button process_btn.click( process_video, inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt], outputs=[output_text, frame_gallery, video_display] ) demo.launch()