Wither Lloyd commited on
Commit
09ecac3
ยท
1 Parent(s): 1240813

Add application file

Browse files
Files changed (1) hide show
  1. app.py +422 -0
app.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import open3d as o3d
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+ from pathlib import Path
8
+ import sys
9
+ import asyncio
10
+ import json
11
+ if sys.platform == "win32":
12
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
13
+
14
+ # Get absolute paths for all files
15
+ def get_absolute_path(relative_path):
16
+ """Convert relative path to absolute path"""
17
+ return os.path.abspath(relative_path)
18
+
19
+ def check_file_exists(file_path):
20
+ """Check if file exists and return absolute path or None"""
21
+ abs_path = get_absolute_path(file_path)
22
+ if os.path.exists(abs_path):
23
+ return abs_path
24
+ else:
25
+ print(f"Warning: File not found: {file_path}")
26
+ return None
27
+
28
+ # Define the data structure with absolute paths
29
+ SAMPLES = {}
30
+ for i in range(2, 9): # Updated to handle 8 assault videos
31
+ sample_data = {
32
+ "input": f"video/Assault{i:03d}_x264.mp4", # Updated naming pattern
33
+ "optical_flow": f"optical_flow/Assault{i:03d}_x264.mp4",
34
+ "yolo": f"yolo/Assault{i:03d}_x264.mp4",
35
+ "vggt": f"vggt/Assault{i:03d}_x264.mp4",
36
+ # "pcd": f"pcd/Assault{i:03d}_x264.pcd",
37
+ "qa": f"qa/Assault{i:03d}_x264.json"
38
+ }
39
+
40
+ # Convert to absolute paths and check existence
41
+ SAMPLES[i] = {}
42
+ for key, path in sample_data.items():
43
+ abs_path = check_file_exists(path)
44
+ SAMPLES[i][key] = abs_path
45
+
46
+ def load_qa_data(qa_file):
47
+ """Load QA data from JSON file and format for chat display"""
48
+ try:
49
+ if not qa_file or not os.path.exists(qa_file):
50
+ return []
51
+
52
+ with open(qa_file, 'r', encoding='utf-8') as f:
53
+ qa_data = json.load(f)
54
+
55
+ # Format QA pairs for Gradio chat interface
56
+ chat_history = []
57
+ for qa_pair in qa_data.get('qa_pairs', []):
58
+ question = qa_pair.get('question', '')
59
+ answer = qa_pair.get('answer', '')
60
+
61
+ # Add question (user message)
62
+ chat_history.append([question, answer])
63
+
64
+ return chat_history
65
+ except Exception as e:
66
+ print(f"Error loading QA data {qa_file}: {e}")
67
+ return []
68
+
69
+ def get_qa_metadata(qa_file):
70
+ """Get metadata from QA JSON file"""
71
+ try:
72
+ if not qa_file or not os.path.exists(qa_file):
73
+ return {}
74
+
75
+ with open(qa_file, 'r', encoding='utf-8') as f:
76
+ qa_data = json.load(f)
77
+
78
+ return {
79
+ 'video_name': qa_data.get('video_name', ''),
80
+ 'timestamp': qa_data.get('timestamp', ''),
81
+ 'model_path': qa_data.get('model_path', ''),
82
+ 'max_num_frames': qa_data.get('max_num_frames', 0),
83
+ 'total_questions': len(qa_data.get('qa_pairs', []))
84
+ }
85
+ except Exception as e:
86
+ print(f"Error loading QA metadata {qa_file}: {e}")
87
+ return {}
88
+
89
+ def load_point_cloud_plotly(pcd_file):
90
+ """Load point cloud and create a 3D plotly visualization"""
91
+ try:
92
+ if not pcd_file or not os.path.exists(pcd_file):
93
+ return None
94
+
95
+ pcd = o3d.io.read_point_cloud(pcd_file)
96
+ points = np.asarray(pcd.points)
97
+ colors = np.asarray(pcd.colors) if pcd.has_colors() else None
98
+
99
+ if len(points) == 0:
100
+ return None
101
+
102
+ # Subsample points if too many (for performance)
103
+ if len(points) > 10000:
104
+ indices = np.random.choice(len(points), 10000, replace=False)
105
+ points = points[indices]
106
+ if colors is not None:
107
+ colors = colors[indices]
108
+
109
+ # Create 3D scatter plot
110
+ if colors is not None and len(colors) > 0:
111
+ # Convert colors to RGB if needed
112
+ if colors.max() <= 1.0:
113
+ colors = (colors * 255).astype(int)
114
+ color_rgb = [f'rgb({r},{g},{b})' for r, g, b in colors]
115
+
116
+ fig = go.Figure(data=[go.Scatter3d(
117
+ x=points[:, 0],
118
+ y=points[:, 1],
119
+ z=points[:, 2],
120
+ mode='markers',
121
+ marker=dict(
122
+ size=1.7,
123
+ color=color_rgb,
124
+ ),
125
+ text=[f'Point {i}' for i in range(len(points))],
126
+ hovertemplate='<b>Point %{text}</b><br>X: %{x}<br>Y: %{y}<br>Z: %{z}<extra></extra>'
127
+ )])
128
+ else:
129
+ fig = go.Figure(data=[go.Scatter3d(
130
+ x=points[:, 0],
131
+ y=points[:, 1],
132
+ z=points[:, 2],
133
+ mode='markers',
134
+ marker=dict(
135
+ size=2,
136
+ color=points[:, 2], # Color by Z coordinate
137
+ colorscale='Viridis',
138
+ showscale=True
139
+ ),
140
+ text=[f'Point {i}' for i in range(len(points))],
141
+ hovertemplate='<b>Point %{text}</b><br>X: %{x}<br>Y: %{y}<br>Z: %{z}<extra></extra>'
142
+ )])
143
+
144
+ fig.update_layout(
145
+ title=f'3D Point Cloud Visualization - {os.path.basename(pcd_file)}',
146
+ scene=dict(
147
+ xaxis_title='X',
148
+ yaxis_title='Y',
149
+ zaxis_title='Z',
150
+ camera=dict(
151
+ eye=dict(x=1.5, y=1.5, z=1.5)
152
+ ),
153
+ bgcolor='rgb(10, 10, 10)',
154
+ ),
155
+ margin=dict(l=0, r=0, t=50, b=0),
156
+ paper_bgcolor='rgb(20, 20, 20)',
157
+ plot_bgcolor='rgb(20, 20, 20)',
158
+ font=dict(color='white')
159
+ )
160
+
161
+ return fig
162
+ except Exception as e:
163
+ print(f"Error loading point cloud {pcd_file}: {e}")
164
+ return None
165
+
166
+ def create_sample_gallery(sample_id):
167
+ """Create a gallery view for a specific sample"""
168
+ sample = SAMPLES[sample_id]
169
+
170
+ # Load point cloud visualization
171
+ pcd_plot = load_point_cloud_plotly(sample["pcd"])
172
+
173
+ return (
174
+ sample["input"], # Input video
175
+ sample["optical_flow"], # Optical flow video
176
+ sample["yolo"], # YOLO video
177
+ sample["vggt"], # VGGT video
178
+ pcd_plot # Point cloud plot
179
+ )
180
+
181
+ def create_overview_gallery():
182
+ """Create an overview showing all samples"""
183
+ gallery_items = []
184
+ for i in range(1, 6):
185
+ sample = SAMPLES[i]
186
+ # Only add items that exist
187
+ if sample["input"]:
188
+ gallery_items.append((sample["input"], f"Sample {i} - Input"))
189
+ if sample["optical_flow"]:
190
+ gallery_items.append((sample["optical_flow"], f"Sample {i} - Optical Flow"))
191
+ if sample["yolo"]:
192
+ gallery_items.append((sample["yolo"], f"Sample {i} - YOLO"))
193
+ if sample["vggt"]:
194
+ gallery_items.append((sample["vggt"], f"Sample {i} - VGGT"))
195
+ return gallery_items
196
+
197
+ # Custom CSS for better styling
198
+ custom_css = """
199
+ # .gradio-container {
200
+ # max-width: 1200px !important;
201
+ # }
202
+ .gallery-item {
203
+ border-radius: 10px;
204
+ }
205
+ h1 {
206
+ text-align: center;
207
+ color: #2c3e50;
208
+ margin-bottom: 30px;
209
+ }
210
+ .tab-nav {
211
+ margin-bottom: 20px;
212
+ }
213
+ .qa-section-header {
214
+ font-size: 1.2em;
215
+ color: #2c3e50;
216
+ margin-top: 20px;
217
+ }
218
+ .qa-metadata {
219
+ background-color: #f8f9fa;
220
+ padding: 15px;
221
+ border-radius: 8px;
222
+ border-left: 4px solid #007bff;
223
+ }
224
+ .qa-info {
225
+ background-color: #e7f3ff;
226
+ padding: 10px;
227
+ border-radius: 5px;
228
+ font-style: italic;
229
+ }
230
+ """
231
+
232
+ # Create the Gradio interface
233
+ with gr.Blocks(css=custom_css, title="Anomalous Event Detection") as demo:
234
+ gr.Markdown("# ๐ŸŽฅ Results Gallery")
235
+
236
+ with gr.Tabs() as tabs:
237
+ # Individual sample tabs
238
+ for i in range(2, 9):
239
+ with gr.Tab(f"๐ŸŽฌ Sample {i-1}"):
240
+ gr.Markdown(f"## Sample {i-1} - Detailed View")
241
+
242
+ sample = SAMPLES[i]
243
+ # Top Row: Input Video + Chat History
244
+ with gr.Row():
245
+ # Left Column: Input Video (narrower)
246
+ with gr.Column(scale=1):
247
+ gr.Markdown("### ๐Ÿ“น Input Video")
248
+ if sample["input"]:
249
+ input_video = gr.Video(
250
+ value=sample["input"],
251
+ label="Original Input",
252
+ show_label=True
253
+ )
254
+ else:
255
+ gr.Markdown("โŒ Input video not found")
256
+
257
+ # Right Column: Q&A Chat History
258
+ with gr.Column(scale=1, min_width=400):
259
+ gr.Markdown("### ๐Ÿ’ฌ Q&A Chat History")
260
+
261
+ if sample["qa"]:
262
+ # Load QA metadata
263
+ qa_metadata = get_qa_metadata(sample["qa"])
264
+ if qa_metadata:
265
+ gr.Markdown(f"""
266
+ **๐Ÿ“Š Chat Session Info:**
267
+ - **Video:** {qa_metadata.get('video_name', 'N/A')}
268
+ - **Total Questions:** {qa_metadata.get('total_questions', 0)}
269
+ - **Max Frames:** {qa_metadata.get('max_num_frames', 0)}
270
+ - **Timestamp:** {qa_metadata.get('timestamp', 'N/A')[:19].replace('T', ' ')}
271
+ """)
272
+
273
+ # Load and display chat history
274
+ qa_history = load_qa_data(sample["qa"])
275
+ if qa_history:
276
+ chatbot = gr.Chatbot(
277
+ value=qa_history,
278
+ label="Video Analysis Q&A",
279
+ show_label=True,
280
+ height=500,
281
+ avatar_images=["๏ฟฝ๏ฟฝ๏ฟฝ", "๐Ÿค–"]
282
+ )
283
+
284
+ gr.Markdown("""
285
+ ๐Ÿ’ก **About this Q&A:** Questions asked by humans about the video content and answers from an AI model trained for video analysis.
286
+ """)
287
+ else:
288
+ gr.Markdown("โŒ No Q&A data available for this sample")
289
+ else:
290
+ gr.Markdown("โŒ Q&A file not found for this sample")
291
+ # VGGT and Point Cloud in a row
292
+ with gr.Row():
293
+ with gr.Column():
294
+ gr.Markdown("### ๐ŸŽฎ VGGT")
295
+
296
+ if sample["vggt"]:
297
+ vggt_video = gr.Video(
298
+ value=sample["vggt"],
299
+ label="VGGT Processing",
300
+ show_label=True
301
+ )
302
+ else:
303
+ gr.Markdown("โŒ VGGT video not found")
304
+
305
+ with gr.Column():
306
+ pass
307
+ # gr.Markdown("### โ˜๏ธ 3D Point Cloud")
308
+
309
+ # if sample["pcd"]:
310
+ # try:
311
+ # pcd_plot = gr.Plot(
312
+ # value=load_point_cloud_plotly(sample["pcd"]),
313
+ # label="Interactive 3D Point Cloud",
314
+ # show_label=True
315
+ # )
316
+ # except Exception as e:
317
+ # gr.Markdown(f"โŒ Error loading point cloud: {str(e)}")
318
+ # else:
319
+ # gr.Markdown("โŒ Point cloud file not found")
320
+
321
+
322
+
323
+ # Bottom Section: Other Analysis Results
324
+ with gr.Row():
325
+ with gr.Column(scale=2):
326
+ # Optical Flow and YOLO in a row
327
+ with gr.Row():
328
+ with gr.Column():
329
+ gr.Markdown("### ๐ŸŒŠ Optical Flow")
330
+ if sample["optical_flow"]:
331
+ optical_flow_video = gr.Video(
332
+ value=sample["optical_flow"],
333
+ label="Motion Analysis",
334
+ show_label=True
335
+ )
336
+ else:
337
+ gr.Markdown("โŒ Optical flow video not found")
338
+
339
+ with gr.Column():
340
+ gr.Markdown("### ๐ŸŽฏ YOLO Detection")
341
+ if sample["yolo"]:
342
+ yolo_video = gr.Video(
343
+ value=sample["yolo"],
344
+ label="Object Detection",
345
+ show_label=True
346
+ )
347
+ else:
348
+ gr.Markdown("โŒ YOLO video not found")
349
+
350
+
351
+ # Comparison tab
352
+ # with gr.Tab("๐Ÿ” Compare"):
353
+ # gr.Markdown("## Compare Different Samples")
354
+ # gr.Markdown("Select two samples to compare side by side")
355
+
356
+ # with gr.Row():
357
+ # sample1_dropdown = gr.Dropdown(
358
+ # choices=list(range(1, 6)),
359
+ # value=1,
360
+ # label="Sample 1"
361
+ # )
362
+ # sample2_dropdown = gr.Dropdown(
363
+ # choices=list(range(1, 6)),
364
+ # value=2,
365
+ # label="Sample 2"
366
+ # )
367
+
368
+ # with gr.Row():
369
+ # with gr.Column():
370
+ # gr.Markdown("### Sample 1")
371
+ # comp_input1 = gr.Video(label="Input")
372
+ # comp_optical1 = gr.Video(label="Optical Flow")
373
+ # comp_yolo1 = gr.Video(label="YOLO")
374
+ # comp_vggt1 = gr.Video(label="VGGT")
375
+ # comp_pcd1 = gr.Plot(label="Point Cloud")
376
+
377
+ # with gr.Column():
378
+ # gr.Markdown("### Sample 2")
379
+ # comp_input2 = gr.Video(label="Input")
380
+ # comp_optical2 = gr.Video(label="Optical Flow")
381
+ # comp_yolo2 = gr.Video(label="YOLO")
382
+ # comp_vggt2 = gr.Video(label="VGGT")
383
+ # comp_pcd2 = gr.Plot(label="Point Cloud")
384
+
385
+ # # Update comparison when dropdowns change
386
+ # def update_comparison(sample1_id, sample2_id):
387
+ # try:
388
+ # sample1_results = create_sample_gallery(sample1_id)
389
+ # sample2_results = create_sample_gallery(sample2_id)
390
+ # return sample1_results + sample2_results
391
+ # except Exception as e:
392
+ # print(f"Error updating comparison: {e}")
393
+ # return [None] * 10
394
+
395
+ # for dropdown in [sample1_dropdown, sample2_dropdown]:
396
+ # dropdown.change(
397
+ # update_comparison,
398
+ # inputs=[sample1_dropdown, sample2_dropdown],
399
+ # outputs=[
400
+ # comp_input1, comp_optical1, comp_yolo1, comp_vggt1, comp_pcd1,
401
+ # comp_input2, comp_optical2, comp_yolo2, comp_vggt2, comp_pcd2
402
+ # ]
403
+ # )
404
+
405
+
406
+ if __name__ == "__main__":
407
+ # Print file status for debugging
408
+ print("=== File Status Check ===")
409
+ for i in range(2, 9):
410
+ print(f"\nSample {i}:")
411
+ for key, path in SAMPLES[i].items():
412
+ status = "โœ… Found" if path else "โŒ Missing"
413
+ print(f" {key}: {status}")
414
+
415
+ print(f"\n=== Starting Gradio App ===")
416
+ demo.launch(
417
+ share=True,
418
+ server_name="127.0.0.1",
419
+ server_port=7861,
420
+ show_error=True,
421
+ inbrowser=True
422
+ )