fanduluhf commited on
Commit
3038d19
·
verified ·
1 Parent(s): f460dc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -149
app.py CHANGED
@@ -1,149 +1,150 @@
1
- import gradio as gr
2
- import numpy as np
3
- import os
4
- import glob
5
- import pickle
6
- import json
7
- from utils.render import render_smpl
8
- from periodic_detection_function import run_periodic_detection
9
-
10
- DATA_DIR = "data"
11
- OUTPUT_DIR = "outputs"
12
-
13
- os.makedirs(OUTPUT_DIR, exist_ok=True)
14
-
15
- def get_candidates():
16
- """List all pickle files in data directory."""
17
- files = glob.glob(os.path.join(DATA_DIR, "*.pkl"))
18
- return [os.path.basename(f) for f in files]
19
-
20
- def load_and_render(candidate_file):
21
- """
22
- Load the selected pickle file, render it to a video, and return the video path.
23
- """
24
- if not candidate_file:
25
- return None
26
-
27
- pkl_path = os.path.join(DATA_DIR, candidate_file)
28
- output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_rendered.mp4")
29
-
30
- # Check for pre-rendered video in data/
31
- pre_rendered_path = os.path.join(DATA_DIR, candidate_file.replace('.pkl', '.mp4'))
32
- if os.path.exists(pre_rendered_path):
33
- print(f"Using pre-rendered video: {pre_rendered_path}")
34
- return pre_rendered_path
35
-
36
- # If not found, fall back to rendering (or re-render if desired, but user wants direct use)
37
- # Keeping fallback just in case
38
- try:
39
- with open(pkl_path, 'rb') as f:
40
- data = pickle.load(f)
41
-
42
- # Data shape check
43
- if len(data.shape) != 3 or data.shape[1] != 24 or data.shape[2] != 3:
44
- raise ValueError(f"Unexpected data shape: {data.shape}. Expected (Frames, 24, 3)")
45
-
46
- print(f"Rendering {candidate_file}...")
47
- render_smpl(data, output_video_path, fps=30)
48
- return output_video_path
49
-
50
- except Exception as e:
51
- print(f"Error rendering {candidate_file}: {e}")
52
- return None
53
-
54
- def run_analysis(candidate_file, rendered_video_path):
55
- """
56
- Run periodic detection on the rendered video and trajectory data.
57
- """
58
- if not candidate_file or not rendered_video_path:
59
- return None, "Please select a candidate and wait for rendering first."
60
-
61
- pkl_path = os.path.join(DATA_DIR, candidate_file)
62
- output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_result.mp4")
63
-
64
- try:
65
- print(f"Running detection on {candidate_file}...")
66
- # Note: run_periodic_detection expects [Frames, N_feats] usually or generic trajectory.
67
- # The pickle contains (Frames, 24, 3).
68
- # The spatiotemporal_clustering in helper seems to handle reshaping or expects specific shape.
69
- # Looking at periodic_detection_function.py line 46:
70
- # trajectories = trajectories.reshape(trajectories.shape[0],-1)
71
- # So it flattens (Frames, 24, 3) to (Frames, 72), which is fine.
72
-
73
- results = run_periodic_detection(
74
- video_path=rendered_video_path,
75
- trajectory_path=pkl_path,
76
- output_video_path=output_video_path,
77
- n_clusters=9,
78
- sampling_rate=1,
79
- make_video=True
80
- )
81
-
82
- if "error" in results:
83
- return None, json.dumps(results, indent=2)
84
-
85
- # Format results for display
86
- display_results = {
87
- "workflow branches": results.get("workflow"),
88
- "period_boundaries": results.get("period_boundaries"),
89
- "num_periods": results.get("num_periods"),
90
- "window_size": results.get("window_size")
91
- }
92
-
93
- return results.get("output_video"), json.dumps(display_results, indent=2)
94
-
95
- except Exception as e:
96
- import traceback
97
- traceback.print_exc()
98
- return None, f"Error during analysis: {str(e)}"
99
-
100
- def reset_all():
101
- return None, None, None, None
102
-
103
- # Gradio Interface
104
- with gr.Blocks(title="Periodic Workflow Detection Demo") as demo:
105
- gr.Markdown("# Periodic Workflow Detection Demo")
106
-
107
- with gr.Row():
108
- with gr.Column(scale=1):
109
- gr.Markdown("### 1. Select Input")
110
- candidate_dropdown = gr.Dropdown(
111
- choices=get_candidates(),
112
- label="Select Candidates",
113
- value=None
114
- )
115
-
116
- gr.Markdown("### Input Visualization")
117
- input_video = gr.Video(label="Spatiotemporal Sequence", interactive=False)
118
-
119
- with gr.Column(scale=1):
120
- gr.Markdown("### 2. Run Detection")
121
- run_btn = gr.Button("Run Analysis", variant="primary")
122
-
123
- gr.Markdown("### Results")
124
- text_output = gr.JSON(label="Numerical Results")
125
- result_video = gr.Video(label="Detection Visualization", interactive=False)
126
-
127
- reset_btn = gr.Button("Reset", variant="secondary")
128
-
129
- # Interactions
130
- candidate_dropdown.change(
131
- fn=load_and_render,
132
- inputs=[candidate_dropdown],
133
- outputs=[input_video]
134
- )
135
-
136
- run_btn.click(
137
- fn=run_analysis,
138
- inputs=[candidate_dropdown, input_video],
139
- outputs=[result_video, text_output]
140
- )
141
-
142
- reset_btn.click(
143
- fn=reset_all,
144
- inputs=[],
145
- outputs=[candidate_dropdown, input_video, result_video, text_output]
146
- )
147
-
148
- if __name__ == "__main__":
149
- demo.launch()
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ import glob
5
+ import pickle
6
+ import json
7
+ from utils.render import render_smpl
8
+ from periodic_detection_function import run_periodic_detection
9
+
10
+ DATA_DIR = "data"
11
+ OUTPUT_DIR = "outputs"
12
+
13
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
14
+
15
+ def get_candidates():
16
+ """List all pickle files in data directory."""
17
+ files = glob.glob(os.path.join(DATA_DIR, "*.pkl"))
18
+ return [os.path.basename(f) for f in files]
19
+
20
+ def load_and_render(candidate_file):
21
+ """
22
+ Load the selected pickle file, render it to a video, and return the video path.
23
+ """
24
+ if not candidate_file:
25
+ return None
26
+
27
+ pkl_path = os.path.join(DATA_DIR, candidate_file)
28
+ output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_rendered.mp4")
29
+
30
+ # Check for pre-rendered video in data/
31
+ pre_rendered_path = os.path.join(DATA_DIR, candidate_file.replace('.pkl', '.mp4'))
32
+ if os.path.exists(pre_rendered_path):
33
+ print(f"Using pre-rendered video: {pre_rendered_path}")
34
+ return pre_rendered_path
35
+
36
+ # If not found, fall back to rendering (or re-render if desired, but user wants direct use)
37
+ # Keeping fallback just in case
38
+ try:
39
+ with open(pkl_path, 'rb') as f:
40
+ data = pickle.load(f)
41
+
42
+ # Data shape check
43
+ if len(data.shape) != 3 or data.shape[1] != 24 or data.shape[2] != 3:
44
+ raise ValueError(f"Unexpected data shape: {data.shape}. Expected (Frames, 24, 3)")
45
+
46
+ print(f"Rendering {candidate_file}...")
47
+ render_smpl(data, output_video_path, fps=30)
48
+ return output_video_path
49
+
50
+ except Exception as e:
51
+ print(f"Error rendering {candidate_file}: {e}")
52
+ return None
53
+
54
+ def run_analysis(candidate_file, rendered_video_path):
55
+ """
56
+ Run periodic detection on the rendered video and trajectory data.
57
+ """
58
+ if not candidate_file or not rendered_video_path:
59
+ return None, "Please select a candidate and wait for rendering first."
60
+
61
+ pkl_path = os.path.join(DATA_DIR, candidate_file)
62
+ output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_result.mp4")
63
+
64
+ try:
65
+ print(f"Running detection on {candidate_file}...")
66
+ # Note: run_periodic_detection expects [Frames, N_feats] usually or generic trajectory.
67
+ # The pickle contains (Frames, 24, 3).
68
+ # The spatiotemporal_clustering in helper seems to handle reshaping or expects specific shape.
69
+ # Looking at periodic_detection_function.py line 46:
70
+ # trajectories = trajectories.reshape(trajectories.shape[0],-1)
71
+ # So it flattens (Frames, 24, 3) to (Frames, 72), which is fine.
72
+
73
+ results = run_periodic_detection(
74
+ video_path=rendered_video_path,
75
+ trajectory_path=pkl_path,
76
+ output_video_path=output_video_path,
77
+ n_clusters=9,
78
+ sampling_rate=1,
79
+ make_video=True
80
+ )
81
+
82
+ if "error" in results:
83
+ return None, json.dumps(results, indent=2)
84
+
85
+ # Format results for display
86
+ display_results = {
87
+ "workflow branches": results.get("workflow"),
88
+ "period_boundaries": results.get("period_boundaries"),
89
+ "num_periods": results.get("num_periods"),
90
+ "window_size": results.get("window_size")
91
+ }
92
+
93
+ return results.get("output_video"), json.dumps(display_results, indent=2)
94
+
95
+ except Exception as e:
96
+ import traceback
97
+ traceback.print_exc()
98
+ return None, f"Error during analysis: {str(e)}"
99
+
100
+ def reset_all():
101
+ return None, None, None, None
102
+
103
+ # Gradio Interface
104
+ with gr.Blocks(title="Periodic Workflow Detection Demo") as demo:
105
+ gr.Markdown("# Periodic Workflow Detection Demo")
106
+
107
+ with gr.Row():
108
+ with gr.Column(scale=1):
109
+ gr.Markdown("### 1. Select Input")
110
+ candidate_dropdown = gr.Dropdown(
111
+ choices=get_candidates(),
112
+ label="Select Candidates",
113
+ value=None
114
+ )
115
+
116
+ gr.Markdown("### Input Visualization")
117
+ input_video = gr.Video(label="Spatiotemporal Sequence", interactive=False)
118
+
119
+ with gr.Column(scale=1):
120
+ gr.Markdown("### 2. Run Detection")
121
+ run_btn = gr.Button("Run Analysis", variant="primary")
122
+
123
+ gr.Markdown("### Results")
124
+ text_output = gr.JSON(label="Numerical Results")
125
+ result_video = gr.Video(label="Detection Visualization", interactive=False)
126
+
127
+ reset_btn = gr.Button("Reset", variant="secondary")
128
+
129
+ # Interactions
130
+ candidate_dropdown.change(
131
+ fn=load_and_render,
132
+ inputs=[candidate_dropdown],
133
+ outputs=[input_video]
134
+ )
135
+
136
+ run_btn.click(
137
+ fn=run_analysis,
138
+ inputs=[candidate_dropdown, input_video],
139
+ outputs=[result_video, text_output]
140
+ )
141
+
142
+ reset_btn.click(
143
+ fn=reset_all,
144
+ inputs=[],
145
+ outputs=[candidate_dropdown, input_video, result_video, text_output]
146
+ )
147
+
148
+ if __name__ == "__main__":
149
+ #demo.launch()
150
+ demo.launch(ssr_mode=False)