import gradio as gr
import os
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from pathlib import Path
import sys
import asyncio
import json
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Get absolute paths for all files
def get_absolute_path(relative_path):
"""Convert relative path to absolute path"""
return os.path.abspath(relative_path)
def check_file_exists(file_path):
"""Check if file exists and return absolute path or None"""
abs_path = get_absolute_path(file_path)
if os.path.exists(abs_path):
return abs_path
else:
print(f"Warning: File not found: {file_path}")
return None
# Define the data structure with absolute paths
SAMPLES = {}
for i in range(2, 9): # Updated to handle 8 assault videos
sample_data = {
"input": f"video/Assault{i:03d}_x264.mp4", # Updated naming pattern
"optical_flow": f"optical_flow/Assault{i:03d}_x264.mp4",
"yolo": f"yolo/Assault{i:03d}_x264.mp4",
"vggt": f"vggt/Assault{i:03d}_x264.mp4",
# "pcd": f"pcd/Assault{i:03d}_x264.pcd",
"qa": f"qa/Assault{i:03d}_x264.json"
}
# Convert to absolute paths and check existence
SAMPLES[i] = {}
for key, path in sample_data.items():
abs_path = check_file_exists(path)
SAMPLES[i][key] = abs_path
def load_qa_data(qa_file):
"""Load QA data from JSON file and format for chat display"""
try:
if not qa_file or not os.path.exists(qa_file):
return []
with open(qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
# Format QA pairs for Gradio chat interface
chat_history = []
for qa_pair in qa_data.get('qa_pairs', []):
question = qa_pair.get('question', '')
answer = qa_pair.get('answer', '')
# Add question (user message)
chat_history.append([question, answer])
return chat_history
except Exception as e:
print(f"Error loading QA data {qa_file}: {e}")
return []
def get_qa_metadata(qa_file):
"""Get metadata from QA JSON file"""
try:
if not qa_file or not os.path.exists(qa_file):
return {}
with open(qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
return {
'video_name': qa_data.get('video_name', ''),
'timestamp': qa_data.get('timestamp', ''),
'model_path': qa_data.get('model_path', ''),
'max_num_frames': qa_data.get('max_num_frames', 0),
'total_questions': len(qa_data.get('qa_pairs', []))
}
except Exception as e:
print(f"Error loading QA metadata {qa_file}: {e}")
return {}
def load_point_cloud_plotly(pcd_file):
"""Load point cloud and create a 3D plotly visualization"""
try:
if not pcd_file or not os.path.exists(pcd_file):
return None
pcd = o3d.io.read_point_cloud(pcd_file)
points = np.asarray(pcd.points)
colors = np.asarray(pcd.colors) if pcd.has_colors() else None
if len(points) == 0:
return None
# Subsample points if too many (for performance)
if len(points) > 10000:
indices = np.random.choice(len(points), 10000, replace=False)
points = points[indices]
if colors is not None:
colors = colors[indices]
# Create 3D scatter plot
if colors is not None and len(colors) > 0:
# Convert colors to RGB if needed
if colors.max() <= 1.0:
colors = (colors * 255).astype(int)
color_rgb = [f'rgb({r},{g},{b})' for r, g, b in colors]
fig = go.Figure(data=[go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode='markers',
marker=dict(
size=1.7,
color=color_rgb,
),
text=[f'Point {i}' for i in range(len(points))],
hovertemplate='Point %{text}
X: %{x}
Y: %{y}
Z: %{z}'
)])
else:
fig = go.Figure(data=[go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode='markers',
marker=dict(
size=2,
color=points[:, 2], # Color by Z coordinate
colorscale='Viridis',
showscale=True
),
text=[f'Point {i}' for i in range(len(points))],
hovertemplate='Point %{text}
X: %{x}
Y: %{y}
Z: %{z}'
)])
fig.update_layout(
title=f'3D Point Cloud Visualization - {os.path.basename(pcd_file)}',
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z',
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5)
),
bgcolor='rgb(10, 10, 10)',
),
margin=dict(l=0, r=0, t=50, b=0),
paper_bgcolor='rgb(20, 20, 20)',
plot_bgcolor='rgb(20, 20, 20)',
font=dict(color='white')
)
return fig
except Exception as e:
print(f"Error loading point cloud {pcd_file}: {e}")
return None
def create_sample_gallery(sample_id):
"""Create a gallery view for a specific sample"""
sample = SAMPLES[sample_id]
# Load point cloud visualization
pcd_plot = load_point_cloud_plotly(sample["pcd"])
return (
sample["input"], # Input video
sample["optical_flow"], # Optical flow video
sample["yolo"], # YOLO video
sample["vggt"], # VGGT video
pcd_plot # Point cloud plot
)
def create_overview_gallery():
"""Create an overview showing all samples"""
gallery_items = []
for i in range(1, 6):
sample = SAMPLES[i]
# Only add items that exist
if sample["input"]:
gallery_items.append((sample["input"], f"Sample {i} - Input"))
if sample["optical_flow"]:
gallery_items.append((sample["optical_flow"], f"Sample {i} - Optical Flow"))
if sample["yolo"]:
gallery_items.append((sample["yolo"], f"Sample {i} - YOLO"))
if sample["vggt"]:
gallery_items.append((sample["vggt"], f"Sample {i} - VGGT"))
return gallery_items
# Custom CSS for better styling
custom_css = """
# .gradio-container {
# max-width: 1200px !important;
# }
.gallery-item {
border-radius: 10px;
}
h1 {
text-align: center;
color: #2c3e50;
margin-bottom: 30px;
}
.tab-nav {
margin-bottom: 20px;
}
.qa-section-header {
font-size: 1.2em;
color: #2c3e50;
margin-top: 20px;
}
.qa-metadata {
background-color: #f8f9fa;
padding: 15px;
border-radius: 8px;
border-left: 4px solid #007bff;
}
.qa-info {
background-color: #e7f3ff;
padding: 10px;
border-radius: 5px;
font-style: italic;
}
"""
# Create the Gradio interface
with gr.Blocks(css=custom_css, title="Anomalous Event Detection") as demo:
gr.Markdown("# 🎥 Results Gallery")
with gr.Tabs() as tabs:
# Individual sample tabs
for i in range(2, 9):
with gr.Tab(f"🎬 Sample {i-1}"):
gr.Markdown(f"## Sample {i-1} - Detailed View")
sample = SAMPLES[i]
# Top Row: Input Video + Chat History
with gr.Row():
# Left Column: Input Video (narrower)
with gr.Column(scale=1):
gr.Markdown("### 📹 Input Video")
if sample["input"]:
input_video = gr.Video(
value=sample["input"],
label="Original Input",
show_label=True
)
else:
gr.Markdown("❌ Input video not found")
# Right Column: Q&A Chat History
with gr.Column(scale=1, min_width=400):
gr.Markdown("### 💬 Q&A Chat History")
if sample["qa"]:
# Load QA metadata
qa_metadata = get_qa_metadata(sample["qa"])
if qa_metadata:
gr.Markdown(f"""
**📊 Chat Session Info:**
- **Video:** {qa_metadata.get('video_name', 'N/A')}
- **Total Questions:** {qa_metadata.get('total_questions', 0)}
- **Max Frames:** {qa_metadata.get('max_num_frames', 0)}
- **Timestamp:** {qa_metadata.get('timestamp', 'N/A')[:19].replace('T', ' ')}
""")
# Load and display chat history
qa_history = load_qa_data(sample["qa"])
if qa_history:
chatbot = gr.Chatbot(
value=qa_history,
label="Video Analysis Q&A",
show_label=True,
height=500,
avatar_images=["👤", "🤖"]
)
gr.Markdown("""
💡 **About this Q&A:** Questions asked by humans about the video content and answers from an AI model trained for video analysis.
""")
else:
gr.Markdown("❌ No Q&A data available for this sample")
else:
gr.Markdown("❌ Q&A file not found for this sample")
# VGGT and Point Cloud in a row
with gr.Row():
with gr.Column():
gr.Markdown("### 🎮 VGGT")
if sample["vggt"]:
vggt_video = gr.Video(
value=sample["vggt"],
label="VGGT Processing",
show_label=True
)
else:
gr.Markdown("❌ VGGT video not found")
with gr.Column():
pass
# gr.Markdown("### ☁️ 3D Point Cloud")
# if sample["pcd"]:
# try:
# pcd_plot = gr.Plot(
# value=load_point_cloud_plotly(sample["pcd"]),
# label="Interactive 3D Point Cloud",
# show_label=True
# )
# except Exception as e:
# gr.Markdown(f"❌ Error loading point cloud: {str(e)}")
# else:
# gr.Markdown("❌ Point cloud file not found")
# Bottom Section: Other Analysis Results
with gr.Row():
with gr.Column(scale=2):
# Optical Flow and YOLO in a row
with gr.Row():
with gr.Column():
gr.Markdown("### 🌊 Optical Flow")
if sample["optical_flow"]:
optical_flow_video = gr.Video(
value=sample["optical_flow"],
label="Motion Analysis",
show_label=True
)
else:
gr.Markdown("❌ Optical flow video not found")
with gr.Column():
gr.Markdown("### 🎯 YOLO Detection")
if sample["yolo"]:
yolo_video = gr.Video(
value=sample["yolo"],
label="Object Detection",
show_label=True
)
else:
gr.Markdown("❌ YOLO video not found")
# Comparison tab
# with gr.Tab("🔍 Compare"):
# gr.Markdown("## Compare Different Samples")
# gr.Markdown("Select two samples to compare side by side")
# with gr.Row():
# sample1_dropdown = gr.Dropdown(
# choices=list(range(1, 6)),
# value=1,
# label="Sample 1"
# )
# sample2_dropdown = gr.Dropdown(
# choices=list(range(1, 6)),
# value=2,
# label="Sample 2"
# )
# with gr.Row():
# with gr.Column():
# gr.Markdown("### Sample 1")
# comp_input1 = gr.Video(label="Input")
# comp_optical1 = gr.Video(label="Optical Flow")
# comp_yolo1 = gr.Video(label="YOLO")
# comp_vggt1 = gr.Video(label="VGGT")
# comp_pcd1 = gr.Plot(label="Point Cloud")
# with gr.Column():
# gr.Markdown("### Sample 2")
# comp_input2 = gr.Video(label="Input")
# comp_optical2 = gr.Video(label="Optical Flow")
# comp_yolo2 = gr.Video(label="YOLO")
# comp_vggt2 = gr.Video(label="VGGT")
# comp_pcd2 = gr.Plot(label="Point Cloud")
# # Update comparison when dropdowns change
# def update_comparison(sample1_id, sample2_id):
# try:
# sample1_results = create_sample_gallery(sample1_id)
# sample2_results = create_sample_gallery(sample2_id)
# return sample1_results + sample2_results
# except Exception as e:
# print(f"Error updating comparison: {e}")
# return [None] * 10
# for dropdown in [sample1_dropdown, sample2_dropdown]:
# dropdown.change(
# update_comparison,
# inputs=[sample1_dropdown, sample2_dropdown],
# outputs=[
# comp_input1, comp_optical1, comp_yolo1, comp_vggt1, comp_pcd1,
# comp_input2, comp_optical2, comp_yolo2, comp_vggt2, comp_pcd2
# ]
# )
if __name__ == "__main__":
# Print file status for debugging
print("=== File Status Check ===")
for i in range(2, 9):
print(f"\nSample {i}:")
for key, path in SAMPLES[i].items():
status = "✅ Found" if path else "❌ Missing"
print(f" {key}: {status}")
print(f"\n=== Starting Gradio App ===")
demo.launch(
share=True,
server_name="127.0.0.1",
server_port=7861,
show_error=True,
inbrowser=True
)