File size: 15,591 Bytes
b2f3ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d64119
 
b2f3ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d64119
b2f3ea1
 
 
 
 
 
 
0d64119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2f3ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d64119
b2f3ea1
 
 
0d64119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2f3ea1
 
 
 
 
 
 
 
 
da50283
b2f3ea1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import numpy as np
import pickle
import json
import string
import cv2
from tqdm import tqdm
import os
from utils.periodic_detection_helper import *
from utils.plot import *
def run_periodic_detection(video_path, trajectory_path, output_video_path=None, n_clusters=8, sampling_rate=1, make_video=True):
    """
    Run periodic detection on a video and its associated trajectories
    
    Parameters:
    - video_path: Path to the video file
    - trajectory_path: Path to the trajectory file (pickle or json)
    - output_video_path: Path where the output video will be saved (default: same as input with _periodic suffix)
    - n_clusters: Number of clusters for spatiotemporal clustering (default: 9)
    - sampling_rate: Sampling rate for trajectories (default: 1)
    - make_video: Whether to create a visualization video (default: True)
    
    Returns:
    - Dictionary containing workflow, period boundaries, and other results
    """

    # Main function execution starts here
    # Setup output video path if not provided
    if output_video_path is None:
        base_name = os.path.splitext(video_path)[0]
        output_video_path = f"{base_name}_periodic.mp4"
    
    # Load trajectories from either pickle or json
    file_ext = os.path.splitext(trajectory_path)[1].lower()
    try:
        if file_ext == '.pkl':
            with open(trajectory_path, 'rb') as f:
                trajectories = pickle.load(f)
        elif file_ext == '.json':
            with open(trajectory_path, 'r') as f:
                trajectories = np.array(json.load(f))
        else:
            raise ValueError(f"Unsupported trajectory file format: {file_ext}. Use .pkl or .json")
    except Exception as e:
        return {"error": f"Failed to load trajectories: {str(e)}"}
    
    trajectories = trajectories.reshape(trajectories.shape[0],-1)
    trajectories = trajectories[::sampling_rate, :]
    cluster_labels, hard_token, soft_token, centroids = spatiotemporal_clustering(trajectories, 9)
    sequence = number_to_alpha(cluster_labels)
    num_frames = len(sequence) 

    window_sizes, magnitudes = dominant_fourier_frequency_2d(soft_token, lbound=10, ubound=max(len(soft_token.T), len(soft_token))//2)

    if len(window_sizes) == 0:
        return {"error": "No dominant frequencies found"}
    

    ### optimize win size
    scores = []
    for win in window_sizes[:10]: # select top 10 window sizes
        temporal_buffer = int(win*0.2)
        periods = []
        for i in range(num_frames//win):
            clip = sequence[max(0, win*i-temporal_buffer):min(num_frames, win*(i+1)+temporal_buffer )]
            periods.append(clip)
        
        compressed_periods = []
        for p in periods:
            compressed_periods.append(fuse_adjacent(p))
        score = calculate_similarity_score(compressed_periods)
        scores.append(score)
    if not scores:
        return {"error": "Failed to calculate similarity scores"}
    

    win = window_sizes[np.argmax(scores)]
    print('selected_win:{}'.format(win))
    temporal_buffer = int(win*0.2)
    periods = []
    for i in range(num_frames//win):
        clip = sequence[max(0, win*i-temporal_buffer):min(num_frames, win*(i+1)+temporal_buffer )]
        periods.append(clip)

    compressed_periods = []
    for p in periods:
        compressed_periods.append(fuse_adjacent(p))

        aligned_sequences = msa(compressed_periods[:3])

    while '-' in [x[-1] for x in aligned_sequences]:
        i = find_dash_end_index(aligned_sequences)
        if i!=0:
            aligned_sequences = [s[:i] for s in aligned_sequences]
        else:
            aligned_sequences = aligned_sequences

    i = find_longest_repeated_ends(aligned_sequences)
    if i!=0:
        aligned_sequences = [s[:-i] for s in aligned_sequences]
    else:
        aligned_sequences = aligned_sequences
    aligned_sequences

    workflow_str = summarize_strings(aligned_sequences)

    if not workflow_str:
        return {"error": "Empty workflow string after summary"}

    while workflow_str and workflow_str[0]=='_':
        workflow_str = workflow_str[1:]

    while workflow_str and workflow_str[-1]=='_':
        workflow_str = workflow_str[:-1]

    if not workflow_str:
        return {"error": "Empty workflow string"}

    workflow_str_len = len(workflow_str)

    workflow = [[] for _ in range(workflow_str_len)]
    for seq in aligned_sequences:
        pointer = 0
        Flag = False
        
        pos_skip_sign = seq.find('-')
        if pos_skip_sign==-1: pos_skip_sign = workflow_str_len //2
        pos_skip_sign = min(pos_skip_sign, workflow_str.find('_'))
        pos_skip_sign = max(pos_skip_sign, 1)

        for i in range(len(seq)):
            l = seq[i]
            if pointer==workflow_str_len:
                break
            if seq[i:i+pos_skip_sign] == workflow_str[:pos_skip_sign]:
                Flag = True
            if Flag:
                workflow[pointer].append(l.replace("-", "_")+'{:02}'.format(pointer))
                pointer += 1

    # Create multi-path workflow
    try:
        workflow_multi_paths = np.stack([''.join([y[0] for i, y in enumerate(x)]) for x in np.stack(workflow).T])
    except:
        workflow_multi_paths = []

    seg_labels = {}
    seg_ind = -1
    transcript_pointer = -1
    workflow_str_len = len(workflow_str)
    workflow_section_len = {}
    for frame_number, l in enumerate(sequence):
        # Only start new segment if current one is long enough (approx win size) or it's the first one
        if l==workflow_str[0] and workflow_str[transcript_pointer]==workflow_str[-1]:
             if seg_ind == -1 or len(seg_labels[seg_ind]) > 0.5 * win:
                transcript_pointer = 0
                seg_ind += 1
                seg_labels[seg_ind] = {}
                workflow_section_len[seg_ind] = {}
                workflow_section_len[seg_ind][transcript_pointer] = 0
        if transcript_pointer==-1: continue
        if transcript_pointer < workflow_str_len-1:
            if l == workflow_str[transcript_pointer+1]:
                transcript_pointer += 1
                workflow_section_len[seg_ind][transcript_pointer] = 0
        if transcript_pointer < workflow_str_len-1:
            if workflow_str[transcript_pointer+1]=='_':
                transcript_pointer += 1
                workflow_section_len[seg_ind][transcript_pointer] = 0

        if transcript_pointer == workflow_str_len-1 and workflow_section_len[seg_ind][transcript_pointer]>1 and l != workflow_str[transcript_pointer]:
            continue

        seg_labels[seg_ind][frame_number] = l
        workflow_section_len[seg_ind][transcript_pointer] +=1

    workflow_section_len = [v for k,v in workflow_section_len.items() if len(v)>workflow_str_len*0.3]
    workflow_section_len_array = []
    for idx in range(len(workflow_section_len)):
        workflow_section_len_array.append(list(workflow_section_len[idx].values()))

    if len(workflow_section_len_array)>0:

        sublist_max_len = max(len(sublist) for sublist in workflow_section_len_array)
        workflow_section_len_array = [sublist for sublist in workflow_section_len_array if len(sublist)==sublist_max_len]
        workflow_section_len_array = np.stack(workflow_section_len_array)
        workflow_section_len = np.median(workflow_section_len_array,0)
    else:
        workflow_section_len = np.zeros(workflow_str_len)

    ### Task 1
    period_num = len([x for x in seg_labels.values() if len(x)>0.5*win])
    #print("period_num: {}".format(period_num))
    #print("seg_labels_index: {}".format(seg_labels.keys()))
    if period_num>0:
        period_boundaries = {}
        for p_id, (k,v) in enumerate(seg_labels.items()):
            frame_list = np.sort(list(v.keys()))
            # Convert to python int for JSON serialization
            period_boundaries[p_id] = [int(frame_list[0]), int(frame_list[-1])]
            if p_id > 0: period_boundaries[p_id-1][1] = int(frame_list[0]-1)

    else:
        period_num = num_frames//win
        period_boundaries = [[int((i-1)*win), int(i*win)] for i in range(1,period_num+1)]

    print(f'Workflow: {workflow_str}')
    for i, boundary in period_boundaries.items():
        print(f"Priod {i+1}: with boundaries of {boundary} ")


    # Make visualization video if requested
    if make_video and os.path.exists(video_path):
        print("Generating Video...")
        
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print("Error opening video file")
            cap.release()
            return {
                "workflow": workflow_str,
                "period_boundaries": period_boundaries,
                "error_video": "Failed to open video file"
            }
        # Make token legends
        images = []
        tokens = []
        #for c in all_chars:
        for c in np.unique(list(sequence)):
            if c=='_': continue
            tokens.append(c)
            c = alpha_to_number(c)
            frame_number = np.where(cluster_labels==c)[0][0] 
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
            ret, frame = cap.read()
            images.append(frame[:,:,::-1])
        plot_images_with_token(images, ''.join(tokens))
        
        W = 640
        H = 640
        height = 80
        video_sampling_rate = 10
        
        unique_labels = sorted(set(list(sequence)))
        unique_chars = sorted(set(string.ascii_lowercase))[:15]
        hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
        color_map = {char: hsv_to_rgb(hue, 0.8, 0.9) for char, hue in zip(unique_chars, hues)}
        
        if seg_labels:
            max_period_len = max([len(v) for v in seg_labels.values()])
        else:
            max_period_len = win
        
        prog_bar_w = int(max_period_len // video_sampling_rate) + 300 + 50 # Add 50 px buffer
        progress_bar = np.ones((H, prog_bar_w, 3), dtype=np.float32)
        
        # Try to load anchor image or create a blank one
        try:
            if os.path.exists("anchors.jpg"):
                anchor = cv2.imread("anchors.jpg")
                anchor = cv2.resize(anchor, (W + prog_bar_w, 380))
            else:
                anchor = np.ones((380, W + prog_bar_w, 3), dtype=np.uint8) * 255
        except:
            anchor = np.ones((380, W + prog_bar_w, 3), dtype=np.uint8) * 255
        
        # Setup video writer
        # Setup video writer with robust codec handling
        
        # Try H.264 (avc1) first
        fourcc_code = 'avc1'
        fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
        out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
        
        if not out.isOpened():
             print(f"{fourcc_code} failed. Trying h264...")
             fourcc_code = 'h264'
             fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
             out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
             
        if not out.isOpened():
             print(f"{fourcc_code} failed. Trying vp80...")
             fourcc_code = 'vp80'
             fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
             out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
             
        if not out.isOpened():
             print(f"{fourcc_code} failed. Trying mp4v (less compatible)...")
             fourcc_code = 'mp4v'
             fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
             out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
             
        if not out.isOpened():
            print("Error: Could not open video writer with any compatible codec.")
        
        i, j = 0, 0
        for idx, k in enumerate(tqdm(list(seg_labels.keys()))):
            if not seg_labels[k]:  # Skip empty segments
                continue
                
            labels = list(seg_labels[k].values())
            frame_ids = list(seg_labels[k].keys())
            j += len(seg_labels[k])
            
            # Use boundaries from JSON (period_boundaries) if available, otherwise fallback or match
            start_frame_text = "????"
            end_frame_text = "????"
            # period_boundaries handles both dict (from detection) and list (fallback)
            # Keys or indices matching the segment order
            try:
                # period_boundaries might be dict or list. 
                # If dict, keys usually match iteration order of seg_labels.items() if consistent
                # If list, idx matches sequence
                boundary = period_boundaries[idx]
                start_frame_text = f"{boundary[0]:04d}"
                end_frame_text = f"{boundary[1]:04d}"
            except:
                pass

            cv2.putText(progress_bar, f'Period {k+1}', (5, height*k+30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
            
            for m, (l, frame_id) in enumerate(zip(labels[::video_sampling_rate], frame_ids[::video_sampling_rate])):
                try:
                    progress_bar[height*k:height*(k+1), 300+m, :] = color_map[l.lower()]
                    
                    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
                    ret, frame = cap.read()
                    if not ret:
                        continue
                        
                    frame = cv2.resize(frame, (W, H))
                    cv2.putText(frame, f"Frame: {frame_id}", (50, 50), 
                               cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
                    
                    frame = np.concatenate([frame, (progress_bar*255).astype(np.uint8)[:,:,::-1]], axis=1)
                    frame = np.concatenate([frame, anchor], axis=0)
                    out.write(frame)
                except Exception as e:
                    print(f"Error in video generation: {str(e)}")
                    continue
            
            cv2.putText(progress_bar, f'Frame: {start_frame_text}-{end_frame_text}', (5, height*k+52), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
            i += len(seg_labels[k])
        
        # Add freeze frame at the end to show the final state
        try:
            # Reconstruct the final frame with the updated progress_bar (which has the last period's boundaries text)
            # frame structure: Top part (Video + ProgressBar), Bottom part (Anchor)
            # We assume 'frame' holds the last written frame. We extract the video component (Top-Left).
            # Video part is [:H, :W]
            if frame is not None:
                last_video_part = frame[:H, :W, :]
                top_part = np.concatenate([last_video_part, (progress_bar*255).astype(np.uint8)[:,:,::-1]], axis=1)
                final_frame = np.concatenate([top_part, anchor], axis=0)
                
                for _ in range(90): # 3 seconds pause
                    out.write(final_frame)
        except Exception as e:
            print(f"Error creating freeze frame: {e}")
            pass

        # Release resources
        cap.release()
        out.release()
    
    # Return results
    return {
        "workflow": workflow_multi_paths.tolist() if isinstance(workflow_multi_paths, np.ndarray) else workflow_multi_paths,
        "period_boundaries": period_boundaries,
        "window_size": int(win),
        "num_periods": int(period_num+1),
        "output_video": output_video_path if make_video else None
    }