LSPW / periodic_detection_function.py
fanduluhf's picture
Update periodic_detection_function.py
da50283 verified
import numpy as np
import pickle
import json
import string
import cv2
from tqdm import tqdm
import os
from utils.periodic_detection_helper import *
from utils.plot import *
def run_periodic_detection(video_path, trajectory_path, output_video_path=None, n_clusters=8, sampling_rate=1, make_video=True):
"""
Run periodic detection on a video and its associated trajectories
Parameters:
- video_path: Path to the video file
- trajectory_path: Path to the trajectory file (pickle or json)
- output_video_path: Path where the output video will be saved (default: same as input with _periodic suffix)
- n_clusters: Number of clusters for spatiotemporal clustering (default: 9)
- sampling_rate: Sampling rate for trajectories (default: 1)
- make_video: Whether to create a visualization video (default: True)
Returns:
- Dictionary containing workflow, period boundaries, and other results
"""
# Main function execution starts here
# Setup output video path if not provided
if output_video_path is None:
base_name = os.path.splitext(video_path)[0]
output_video_path = f"{base_name}_periodic.mp4"
# Load trajectories from either pickle or json
file_ext = os.path.splitext(trajectory_path)[1].lower()
try:
if file_ext == '.pkl':
with open(trajectory_path, 'rb') as f:
trajectories = pickle.load(f)
elif file_ext == '.json':
with open(trajectory_path, 'r') as f:
trajectories = np.array(json.load(f))
else:
raise ValueError(f"Unsupported trajectory file format: {file_ext}. Use .pkl or .json")
except Exception as e:
return {"error": f"Failed to load trajectories: {str(e)}"}
trajectories = trajectories.reshape(trajectories.shape[0],-1)
trajectories = trajectories[::sampling_rate, :]
cluster_labels, hard_token, soft_token, centroids = spatiotemporal_clustering(trajectories, 9)
sequence = number_to_alpha(cluster_labels)
num_frames = len(sequence)
window_sizes, magnitudes = dominant_fourier_frequency_2d(soft_token, lbound=10, ubound=max(len(soft_token.T), len(soft_token))//2)
if len(window_sizes) == 0:
return {"error": "No dominant frequencies found"}
### optimize win size
scores = []
for win in window_sizes[:10]: # select top 10 window sizes
temporal_buffer = int(win*0.2)
periods = []
for i in range(num_frames//win):
clip = sequence[max(0, win*i-temporal_buffer):min(num_frames, win*(i+1)+temporal_buffer )]
periods.append(clip)
compressed_periods = []
for p in periods:
compressed_periods.append(fuse_adjacent(p))
score = calculate_similarity_score(compressed_periods)
scores.append(score)
if not scores:
return {"error": "Failed to calculate similarity scores"}
win = window_sizes[np.argmax(scores)]
print('selected_win:{}'.format(win))
temporal_buffer = int(win*0.2)
periods = []
for i in range(num_frames//win):
clip = sequence[max(0, win*i-temporal_buffer):min(num_frames, win*(i+1)+temporal_buffer )]
periods.append(clip)
compressed_periods = []
for p in periods:
compressed_periods.append(fuse_adjacent(p))
aligned_sequences = msa(compressed_periods[:3])
while '-' in [x[-1] for x in aligned_sequences]:
i = find_dash_end_index(aligned_sequences)
if i!=0:
aligned_sequences = [s[:i] for s in aligned_sequences]
else:
aligned_sequences = aligned_sequences
i = find_longest_repeated_ends(aligned_sequences)
if i!=0:
aligned_sequences = [s[:-i] for s in aligned_sequences]
else:
aligned_sequences = aligned_sequences
aligned_sequences
workflow_str = summarize_strings(aligned_sequences)
if not workflow_str:
return {"error": "Empty workflow string after summary"}
while workflow_str and workflow_str[0]=='_':
workflow_str = workflow_str[1:]
while workflow_str and workflow_str[-1]=='_':
workflow_str = workflow_str[:-1]
if not workflow_str:
return {"error": "Empty workflow string"}
workflow_str_len = len(workflow_str)
workflow = [[] for _ in range(workflow_str_len)]
for seq in aligned_sequences:
pointer = 0
Flag = False
pos_skip_sign = seq.find('-')
if pos_skip_sign==-1: pos_skip_sign = workflow_str_len //2
pos_skip_sign = min(pos_skip_sign, workflow_str.find('_'))
pos_skip_sign = max(pos_skip_sign, 1)
for i in range(len(seq)):
l = seq[i]
if pointer==workflow_str_len:
break
if seq[i:i+pos_skip_sign] == workflow_str[:pos_skip_sign]:
Flag = True
if Flag:
workflow[pointer].append(l.replace("-", "_")+'{:02}'.format(pointer))
pointer += 1
# Create multi-path workflow
try:
workflow_multi_paths = np.stack([''.join([y[0] for i, y in enumerate(x)]) for x in np.stack(workflow).T])
except:
workflow_multi_paths = []
seg_labels = {}
seg_ind = -1
transcript_pointer = -1
workflow_str_len = len(workflow_str)
workflow_section_len = {}
for frame_number, l in enumerate(sequence):
# Only start new segment if current one is long enough (approx win size) or it's the first one
if l==workflow_str[0] and workflow_str[transcript_pointer]==workflow_str[-1]:
if seg_ind == -1 or len(seg_labels[seg_ind]) > 0.5 * win:
transcript_pointer = 0
seg_ind += 1
seg_labels[seg_ind] = {}
workflow_section_len[seg_ind] = {}
workflow_section_len[seg_ind][transcript_pointer] = 0
if transcript_pointer==-1: continue
if transcript_pointer < workflow_str_len-1:
if l == workflow_str[transcript_pointer+1]:
transcript_pointer += 1
workflow_section_len[seg_ind][transcript_pointer] = 0
if transcript_pointer < workflow_str_len-1:
if workflow_str[transcript_pointer+1]=='_':
transcript_pointer += 1
workflow_section_len[seg_ind][transcript_pointer] = 0
if transcript_pointer == workflow_str_len-1 and workflow_section_len[seg_ind][transcript_pointer]>1 and l != workflow_str[transcript_pointer]:
continue
seg_labels[seg_ind][frame_number] = l
workflow_section_len[seg_ind][transcript_pointer] +=1
workflow_section_len = [v for k,v in workflow_section_len.items() if len(v)>workflow_str_len*0.3]
workflow_section_len_array = []
for idx in range(len(workflow_section_len)):
workflow_section_len_array.append(list(workflow_section_len[idx].values()))
if len(workflow_section_len_array)>0:
sublist_max_len = max(len(sublist) for sublist in workflow_section_len_array)
workflow_section_len_array = [sublist for sublist in workflow_section_len_array if len(sublist)==sublist_max_len]
workflow_section_len_array = np.stack(workflow_section_len_array)
workflow_section_len = np.median(workflow_section_len_array,0)
else:
workflow_section_len = np.zeros(workflow_str_len)
### Task 1
period_num = len([x for x in seg_labels.values() if len(x)>0.5*win])
#print("period_num: {}".format(period_num))
#print("seg_labels_index: {}".format(seg_labels.keys()))
if period_num>0:
period_boundaries = {}
for p_id, (k,v) in enumerate(seg_labels.items()):
frame_list = np.sort(list(v.keys()))
# Convert to python int for JSON serialization
period_boundaries[p_id] = [int(frame_list[0]), int(frame_list[-1])]
if p_id > 0: period_boundaries[p_id-1][1] = int(frame_list[0]-1)
else:
period_num = num_frames//win
period_boundaries = [[int((i-1)*win), int(i*win)] for i in range(1,period_num+1)]
print(f'Workflow: {workflow_str}')
for i, boundary in period_boundaries.items():
print(f"Priod {i+1}: with boundaries of {boundary} ")
# Make visualization video if requested
if make_video and os.path.exists(video_path):
print("Generating Video...")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error opening video file")
cap.release()
return {
"workflow": workflow_str,
"period_boundaries": period_boundaries,
"error_video": "Failed to open video file"
}
# Make token legends
images = []
tokens = []
#for c in all_chars:
for c in np.unique(list(sequence)):
if c=='_': continue
tokens.append(c)
c = alpha_to_number(c)
frame_number = np.where(cluster_labels==c)[0][0]
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
images.append(frame[:,:,::-1])
plot_images_with_token(images, ''.join(tokens))
W = 640
H = 640
height = 80
video_sampling_rate = 10
unique_labels = sorted(set(list(sequence)))
unique_chars = sorted(set(string.ascii_lowercase))[:15]
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9) for char, hue in zip(unique_chars, hues)}
if seg_labels:
max_period_len = max([len(v) for v in seg_labels.values()])
else:
max_period_len = win
prog_bar_w = int(max_period_len // video_sampling_rate) + 300 + 50 # Add 50 px buffer
progress_bar = np.ones((H, prog_bar_w, 3), dtype=np.float32)
# Try to load anchor image or create a blank one
try:
if os.path.exists("anchors.jpg"):
anchor = cv2.imread("anchors.jpg")
anchor = cv2.resize(anchor, (W + prog_bar_w, 380))
else:
anchor = np.ones((380, W + prog_bar_w, 3), dtype=np.uint8) * 255
except:
anchor = np.ones((380, W + prog_bar_w, 3), dtype=np.uint8) * 255
# Setup video writer
# Setup video writer with robust codec handling
# Try H.264 (avc1) first
fourcc_code = 'avc1'
fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
if not out.isOpened():
print(f"{fourcc_code} failed. Trying h264...")
fourcc_code = 'h264'
fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
if not out.isOpened():
print(f"{fourcc_code} failed. Trying vp80...")
fourcc_code = 'vp80'
fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
if not out.isOpened():
print(f"{fourcc_code} failed. Trying mp4v (less compatible)...")
fourcc_code = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
out = cv2.VideoWriter(output_video_path, fourcc, 30, (anchor.shape[1], H + anchor.shape[0]))
if not out.isOpened():
print("Error: Could not open video writer with any compatible codec.")
i, j = 0, 0
for idx, k in enumerate(tqdm(list(seg_labels.keys()))):
if not seg_labels[k]: # Skip empty segments
continue
labels = list(seg_labels[k].values())
frame_ids = list(seg_labels[k].keys())
j += len(seg_labels[k])
# Use boundaries from JSON (period_boundaries) if available, otherwise fallback or match
start_frame_text = "????"
end_frame_text = "????"
# period_boundaries handles both dict (from detection) and list (fallback)
# Keys or indices matching the segment order
try:
# period_boundaries might be dict or list.
# If dict, keys usually match iteration order of seg_labels.items() if consistent
# If list, idx matches sequence
boundary = period_boundaries[idx]
start_frame_text = f"{boundary[0]:04d}"
end_frame_text = f"{boundary[1]:04d}"
except:
pass
cv2.putText(progress_bar, f'Period {k+1}', (5, height*k+30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
for m, (l, frame_id) in enumerate(zip(labels[::video_sampling_rate], frame_ids[::video_sampling_rate])):
try:
progress_bar[height*k:height*(k+1), 300+m, :] = color_map[l.lower()]
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
ret, frame = cap.read()
if not ret:
continue
frame = cv2.resize(frame, (W, H))
cv2.putText(frame, f"Frame: {frame_id}", (50, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
frame = np.concatenate([frame, (progress_bar*255).astype(np.uint8)[:,:,::-1]], axis=1)
frame = np.concatenate([frame, anchor], axis=0)
out.write(frame)
except Exception as e:
print(f"Error in video generation: {str(e)}")
continue
cv2.putText(progress_bar, f'Frame: {start_frame_text}-{end_frame_text}', (5, height*k+52),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
i += len(seg_labels[k])
# Add freeze frame at the end to show the final state
try:
# Reconstruct the final frame with the updated progress_bar (which has the last period's boundaries text)
# frame structure: Top part (Video + ProgressBar), Bottom part (Anchor)
# We assume 'frame' holds the last written frame. We extract the video component (Top-Left).
# Video part is [:H, :W]
if frame is not None:
last_video_part = frame[:H, :W, :]
top_part = np.concatenate([last_video_part, (progress_bar*255).astype(np.uint8)[:,:,::-1]], axis=1)
final_frame = np.concatenate([top_part, anchor], axis=0)
for _ in range(90): # 3 seconds pause
out.write(final_frame)
except Exception as e:
print(f"Error creating freeze frame: {e}")
pass
# Release resources
cap.release()
out.release()
# Return results
return {
"workflow": workflow_multi_paths.tolist() if isinstance(workflow_multi_paths, np.ndarray) else workflow_multi_paths,
"period_boundaries": period_boundaries,
"window_size": int(win),
"num_periods": int(period_num+1),
"output_video": output_video_path if make_video else None
}