import gradio as gr import os import numpy as np import plotly.graph_objects as go import plotly.express as px from pathlib import Path import sys import asyncio import json if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # Get absolute paths for all files def get_absolute_path(relative_path): """Convert relative path to absolute path""" return os.path.abspath(relative_path) def check_file_exists(file_path): """Check if file exists and return absolute path or None""" abs_path = get_absolute_path(file_path) if os.path.exists(abs_path): return abs_path else: print(f"Warning: File not found: {file_path}") return None # Define the data structure with absolute paths SAMPLES = {} for i in range(2, 9): # Updated to handle 8 assault videos sample_data = { "input": f"video/Assault{i:03d}_x264.mp4", # Updated naming pattern "optical_flow": f"optical_flow/Assault{i:03d}_x264.mp4", "yolo": f"yolo/Assault{i:03d}_x264.mp4", "vggt": f"vggt/Assault{i:03d}_x264.mp4", # "pcd": f"pcd/Assault{i:03d}_x264.pcd", "qa": f"qa/Assault{i:03d}_x264.json" } # Convert to absolute paths and check existence SAMPLES[i] = {} for key, path in sample_data.items(): abs_path = check_file_exists(path) SAMPLES[i][key] = abs_path def load_qa_data(qa_file): """Load QA data from JSON file and format for chat display""" try: if not qa_file or not os.path.exists(qa_file): return [] with open(qa_file, 'r', encoding='utf-8') as f: qa_data = json.load(f) # Format QA pairs for Gradio chat interface chat_history = [] for qa_pair in qa_data.get('qa_pairs', []): question = qa_pair.get('question', '') answer = qa_pair.get('answer', '') # Add question (user message) chat_history.append([question, answer]) return chat_history except Exception as e: print(f"Error loading QA data {qa_file}: {e}") return [] def get_qa_metadata(qa_file): """Get metadata from QA JSON file""" try: if not qa_file or not os.path.exists(qa_file): return {} with open(qa_file, 'r', encoding='utf-8') as f: qa_data = json.load(f) return { 'video_name': qa_data.get('video_name', ''), 'timestamp': qa_data.get('timestamp', ''), 'model_path': qa_data.get('model_path', ''), 'max_num_frames': qa_data.get('max_num_frames', 0), 'total_questions': len(qa_data.get('qa_pairs', [])) } except Exception as e: print(f"Error loading QA metadata {qa_file}: {e}") return {} def load_point_cloud_plotly(pcd_file): """Load point cloud and create a 3D plotly visualization""" try: if not pcd_file or not os.path.exists(pcd_file): return None pcd = o3d.io.read_point_cloud(pcd_file) points = np.asarray(pcd.points) colors = np.asarray(pcd.colors) if pcd.has_colors() else None if len(points) == 0: return None # Subsample points if too many (for performance) if len(points) > 10000: indices = np.random.choice(len(points), 10000, replace=False) points = points[indices] if colors is not None: colors = colors[indices] # Create 3D scatter plot if colors is not None and len(colors) > 0: # Convert colors to RGB if needed if colors.max() <= 1.0: colors = (colors * 255).astype(int) color_rgb = [f'rgb({r},{g},{b})' for r, g, b in colors] fig = go.Figure(data=[go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers', marker=dict( size=1.7, color=color_rgb, ), text=[f'Point {i}' for i in range(len(points))], hovertemplate='Point %{text}
X: %{x}
Y: %{y}
Z: %{z}' )]) else: fig = go.Figure(data=[go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers', marker=dict( size=2, color=points[:, 2], # Color by Z coordinate colorscale='Viridis', showscale=True ), text=[f'Point {i}' for i in range(len(points))], hovertemplate='Point %{text}
X: %{x}
Y: %{y}
Z: %{z}' )]) fig.update_layout( title=f'3D Point Cloud Visualization - {os.path.basename(pcd_file)}', scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='Z', camera=dict( eye=dict(x=1.5, y=1.5, z=1.5) ), bgcolor='rgb(10, 10, 10)', ), margin=dict(l=0, r=0, t=50, b=0), paper_bgcolor='rgb(20, 20, 20)', plot_bgcolor='rgb(20, 20, 20)', font=dict(color='white') ) return fig except Exception as e: print(f"Error loading point cloud {pcd_file}: {e}") return None def create_sample_gallery(sample_id): """Create a gallery view for a specific sample""" sample = SAMPLES[sample_id] # Load point cloud visualization pcd_plot = load_point_cloud_plotly(sample["pcd"]) return ( sample["input"], # Input video sample["optical_flow"], # Optical flow video sample["yolo"], # YOLO video sample["vggt"], # VGGT video pcd_plot # Point cloud plot ) def create_overview_gallery(): """Create an overview showing all samples""" gallery_items = [] for i in range(1, 6): sample = SAMPLES[i] # Only add items that exist if sample["input"]: gallery_items.append((sample["input"], f"Sample {i} - Input")) if sample["optical_flow"]: gallery_items.append((sample["optical_flow"], f"Sample {i} - Optical Flow")) if sample["yolo"]: gallery_items.append((sample["yolo"], f"Sample {i} - YOLO")) if sample["vggt"]: gallery_items.append((sample["vggt"], f"Sample {i} - VGGT")) return gallery_items # Custom CSS for better styling custom_css = """ # .gradio-container { # max-width: 1200px !important; # } .gallery-item { border-radius: 10px; } h1 { text-align: center; color: #2c3e50; margin-bottom: 30px; } .tab-nav { margin-bottom: 20px; } .qa-section-header { font-size: 1.2em; color: #2c3e50; margin-top: 20px; } .qa-metadata { background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #007bff; } .qa-info { background-color: #e7f3ff; padding: 10px; border-radius: 5px; font-style: italic; } """ # Create the Gradio interface with gr.Blocks(css=custom_css, title="Anomalous Event Detection") as demo: gr.Markdown("# 🎥 Results Gallery") with gr.Tabs() as tabs: # Individual sample tabs for i in range(2, 9): with gr.Tab(f"🎬 Sample {i-1}"): gr.Markdown(f"## Sample {i-1} - Detailed View") sample = SAMPLES[i] # Top Row: Input Video + Chat History with gr.Row(): # Left Column: Input Video (narrower) with gr.Column(scale=1): gr.Markdown("### 📹 Input Video") if sample["input"]: input_video = gr.Video( value=sample["input"], label="Original Input", show_label=True ) else: gr.Markdown("❌ Input video not found") # Right Column: Q&A Chat History with gr.Column(scale=1, min_width=400): gr.Markdown("### 💬 Q&A Chat History") if sample["qa"]: # Load QA metadata qa_metadata = get_qa_metadata(sample["qa"]) if qa_metadata: gr.Markdown(f""" **📊 Chat Session Info:** - **Video:** {qa_metadata.get('video_name', 'N/A')} - **Total Questions:** {qa_metadata.get('total_questions', 0)} - **Max Frames:** {qa_metadata.get('max_num_frames', 0)} - **Timestamp:** {qa_metadata.get('timestamp', 'N/A')[:19].replace('T', ' ')} """) # Load and display chat history qa_history = load_qa_data(sample["qa"]) if qa_history: chatbot = gr.Chatbot( value=qa_history, label="Video Analysis Q&A", show_label=True, height=500, avatar_images=["👤", "🤖"] ) gr.Markdown(""" 💡 **About this Q&A:** Questions asked by humans about the video content and answers from an AI model trained for video analysis. """) else: gr.Markdown("❌ No Q&A data available for this sample") else: gr.Markdown("❌ Q&A file not found for this sample") # VGGT and Point Cloud in a row with gr.Row(): with gr.Column(): gr.Markdown("### 🎮 VGGT") if sample["vggt"]: vggt_video = gr.Video( value=sample["vggt"], label="VGGT Processing", show_label=True ) else: gr.Markdown("❌ VGGT video not found") with gr.Column(): pass # gr.Markdown("### ☁️ 3D Point Cloud") # if sample["pcd"]: # try: # pcd_plot = gr.Plot( # value=load_point_cloud_plotly(sample["pcd"]), # label="Interactive 3D Point Cloud", # show_label=True # ) # except Exception as e: # gr.Markdown(f"❌ Error loading point cloud: {str(e)}") # else: # gr.Markdown("❌ Point cloud file not found") # Bottom Section: Other Analysis Results with gr.Row(): with gr.Column(scale=2): # Optical Flow and YOLO in a row with gr.Row(): with gr.Column(): gr.Markdown("### 🌊 Optical Flow") if sample["optical_flow"]: optical_flow_video = gr.Video( value=sample["optical_flow"], label="Motion Analysis", show_label=True ) else: gr.Markdown("❌ Optical flow video not found") with gr.Column(): gr.Markdown("### 🎯 YOLO Detection") if sample["yolo"]: yolo_video = gr.Video( value=sample["yolo"], label="Object Detection", show_label=True ) else: gr.Markdown("❌ YOLO video not found") # Comparison tab # with gr.Tab("🔍 Compare"): # gr.Markdown("## Compare Different Samples") # gr.Markdown("Select two samples to compare side by side") # with gr.Row(): # sample1_dropdown = gr.Dropdown( # choices=list(range(1, 6)), # value=1, # label="Sample 1" # ) # sample2_dropdown = gr.Dropdown( # choices=list(range(1, 6)), # value=2, # label="Sample 2" # ) # with gr.Row(): # with gr.Column(): # gr.Markdown("### Sample 1") # comp_input1 = gr.Video(label="Input") # comp_optical1 = gr.Video(label="Optical Flow") # comp_yolo1 = gr.Video(label="YOLO") # comp_vggt1 = gr.Video(label="VGGT") # comp_pcd1 = gr.Plot(label="Point Cloud") # with gr.Column(): # gr.Markdown("### Sample 2") # comp_input2 = gr.Video(label="Input") # comp_optical2 = gr.Video(label="Optical Flow") # comp_yolo2 = gr.Video(label="YOLO") # comp_vggt2 = gr.Video(label="VGGT") # comp_pcd2 = gr.Plot(label="Point Cloud") # # Update comparison when dropdowns change # def update_comparison(sample1_id, sample2_id): # try: # sample1_results = create_sample_gallery(sample1_id) # sample2_results = create_sample_gallery(sample2_id) # return sample1_results + sample2_results # except Exception as e: # print(f"Error updating comparison: {e}") # return [None] * 10 # for dropdown in [sample1_dropdown, sample2_dropdown]: # dropdown.change( # update_comparison, # inputs=[sample1_dropdown, sample2_dropdown], # outputs=[ # comp_input1, comp_optical1, comp_yolo1, comp_vggt1, comp_pcd1, # comp_input2, comp_optical2, comp_yolo2, comp_vggt2, comp_pcd2 # ] # ) if __name__ == "__main__": # Print file status for debugging print("=== File Status Check ===") for i in range(2, 9): print(f"\nSample {i}:") for key, path in SAMPLES[i].items(): status = "✅ Found" if path else "❌ Missing" print(f" {key}: {status}") print(f"\n=== Starting Gradio App ===") demo.launch( share=True, server_name="127.0.0.1", server_port=7861, show_error=True, inbrowser=True )