File size: 5,228 Bytes
3038d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import numpy as np
import os
import glob
import pickle
import json
from utils.render import render_smpl
from periodic_detection_function import run_periodic_detection
DATA_DIR = "data"
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
def get_candidates():
"""List all pickle files in data directory."""
files = glob.glob(os.path.join(DATA_DIR, "*.pkl"))
return [os.path.basename(f) for f in files]
def load_and_render(candidate_file):
"""
Load the selected pickle file, render it to a video, and return the video path.
"""
if not candidate_file:
return None
pkl_path = os.path.join(DATA_DIR, candidate_file)
output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_rendered.mp4")
# Check for pre-rendered video in data/
pre_rendered_path = os.path.join(DATA_DIR, candidate_file.replace('.pkl', '.mp4'))
if os.path.exists(pre_rendered_path):
print(f"Using pre-rendered video: {pre_rendered_path}")
return pre_rendered_path
# If not found, fall back to rendering (or re-render if desired, but user wants direct use)
# Keeping fallback just in case
try:
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
# Data shape check
if len(data.shape) != 3 or data.shape[1] != 24 or data.shape[2] != 3:
raise ValueError(f"Unexpected data shape: {data.shape}. Expected (Frames, 24, 3)")
print(f"Rendering {candidate_file}...")
render_smpl(data, output_video_path, fps=30)
return output_video_path
except Exception as e:
print(f"Error rendering {candidate_file}: {e}")
return None
def run_analysis(candidate_file, rendered_video_path):
"""
Run periodic detection on the rendered video and trajectory data.
"""
if not candidate_file or not rendered_video_path:
return None, "Please select a candidate and wait for rendering first."
pkl_path = os.path.join(DATA_DIR, candidate_file)
output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_result.mp4")
try:
print(f"Running detection on {candidate_file}...")
# Note: run_periodic_detection expects [Frames, N_feats] usually or generic trajectory.
# The pickle contains (Frames, 24, 3).
# The spatiotemporal_clustering in helper seems to handle reshaping or expects specific shape.
# Looking at periodic_detection_function.py line 46:
# trajectories = trajectories.reshape(trajectories.shape[0],-1)
# So it flattens (Frames, 24, 3) to (Frames, 72), which is fine.
results = run_periodic_detection(
video_path=rendered_video_path,
trajectory_path=pkl_path,
output_video_path=output_video_path,
n_clusters=9,
sampling_rate=1,
make_video=True
)
if "error" in results:
return None, json.dumps(results, indent=2)
# Format results for display
display_results = {
"workflow branches": results.get("workflow"),
"period_boundaries": results.get("period_boundaries"),
"num_periods": results.get("num_periods"),
"window_size": results.get("window_size")
}
return results.get("output_video"), json.dumps(display_results, indent=2)
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error during analysis: {str(e)}"
def reset_all():
return None, None, None, None
# Gradio Interface
with gr.Blocks(title="Periodic Workflow Detection Demo") as demo:
gr.Markdown("# Periodic Workflow Detection Demo")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Select Input")
candidate_dropdown = gr.Dropdown(
choices=get_candidates(),
label="Select Candidates",
value=None
)
gr.Markdown("### Input Visualization")
input_video = gr.Video(label="Spatiotemporal Sequence", interactive=False)
with gr.Column(scale=1):
gr.Markdown("### 2. Run Detection")
run_btn = gr.Button("Run Analysis", variant="primary")
gr.Markdown("### Results")
text_output = gr.JSON(label="Numerical Results")
result_video = gr.Video(label="Detection Visualization", interactive=False)
reset_btn = gr.Button("Reset", variant="secondary")
# Interactions
candidate_dropdown.change(
fn=load_and_render,
inputs=[candidate_dropdown],
outputs=[input_video]
)
run_btn.click(
fn=run_analysis,
inputs=[candidate_dropdown, input_video],
outputs=[result_video, text_output]
)
reset_btn.click(
fn=reset_all,
inputs=[],
outputs=[candidate_dropdown, input_video, result_video, text_output]
)
if __name__ == "__main__":
#demo.launch()
demo.launch(ssr_mode=False)
|