|
|
import gradio as gr |
|
|
import os |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
from pathlib import Path |
|
|
import sys |
|
|
import asyncio |
|
|
import json |
|
|
if sys.platform == "win32": |
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) |
|
|
|
|
|
|
|
|
def get_absolute_path(relative_path): |
|
|
"""Convert relative path to absolute path""" |
|
|
return os.path.abspath(relative_path) |
|
|
|
|
|
def check_file_exists(file_path): |
|
|
"""Check if file exists and return absolute path or None""" |
|
|
abs_path = get_absolute_path(file_path) |
|
|
if os.path.exists(abs_path): |
|
|
return abs_path |
|
|
else: |
|
|
print(f"Warning: File not found: {file_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
SAMPLES = {} |
|
|
for i in range(2, 9): |
|
|
sample_data = { |
|
|
"input": f"video/Assault{i:03d}_x264.mp4", |
|
|
"optical_flow": f"optical_flow/Assault{i:03d}_x264.mp4", |
|
|
"yolo": f"yolo/Assault{i:03d}_x264.mp4", |
|
|
"vggt": f"vggt/Assault{i:03d}_x264.mp4", |
|
|
|
|
|
"qa": f"qa/Assault{i:03d}_x264.json" |
|
|
} |
|
|
|
|
|
|
|
|
SAMPLES[i] = {} |
|
|
for key, path in sample_data.items(): |
|
|
abs_path = check_file_exists(path) |
|
|
SAMPLES[i][key] = abs_path |
|
|
|
|
|
def load_qa_data(qa_file): |
|
|
"""Load QA data from JSON file and format for chat display""" |
|
|
try: |
|
|
if not qa_file or not os.path.exists(qa_file): |
|
|
return [] |
|
|
|
|
|
with open(qa_file, 'r', encoding='utf-8') as f: |
|
|
qa_data = json.load(f) |
|
|
|
|
|
|
|
|
chat_history = [] |
|
|
for qa_pair in qa_data.get('qa_pairs', []): |
|
|
question = qa_pair.get('question', '') |
|
|
answer = qa_pair.get('answer', '') |
|
|
|
|
|
|
|
|
chat_history.append([question, answer]) |
|
|
|
|
|
return chat_history |
|
|
except Exception as e: |
|
|
print(f"Error loading QA data {qa_file}: {e}") |
|
|
return [] |
|
|
|
|
|
def get_qa_metadata(qa_file): |
|
|
"""Get metadata from QA JSON file""" |
|
|
try: |
|
|
if not qa_file or not os.path.exists(qa_file): |
|
|
return {} |
|
|
|
|
|
with open(qa_file, 'r', encoding='utf-8') as f: |
|
|
qa_data = json.load(f) |
|
|
|
|
|
return { |
|
|
'video_name': qa_data.get('video_name', ''), |
|
|
'timestamp': qa_data.get('timestamp', ''), |
|
|
'model_path': qa_data.get('model_path', ''), |
|
|
'max_num_frames': qa_data.get('max_num_frames', 0), |
|
|
'total_questions': len(qa_data.get('qa_pairs', [])) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error loading QA metadata {qa_file}: {e}") |
|
|
return {} |
|
|
|
|
|
def load_point_cloud_plotly(pcd_file): |
|
|
"""Load point cloud and create a 3D plotly visualization""" |
|
|
try: |
|
|
if not pcd_file or not os.path.exists(pcd_file): |
|
|
return None |
|
|
|
|
|
pcd = o3d.io.read_point_cloud(pcd_file) |
|
|
points = np.asarray(pcd.points) |
|
|
colors = np.asarray(pcd.colors) if pcd.has_colors() else None |
|
|
|
|
|
if len(points) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
if len(points) > 10000: |
|
|
indices = np.random.choice(len(points), 10000, replace=False) |
|
|
points = points[indices] |
|
|
if colors is not None: |
|
|
colors = colors[indices] |
|
|
|
|
|
|
|
|
if colors is not None and len(colors) > 0: |
|
|
|
|
|
if colors.max() <= 1.0: |
|
|
colors = (colors * 255).astype(int) |
|
|
color_rgb = [f'rgb({r},{g},{b})' for r, g, b in colors] |
|
|
|
|
|
fig = go.Figure(data=[go.Scatter3d( |
|
|
x=points[:, 0], |
|
|
y=points[:, 1], |
|
|
z=points[:, 2], |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
size=1.7, |
|
|
color=color_rgb, |
|
|
), |
|
|
text=[f'Point {i}' for i in range(len(points))], |
|
|
hovertemplate='<b>Point %{text}</b><br>X: %{x}<br>Y: %{y}<br>Z: %{z}<extra></extra>' |
|
|
)]) |
|
|
else: |
|
|
fig = go.Figure(data=[go.Scatter3d( |
|
|
x=points[:, 0], |
|
|
y=points[:, 1], |
|
|
z=points[:, 2], |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
size=2, |
|
|
color=points[:, 2], |
|
|
colorscale='Viridis', |
|
|
showscale=True |
|
|
), |
|
|
text=[f'Point {i}' for i in range(len(points))], |
|
|
hovertemplate='<b>Point %{text}</b><br>X: %{x}<br>Y: %{y}<br>Z: %{z}<extra></extra>' |
|
|
)]) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f'3D Point Cloud Visualization - {os.path.basename(pcd_file)}', |
|
|
scene=dict( |
|
|
xaxis_title='X', |
|
|
yaxis_title='Y', |
|
|
zaxis_title='Z', |
|
|
camera=dict( |
|
|
eye=dict(x=1.5, y=1.5, z=1.5) |
|
|
), |
|
|
bgcolor='rgb(10, 10, 10)', |
|
|
), |
|
|
margin=dict(l=0, r=0, t=50, b=0), |
|
|
paper_bgcolor='rgb(20, 20, 20)', |
|
|
plot_bgcolor='rgb(20, 20, 20)', |
|
|
font=dict(color='white') |
|
|
) |
|
|
|
|
|
return fig |
|
|
except Exception as e: |
|
|
print(f"Error loading point cloud {pcd_file}: {e}") |
|
|
return None |
|
|
|
|
|
def create_sample_gallery(sample_id): |
|
|
"""Create a gallery view for a specific sample""" |
|
|
sample = SAMPLES[sample_id] |
|
|
|
|
|
|
|
|
pcd_plot = load_point_cloud_plotly(sample["pcd"]) |
|
|
|
|
|
return ( |
|
|
sample["input"], |
|
|
sample["optical_flow"], |
|
|
sample["yolo"], |
|
|
sample["vggt"], |
|
|
pcd_plot |
|
|
) |
|
|
|
|
|
def create_overview_gallery(): |
|
|
"""Create an overview showing all samples""" |
|
|
gallery_items = [] |
|
|
for i in range(1, 6): |
|
|
sample = SAMPLES[i] |
|
|
|
|
|
if sample["input"]: |
|
|
gallery_items.append((sample["input"], f"Sample {i} - Input")) |
|
|
if sample["optical_flow"]: |
|
|
gallery_items.append((sample["optical_flow"], f"Sample {i} - Optical Flow")) |
|
|
if sample["yolo"]: |
|
|
gallery_items.append((sample["yolo"], f"Sample {i} - YOLO")) |
|
|
if sample["vggt"]: |
|
|
gallery_items.append((sample["vggt"], f"Sample {i} - VGGT")) |
|
|
return gallery_items |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
# .gradio-container { |
|
|
# max-width: 1200px !important; |
|
|
# } |
|
|
.gallery-item { |
|
|
border-radius: 10px; |
|
|
} |
|
|
h1 { |
|
|
text-align: center; |
|
|
color: #2c3e50; |
|
|
margin-bottom: 30px; |
|
|
} |
|
|
.tab-nav { |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
.qa-section-header { |
|
|
font-size: 1.2em; |
|
|
color: #2c3e50; |
|
|
margin-top: 20px; |
|
|
} |
|
|
.qa-metadata { |
|
|
background-color: #f8f9fa; |
|
|
padding: 15px; |
|
|
border-radius: 8px; |
|
|
border-left: 4px solid #007bff; |
|
|
} |
|
|
.qa-info { |
|
|
background-color: #e7f3ff; |
|
|
padding: 10px; |
|
|
border-radius: 5px; |
|
|
font-style: italic; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="Anomalous Event Detection") as demo: |
|
|
gr.Markdown("# 🎥 Results Gallery") |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
|
|
|
|
for i in range(2, 9): |
|
|
with gr.Tab(f"🎬 Sample {i-1}"): |
|
|
gr.Markdown(f"## Sample {i-1} - Detailed View") |
|
|
|
|
|
sample = SAMPLES[i] |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📹 Input Video") |
|
|
if sample["input"]: |
|
|
input_video = gr.Video( |
|
|
value=sample["input"], |
|
|
label="Original Input", |
|
|
show_label=True |
|
|
) |
|
|
else: |
|
|
gr.Markdown("❌ Input video not found") |
|
|
|
|
|
|
|
|
with gr.Column(scale=1, min_width=400): |
|
|
gr.Markdown("### 💬 Q&A Chat History") |
|
|
|
|
|
if sample["qa"]: |
|
|
|
|
|
qa_metadata = get_qa_metadata(sample["qa"]) |
|
|
if qa_metadata: |
|
|
gr.Markdown(f""" |
|
|
**📊 Chat Session Info:** |
|
|
- **Video:** {qa_metadata.get('video_name', 'N/A')} |
|
|
- **Total Questions:** {qa_metadata.get('total_questions', 0)} |
|
|
- **Max Frames:** {qa_metadata.get('max_num_frames', 0)} |
|
|
- **Timestamp:** {qa_metadata.get('timestamp', 'N/A')[:19].replace('T', ' ')} |
|
|
""") |
|
|
|
|
|
|
|
|
qa_history = load_qa_data(sample["qa"]) |
|
|
if qa_history: |
|
|
chatbot = gr.Chatbot( |
|
|
value=qa_history, |
|
|
label="Video Analysis Q&A", |
|
|
show_label=True, |
|
|
height=500, |
|
|
avatar_images=["👤", "🤖"] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
💡 **About this Q&A:** Questions asked by humans about the video content and answers from an AI model trained for video analysis. |
|
|
""") |
|
|
else: |
|
|
gr.Markdown("❌ No Q&A data available for this sample") |
|
|
else: |
|
|
gr.Markdown("❌ Q&A file not found for this sample") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### 🎮 VGGT") |
|
|
|
|
|
if sample["vggt"]: |
|
|
vggt_video = gr.Video( |
|
|
value=sample["vggt"], |
|
|
label="VGGT Processing", |
|
|
show_label=True |
|
|
) |
|
|
else: |
|
|
gr.Markdown("❌ VGGT video not found") |
|
|
|
|
|
with gr.Column(): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### 🌊 Optical Flow") |
|
|
if sample["optical_flow"]: |
|
|
optical_flow_video = gr.Video( |
|
|
value=sample["optical_flow"], |
|
|
label="Motion Analysis", |
|
|
show_label=True |
|
|
) |
|
|
else: |
|
|
gr.Markdown("❌ Optical flow video not found") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### 🎯 YOLO Detection") |
|
|
if sample["yolo"]: |
|
|
yolo_video = gr.Video( |
|
|
value=sample["yolo"], |
|
|
label="Object Detection", |
|
|
show_label=True |
|
|
) |
|
|
else: |
|
|
gr.Markdown("❌ YOLO video not found") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("=== File Status Check ===") |
|
|
for i in range(2, 9): |
|
|
print(f"\nSample {i}:") |
|
|
for key, path in SAMPLES[i].items(): |
|
|
status = "✅ Found" if path else "❌ Missing" |
|
|
print(f" {key}: {status}") |
|
|
|
|
|
print(f"\n=== Starting Gradio App ===") |
|
|
demo.launch( |
|
|
share=True, |
|
|
server_name="127.0.0.1", |
|
|
server_port=7861, |
|
|
show_error=True, |
|
|
inbrowser=True |
|
|
) |
|
|
|