LSPW / app.py
fanduluhf's picture
Update app.py
3038d19 verified
import gradio as gr
import numpy as np
import os
import glob
import pickle
import json
from utils.render import render_smpl
from periodic_detection_function import run_periodic_detection
DATA_DIR = "data"
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
def get_candidates():
"""List all pickle files in data directory."""
files = glob.glob(os.path.join(DATA_DIR, "*.pkl"))
return [os.path.basename(f) for f in files]
def load_and_render(candidate_file):
"""
Load the selected pickle file, render it to a video, and return the video path.
"""
if not candidate_file:
return None
pkl_path = os.path.join(DATA_DIR, candidate_file)
output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_rendered.mp4")
# Check for pre-rendered video in data/
pre_rendered_path = os.path.join(DATA_DIR, candidate_file.replace('.pkl', '.mp4'))
if os.path.exists(pre_rendered_path):
print(f"Using pre-rendered video: {pre_rendered_path}")
return pre_rendered_path
# If not found, fall back to rendering (or re-render if desired, but user wants direct use)
# Keeping fallback just in case
try:
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
# Data shape check
if len(data.shape) != 3 or data.shape[1] != 24 or data.shape[2] != 3:
raise ValueError(f"Unexpected data shape: {data.shape}. Expected (Frames, 24, 3)")
print(f"Rendering {candidate_file}...")
render_smpl(data, output_video_path, fps=30)
return output_video_path
except Exception as e:
print(f"Error rendering {candidate_file}: {e}")
return None
def run_analysis(candidate_file, rendered_video_path):
"""
Run periodic detection on the rendered video and trajectory data.
"""
if not candidate_file or not rendered_video_path:
return None, "Please select a candidate and wait for rendering first."
pkl_path = os.path.join(DATA_DIR, candidate_file)
output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_result.mp4")
try:
print(f"Running detection on {candidate_file}...")
# Note: run_periodic_detection expects [Frames, N_feats] usually or generic trajectory.
# The pickle contains (Frames, 24, 3).
# The spatiotemporal_clustering in helper seems to handle reshaping or expects specific shape.
# Looking at periodic_detection_function.py line 46:
# trajectories = trajectories.reshape(trajectories.shape[0],-1)
# So it flattens (Frames, 24, 3) to (Frames, 72), which is fine.
results = run_periodic_detection(
video_path=rendered_video_path,
trajectory_path=pkl_path,
output_video_path=output_video_path,
n_clusters=9,
sampling_rate=1,
make_video=True
)
if "error" in results:
return None, json.dumps(results, indent=2)
# Format results for display
display_results = {
"workflow branches": results.get("workflow"),
"period_boundaries": results.get("period_boundaries"),
"num_periods": results.get("num_periods"),
"window_size": results.get("window_size")
}
return results.get("output_video"), json.dumps(display_results, indent=2)
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error during analysis: {str(e)}"
def reset_all():
return None, None, None, None
# Gradio Interface
with gr.Blocks(title="Periodic Workflow Detection Demo") as demo:
gr.Markdown("# Periodic Workflow Detection Demo")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Select Input")
candidate_dropdown = gr.Dropdown(
choices=get_candidates(),
label="Select Candidates",
value=None
)
gr.Markdown("### Input Visualization")
input_video = gr.Video(label="Spatiotemporal Sequence", interactive=False)
with gr.Column(scale=1):
gr.Markdown("### 2. Run Detection")
run_btn = gr.Button("Run Analysis", variant="primary")
gr.Markdown("### Results")
text_output = gr.JSON(label="Numerical Results")
result_video = gr.Video(label="Detection Visualization", interactive=False)
reset_btn = gr.Button("Reset", variant="secondary")
# Interactions
candidate_dropdown.change(
fn=load_and_render,
inputs=[candidate_dropdown],
outputs=[input_video]
)
run_btn.click(
fn=run_analysis,
inputs=[candidate_dropdown, input_video],
outputs=[result_video, text_output]
)
reset_btn.click(
fn=reset_all,
inputs=[],
outputs=[candidate_dropdown, input_video, result_video, text_output]
)
if __name__ == "__main__":
#demo.launch()
demo.launch(ssr_mode=False)