File size: 5,228 Bytes
3038d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import numpy as np
import os
import glob
import pickle
import json
from utils.render import render_smpl
from periodic_detection_function import run_periodic_detection

DATA_DIR = "data"
OUTPUT_DIR = "outputs"

os.makedirs(OUTPUT_DIR, exist_ok=True)

def get_candidates():
    """List all pickle files in data directory."""
    files = glob.glob(os.path.join(DATA_DIR, "*.pkl"))
    return [os.path.basename(f) for f in files]

def load_and_render(candidate_file):
    """
    Load the selected pickle file, render it to a video, and return the video path.
    """
    if not candidate_file:
        return None
    
    pkl_path = os.path.join(DATA_DIR, candidate_file)
    output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_rendered.mp4")
    
    # Check for pre-rendered video in data/
    pre_rendered_path = os.path.join(DATA_DIR, candidate_file.replace('.pkl', '.mp4'))
    if os.path.exists(pre_rendered_path):
        print(f"Using pre-rendered video: {pre_rendered_path}")
        return pre_rendered_path

    # If not found, fall back to rendering (or re-render if desired, but user wants direct use)
    # Keeping fallback just in case
    try:
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
        
        # Data shape check
        if len(data.shape) != 3 or data.shape[1] != 24 or data.shape[2] != 3:
            raise ValueError(f"Unexpected data shape: {data.shape}. Expected (Frames, 24, 3)")
            
        print(f"Rendering {candidate_file}...")
        render_smpl(data, output_video_path, fps=30)
        return output_video_path
    
    except Exception as e:
        print(f"Error rendering {candidate_file}: {e}")
        return None

def run_analysis(candidate_file, rendered_video_path):
    """
    Run periodic detection on the rendered video and trajectory data.
    """
    if not candidate_file or not rendered_video_path:
        return None, "Please select a candidate and wait for rendering first."
    
    pkl_path = os.path.join(DATA_DIR, candidate_file)
    output_video_path = os.path.join(OUTPUT_DIR, f"{candidate_file.replace('.pkl', '')}_result.mp4")
    
    try:
        print(f"Running detection on {candidate_file}...")
        # Note: run_periodic_detection expects [Frames, N_feats] usually or generic trajectory.
        # The pickle contains (Frames, 24, 3). 
        # The spatiotemporal_clustering in helper seems to handle reshaping or expects specific shape.
        # Looking at periodic_detection_function.py line 46: 
        # trajectories = trajectories.reshape(trajectories.shape[0],-1)
        # So it flattens (Frames, 24, 3) to (Frames, 72), which is fine.
        
        results = run_periodic_detection(
            video_path=rendered_video_path,
            trajectory_path=pkl_path,
            output_video_path=output_video_path,
            n_clusters=9,
            sampling_rate=1,
            make_video=True
        )
        
        if "error" in results:
            return None, json.dumps(results, indent=2)
            
        # Format results for display
        display_results = {
            "workflow branches": results.get("workflow"),
            "period_boundaries": results.get("period_boundaries"),
            "num_periods": results.get("num_periods"),
            "window_size": results.get("window_size")
        }
        
        return results.get("output_video"), json.dumps(display_results, indent=2)
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"Error during analysis: {str(e)}"

def reset_all():
    return None, None, None, None

# Gradio Interface
with gr.Blocks(title="Periodic Workflow Detection Demo") as demo:
    gr.Markdown("# Periodic Workflow Detection Demo")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 1. Select Input")
            candidate_dropdown = gr.Dropdown(
                choices=get_candidates(),
                label="Select Candidates",
                value=None
            )
            
            gr.Markdown("### Input Visualization")
            input_video = gr.Video(label="Spatiotemporal Sequence", interactive=False)
            
        with gr.Column(scale=1):
            gr.Markdown("### 2. Run Detection")
            run_btn = gr.Button("Run Analysis", variant="primary")
            
            gr.Markdown("### Results")
            text_output = gr.JSON(label="Numerical Results")
            result_video = gr.Video(label="Detection Visualization", interactive=False)
            
            reset_btn = gr.Button("Reset", variant="secondary")

    # Interactions
    candidate_dropdown.change(
        fn=load_and_render,
        inputs=[candidate_dropdown],
        outputs=[input_video]
    )
    
    run_btn.click(
        fn=run_analysis,
        inputs=[candidate_dropdown, input_video],
        outputs=[result_video, text_output]
    )
    
    reset_btn.click(
        fn=reset_all,
        inputs=[],
        outputs=[candidate_dropdown, input_video, result_video, text_output]
    )

if __name__ == "__main__":
    #demo.launch()
    demo.launch(ssr_mode=False)