Upload 4 files
Browse files- utils/eval.py +180 -0
- utils/periodic_detection_helper.py +641 -0
- utils/plot.py +532 -0
- utils/render.py +242 -0
utils/eval.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.optimize import linear_sum_assignment
|
| 3 |
+
|
| 4 |
+
def temporal_iou(pred_span, gt_span):
|
| 5 |
+
"""
|
| 6 |
+
Calculate 1D Intersection over Union (IoU) between two temporal spans.
|
| 7 |
+
Args:
|
| 8 |
+
pred_span (tuple/list): Predicted temporal span (start, end)
|
| 9 |
+
gt_span (tuple/list): Ground truth temporal span (start, end)
|
| 10 |
+
Returns:
|
| 11 |
+
float: IoU score between 0 and 1
|
| 12 |
+
"""
|
| 13 |
+
pred_start, pred_end = pred_span
|
| 14 |
+
gt_start, gt_end = gt_span
|
| 15 |
+
|
| 16 |
+
# Ensure valid spans
|
| 17 |
+
if pred_end < pred_start or gt_end < gt_start:
|
| 18 |
+
raise ValueError("End time cannot be before start time")
|
| 19 |
+
|
| 20 |
+
# Calculate intersection
|
| 21 |
+
intersection_start = max(pred_start, gt_start)
|
| 22 |
+
intersection_end = min(pred_end, gt_end)
|
| 23 |
+
|
| 24 |
+
if intersection_end <= intersection_start:
|
| 25 |
+
return 0.0
|
| 26 |
+
|
| 27 |
+
intersection = intersection_end - intersection_start
|
| 28 |
+
|
| 29 |
+
# Calculate union
|
| 30 |
+
pred_duration = pred_end - pred_start
|
| 31 |
+
gt_duration = gt_end - gt_start
|
| 32 |
+
union = pred_duration + gt_duration - intersection
|
| 33 |
+
|
| 34 |
+
# Calculate IoU
|
| 35 |
+
iou = intersection / union
|
| 36 |
+
|
| 37 |
+
return float(iou)
|
| 38 |
+
|
| 39 |
+
def match_temporal_iou(preds, gts):
|
| 40 |
+
"""
|
| 41 |
+
Find optimal matching between predicted and ground truth temporal spans using Hungarian algorithm.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
preds (list): List of predicted temporal spans, each span is [start, end]
|
| 45 |
+
gts (list): List of ground truth temporal spans, each span is [start, end]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
tuple: (matched_indices, total_iou)
|
| 49 |
+
- matched_indices: List of (pred_idx, gt_idx) pairs
|
| 50 |
+
- total_iou: Sum of IoUs for the matched pairs
|
| 51 |
+
"""
|
| 52 |
+
if not preds or not gts:
|
| 53 |
+
return [], 0.0
|
| 54 |
+
|
| 55 |
+
# Calculate cost matrix (negative IoU since Hungarian algorithm minimizes cost)
|
| 56 |
+
cost_matrix = np.zeros((len(preds), len(gts)))
|
| 57 |
+
for i, pred in enumerate(preds):
|
| 58 |
+
for j, gt in enumerate(gts):
|
| 59 |
+
cost_matrix[i, j] = -temporal_iou(pred, gt) # Negative since we want to maximize IoU
|
| 60 |
+
|
| 61 |
+
# Apply Hungarian algorithm
|
| 62 |
+
pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
|
| 63 |
+
|
| 64 |
+
# Get matched pairs and total IoU
|
| 65 |
+
matched_pairs = list(zip(pred_indices, gt_indices))
|
| 66 |
+
total_iou = -cost_matrix[pred_indices, gt_indices].sum() # Convert back to positive
|
| 67 |
+
avg_iou = total_iou / len(gts)
|
| 68 |
+
|
| 69 |
+
return matched_pairs, avg_iou
|
| 70 |
+
'''
|
| 71 |
+
# Example usage:
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
# Example predictions and ground truths
|
| 74 |
+
predictions = [[10, 20], [25, 35], [40, 50], [50, 55]]
|
| 75 |
+
ground_truths = [[15, 25], [30, 40], [45, 55]]
|
| 76 |
+
|
| 77 |
+
# Find optimal matching
|
| 78 |
+
matches, avg_iou = match_temporal_iou(predictions, ground_truths)
|
| 79 |
+
|
| 80 |
+
print("Matched pairs (pred_idx, gt_idx):", matches)
|
| 81 |
+
print("Avg IoU:", avg_iou)
|
| 82 |
+
|
| 83 |
+
# Print individual IoUs for matched pairs
|
| 84 |
+
print("\nIndividual IoUs:")
|
| 85 |
+
for pred_idx, gt_idx in matches:
|
| 86 |
+
iou = temporal_iou(predictions[pred_idx], ground_truths[gt_idx])
|
| 87 |
+
print(f"Pred {pred_idx} - GT {gt_idx}: {iou:.3f}")
|
| 88 |
+
'''
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def find_difference_range(s1, s2):
|
| 92 |
+
# Ignore first and last chars by slicing [1:-1]
|
| 93 |
+
s1_mid = s1[1:-1]
|
| 94 |
+
s2_mid = s2[1:-1]
|
| 95 |
+
|
| 96 |
+
n = len(s1_mid)
|
| 97 |
+
if n != len(s2_mid):
|
| 98 |
+
return None # Strings of different lengths
|
| 99 |
+
|
| 100 |
+
# Find start of difference
|
| 101 |
+
start = 0
|
| 102 |
+
while start < n and s1_mid[start] == s2_mid[start]:
|
| 103 |
+
start += 1
|
| 104 |
+
|
| 105 |
+
# Find end of difference (going backwards)
|
| 106 |
+
end = n - 1
|
| 107 |
+
while end >= start and s1_mid[end] == s2_mid[end]:
|
| 108 |
+
end -= 1
|
| 109 |
+
|
| 110 |
+
# Adjust indices to account for ignored first character
|
| 111 |
+
return [start + 1, end + 1] if start <= end else None
|
| 112 |
+
|
| 113 |
+
'''
|
| 114 |
+
# Test with your example
|
| 115 |
+
s1 = "GIBJBIGCHEHCGIBFAD-"
|
| 116 |
+
s2 = "GIBJBIGCHED----FADG"
|
| 117 |
+
result = find_difference_range(s1, s2)
|
| 118 |
+
print(f"Different substrings: '{s1[result[0]:result[1]+1]}' and '{s2[result[0]:result[1]+1]}'")
|
| 119 |
+
'''
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_overlapping_substring(s1, s2, best_offset, max_matches):
|
| 124 |
+
len1 = len(s1)
|
| 125 |
+
len2 = len(s2)
|
| 126 |
+
start_index_s1 = -1
|
| 127 |
+
start_index_s2 = -1
|
| 128 |
+
|
| 129 |
+
for i in range(len1):
|
| 130 |
+
j = i - best_offset
|
| 131 |
+
if 0 <= j < len2 and s1[i] == s2[j]:
|
| 132 |
+
start_index_s1 = i
|
| 133 |
+
start_index_s2 = j
|
| 134 |
+
break # Find the first index of match
|
| 135 |
+
|
| 136 |
+
if start_index_s1 != -1:
|
| 137 |
+
return s1[start_index_s1 : start_index_s1 + max_matches]
|
| 138 |
+
else:
|
| 139 |
+
return ""
|
| 140 |
+
|
| 141 |
+
'''
|
| 142 |
+
string1 = 'JBHKHBJGCEID'
|
| 143 |
+
string2 = 'BJGCEIDIALFKCGJ'
|
| 144 |
+
best_offset, max_matches = align_strings(string1, string2)
|
| 145 |
+
overlapping_part = get_overlapping_substring(string1, string2, best_offset, max_matches)
|
| 146 |
+
print("String 1:", string1)
|
| 147 |
+
print("String 2:", string2)
|
| 148 |
+
print("\nOverlapping part:", overlapping_part)
|
| 149 |
+
'''
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def find_difference_range(s1, s2):
|
| 153 |
+
# Ignore first and last chars by slicing [1:-1]
|
| 154 |
+
s1_mid = s1[1:-1]
|
| 155 |
+
s2_mid = s2[1:-1]
|
| 156 |
+
|
| 157 |
+
n = len(s1_mid)
|
| 158 |
+
if n != len(s2_mid):
|
| 159 |
+
return None # Strings of different lengths
|
| 160 |
+
|
| 161 |
+
# Find start of difference
|
| 162 |
+
start = 0
|
| 163 |
+
while start < n and s1_mid[start] == s2_mid[start]:
|
| 164 |
+
start += 1
|
| 165 |
+
|
| 166 |
+
# Find end of difference (going backwards)
|
| 167 |
+
end = n - 1
|
| 168 |
+
while end >= start and s1_mid[end] == s2_mid[end]:
|
| 169 |
+
end -= 1
|
| 170 |
+
|
| 171 |
+
# Adjust indices to account for ignored first character
|
| 172 |
+
return [start + 1, end + 1] if start <= end else None
|
| 173 |
+
'''
|
| 174 |
+
# Test with your example
|
| 175 |
+
s1 = "GIBJBIGCHEHCGIBFAD-"
|
| 176 |
+
s2 = "GIBJBIGCHED----FADG"
|
| 177 |
+
result = find_difference_range(s1, s2)
|
| 178 |
+
print(f"Different substrings: '{s1[result[0]:result[1]+1]}' and '{s2[result[0]:result[1]+1]}'")
|
| 179 |
+
print(f"Range index: {result}")
|
| 180 |
+
'''
|
utils/periodic_detection_helper.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
from itertools import product
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from sklearn.cluster import KMeans, MeanShift
|
| 7 |
+
from sklearn.preprocessing import StandardScaler
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 11 |
+
from sklearn.cluster import KMeans
|
| 12 |
+
|
| 13 |
+
import copy
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def smooth(period_labels, gap = 1):
|
| 18 |
+
period_labels_copy = copy.deepcopy(period_labels)
|
| 19 |
+
for i in range(gap,len(period_labels)-gap):
|
| 20 |
+
counts = np.bincount(period_labels[i-gap:i+gap])
|
| 21 |
+
value = np.argmax(counts)
|
| 22 |
+
period_labels_copy[i] = value
|
| 23 |
+
return period_labels_copy
|
| 24 |
+
|
| 25 |
+
def spatiotemporal_clustering(spatiotemporal_data: np.ndarray, n_clusters: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 26 |
+
"""
|
| 27 |
+
Clusters a 3D spatial trajectory (ignoring timestamps) using DBSCAN and tokenizes it.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
spatiotemporal_data: An array of [frame, n_feats].
|
| 31 |
+
Returns:
|
| 32 |
+
A tuple containing:
|
| 33 |
+
- cluster_labels: A numpy array of cluster labels.
|
| 34 |
+
- hard_tokenized_trajectory: A numpy array representing the hard-encoded tokenized trajectory (cluster labels)
|
| 35 |
+
- soft_tokenized_trajectory: A numpy array representing the soft-encoded tokenized trajectory (vector of normalized distance to all centroids)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=20, n_init='auto')
|
| 39 |
+
cluster_labels = kmeans.fit_predict(spatiotemporal_data)
|
| 40 |
+
|
| 41 |
+
cluster_labels = smooth(cluster_labels, gap = 1)
|
| 42 |
+
# Hard-encoded tokenization for the trajectory using cluster labels.
|
| 43 |
+
hard_tokenized_trajectory = cluster_labels
|
| 44 |
+
|
| 45 |
+
# Get cluster centroids
|
| 46 |
+
centroids = kmeans.cluster_centers_
|
| 47 |
+
n_clusters = len(centroids)
|
| 48 |
+
n_points = len(spatiotemporal_data)
|
| 49 |
+
|
| 50 |
+
# Initialize array for soft tokenization
|
| 51 |
+
soft_tokenized_trajectory = np.zeros((n_points, n_clusters))
|
| 52 |
+
|
| 53 |
+
# Compute Euclidean distances to all centroids for each point
|
| 54 |
+
for i in tqdm(range(n_points)):
|
| 55 |
+
point = spatiotemporal_data[i]
|
| 56 |
+
distances = np.array([np.linalg.norm(point - centroid) for centroid in centroids])
|
| 57 |
+
#'''
|
| 58 |
+
# Convert distances to similarities using exponential decay
|
| 59 |
+
similarities = np.exp(-distances)
|
| 60 |
+
|
| 61 |
+
# Normalize similarities to sum to 1
|
| 62 |
+
soft_tokenized_trajectory[i] = similarities / np.sum(similarities)
|
| 63 |
+
|
| 64 |
+
return cluster_labels, hard_tokenized_trajectory.T, soft_tokenized_trajectory.T, centroids
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def create_path(nodes):
|
| 69 |
+
"""
|
| 70 |
+
Create a string representing a path from a list of node sets.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
nodes (list): List of lists of node IDs. Each list of nodes is connected by
|
| 74 |
+
an edge.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
str: String representing the path.
|
| 78 |
+
"""
|
| 79 |
+
result = []
|
| 80 |
+
# Initial edge
|
| 81 |
+
result.append(f"{nodes[0][0]}->{nodes[1][0]}")
|
| 82 |
+
|
| 83 |
+
current_idx = 1
|
| 84 |
+
# Loop until all edges are processed
|
| 85 |
+
while current_idx < len(nodes) - 1:
|
| 86 |
+
sources = nodes[current_idx]
|
| 87 |
+
targets = nodes[current_idx + 1]
|
| 88 |
+
|
| 89 |
+
if len(sources) == 1 and len(targets) == 1:
|
| 90 |
+
# One source, one target
|
| 91 |
+
result.append(f"{sources[0]}->{targets[0]}")
|
| 92 |
+
elif len(sources) == 1:
|
| 93 |
+
# One source, multiple targets
|
| 94 |
+
paths = [f"{sources[0]}->{target}" for target in targets]
|
| 95 |
+
result.append(f"({', '.join(paths)})")
|
| 96 |
+
elif len(targets) == 1:
|
| 97 |
+
# Multiple sources, one target
|
| 98 |
+
paths = [f"{source}->{targets[0]}" for source in sources]
|
| 99 |
+
result.append(f"({', '.join(paths)})")
|
| 100 |
+
else:
|
| 101 |
+
# Multiple sources, multiple targets
|
| 102 |
+
paths = []
|
| 103 |
+
for i in range(len(sources)):
|
| 104 |
+
paths.append(f"{sources[i]}->{targets[i]}")
|
| 105 |
+
result.append(f"({', '.join(paths)})")
|
| 106 |
+
|
| 107 |
+
current_idx += 1
|
| 108 |
+
|
| 109 |
+
return ', '.join(result)
|
| 110 |
+
|
| 111 |
+
def summarize_strings(strings):
|
| 112 |
+
"""
|
| 113 |
+
Summarize a list of strings by comparing characters at each position.
|
| 114 |
+
|
| 115 |
+
If all strings have the same character at a position, that character is
|
| 116 |
+
included in the result. If not, an underscore is included.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
strings (list): List of strings to summarize
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
str: Summary of the strings
|
| 123 |
+
"""
|
| 124 |
+
if not strings:
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
# Get length of shortest string
|
| 128 |
+
min_len = min(len(s) for s in strings)
|
| 129 |
+
|
| 130 |
+
# Compare characters at each position
|
| 131 |
+
result = []
|
| 132 |
+
for i in range(min_len):
|
| 133 |
+
chars = set(s[i] for s in strings)
|
| 134 |
+
# If all strings have the same character at this position, use that
|
| 135 |
+
# character. Otherwise, use an underscore.
|
| 136 |
+
result.append("_" if len(chars) > 1 else strings[0][i])
|
| 137 |
+
|
| 138 |
+
return "".join(result)
|
| 139 |
+
|
| 140 |
+
def find_dash_end_index(strings):
|
| 141 |
+
"""
|
| 142 |
+
Find the index of the last dash in the strings that is
|
| 143 |
+
immediately preceded by a letter.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
strings (list): List of strings with same length
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
int: Index of the last dash (if found) or -1
|
| 150 |
+
"""
|
| 151 |
+
# Ensure all strings have same length
|
| 152 |
+
if not all(len(s) == len(strings[0]) for s in strings):
|
| 153 |
+
raise ValueError("Strings must be of equal length")
|
| 154 |
+
|
| 155 |
+
# Iterate from the right
|
| 156 |
+
for i in range(len(strings[0])-1, -1, -1):
|
| 157 |
+
for s in strings:
|
| 158 |
+
if s[i] == '-':
|
| 159 |
+
# Check if previous char is letter
|
| 160 |
+
if i > 0 and s[i-1].isalpha():
|
| 161 |
+
return i
|
| 162 |
+
elif not s[i].isalpha(): # Skip if not dash or letter
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
return -1 # No matching pattern found
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def find_longest_repeated_ends(strings):
|
| 170 |
+
"""
|
| 171 |
+
Find the longest prefix and suffix that are identical across all strings.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
strings (list): List of strings to check.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
int: Length of the longest common prefix and suffix.
|
| 178 |
+
"""
|
| 179 |
+
if not strings:
|
| 180 |
+
return 0
|
| 181 |
+
|
| 182 |
+
# Use the first string as a reference
|
| 183 |
+
s = strings[0]
|
| 184 |
+
n = len(s)
|
| 185 |
+
max_len = 0
|
| 186 |
+
|
| 187 |
+
# Iterate over possible prefix/suffix lengths
|
| 188 |
+
for i in range(1, n // 2 + 1):
|
| 189 |
+
prefix = s[:i]
|
| 190 |
+
suffix = s[-i:]
|
| 191 |
+
|
| 192 |
+
# Check if prefix equals suffix and appears in all strings
|
| 193 |
+
if prefix == suffix and all(st.startswith(prefix) and st.endswith(suffix) for st in strings):
|
| 194 |
+
max_len = i
|
| 195 |
+
|
| 196 |
+
return max_len
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def create_path(nodes):
|
| 200 |
+
result = []
|
| 201 |
+
result.append(f"{nodes[0][0]}->{nodes[1][0]}")
|
| 202 |
+
|
| 203 |
+
current_idx = 0
|
| 204 |
+
while current_idx < len(nodes) - 1:
|
| 205 |
+
sources = nodes[current_idx]
|
| 206 |
+
targets = nodes[current_idx + 1]
|
| 207 |
+
|
| 208 |
+
if len(sources) == 1 and len(targets) == 1:
|
| 209 |
+
result.append(f"{sources[0]}->{targets[0]}")
|
| 210 |
+
elif len(sources) == 1:
|
| 211 |
+
paths = [f"{sources[0]}->{target}" for target in targets]
|
| 212 |
+
result.append(f"({', '.join(paths)})")
|
| 213 |
+
elif len(targets) == 1:
|
| 214 |
+
paths = [f"{source}->{targets[0]}" for source in sources]
|
| 215 |
+
result.append(f"({', '.join(paths)})")
|
| 216 |
+
else:
|
| 217 |
+
paths = []
|
| 218 |
+
for i in range(len(sources)):
|
| 219 |
+
paths.append(f"{sources[i]}->{targets[i]}")
|
| 220 |
+
result.append(f"({', '.join(paths)})")
|
| 221 |
+
|
| 222 |
+
current_idx += 1
|
| 223 |
+
|
| 224 |
+
return ', '.join(result)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def dominant_fourier_frequency_2d(matrix, lbound=10, ubound=1000):
|
| 228 |
+
"""
|
| 229 |
+
Find the dominant Fourier frequencies of a 2D matrix within a window size range.
|
| 230 |
+
|
| 231 |
+
Parameters
|
| 232 |
+
----------
|
| 233 |
+
matrix : array-like
|
| 234 |
+
The input 2D matrix
|
| 235 |
+
lbound : int, optional
|
| 236 |
+
The lower bound of the window size range. Default is 10.
|
| 237 |
+
ubound : int, optional
|
| 238 |
+
The upper bound of the window size range. Default is 1000.
|
| 239 |
+
|
| 240 |
+
Returns
|
| 241 |
+
-------L
|
| 242 |
+
tuple
|
| 243 |
+
period_condidates
|
| 244 |
+
period_condidates_magnitudes
|
| 245 |
+
"""
|
| 246 |
+
# Compute 2D FFT
|
| 247 |
+
fourier = np.fft.fft2(matrix)
|
| 248 |
+
|
| 249 |
+
# Get frequency components for temporal dimensions
|
| 250 |
+
freq_x = np.fft.fftfreq(matrix.shape[1], 1)
|
| 251 |
+
|
| 252 |
+
magnitudes_x = []
|
| 253 |
+
window_sizes_x = []
|
| 254 |
+
|
| 255 |
+
# Analyze horizontal frequencies (x-axis)
|
| 256 |
+
for j, freq in enumerate(freq_x):
|
| 257 |
+
if freq > 0: # Only consider positive frequencies
|
| 258 |
+
window_size = int(1 / freq)
|
| 259 |
+
if window_size >= lbound and window_size < ubound:
|
| 260 |
+
# Sum magnitudes across columns for this frequency
|
| 261 |
+
mag = 0
|
| 262 |
+
for i in range(matrix.shape[0]):
|
| 263 |
+
coef = fourier[i, j]
|
| 264 |
+
mag += math.sqrt(coef.real * coef.real + coef.imag * coef.imag)
|
| 265 |
+
window_sizes_x.append(window_size)
|
| 266 |
+
magnitudes_x.append(mag)
|
| 267 |
+
|
| 268 |
+
'''
|
| 269 |
+
# Handle cases where no valid frequencies are found
|
| 270 |
+
if len(magnitudes_x) == 0:
|
| 271 |
+
warnings.warn(f"Could not extract valid horizontal frequencies. Using window_size={lbound}.")
|
| 272 |
+
period_x = lbound
|
| 273 |
+
else:
|
| 274 |
+
period_x = window_sizes_x[np.argmax(magnitudes_x)]
|
| 275 |
+
'''
|
| 276 |
+
|
| 277 |
+
return np.array(window_sizes_x)[np.argsort(magnitudes_x)[::-1]], np.sort(magnitudes_x)[::-1]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def dominant_fourier_frequency_1d(time_series, lbound=10, ubound=1000):
|
| 282 |
+
"""
|
| 283 |
+
Find the dominant Fourier frequency of the time series within a window size range.
|
| 284 |
+
|
| 285 |
+
Parameters
|
| 286 |
+
----------
|
| 287 |
+
time_series : array-like
|
| 288 |
+
The input time series.
|
| 289 |
+
lbound : int, optional
|
| 290 |
+
The lower bound of the window size range. Default is 10.
|
| 291 |
+
ubound : int, optional
|
| 292 |
+
The upper bound of the window size range. Default is 1000.
|
| 293 |
+
|
| 294 |
+
Returns
|
| 295 |
+
-------
|
| 296 |
+
The dominant Fourier frequency's corresponding window size within the specified range.
|
| 297 |
+
period_condidates
|
| 298 |
+
period_condidates_magnitudes
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
if time_series.shape[0] < 2 * lbound:
|
| 302 |
+
warnings.warn(
|
| 303 |
+
f"Time series must at least have 2*lbound much data points. Using window_size={time_series.shape[0]}.")
|
| 304 |
+
return time_series.shape[0]
|
| 305 |
+
|
| 306 |
+
fourier = np.fft.fft(time_series)
|
| 307 |
+
freq = np.fft.fftfreq(time_series.shape[0], 1)
|
| 308 |
+
|
| 309 |
+
magnitudes = []
|
| 310 |
+
window_sizes = []
|
| 311 |
+
|
| 312 |
+
for coef, freq in zip(fourier, freq):
|
| 313 |
+
if coef and freq > 0:
|
| 314 |
+
window_size = int(1 / freq)
|
| 315 |
+
mag = math.sqrt(coef.real * coef.real + coef.imag * coef.imag)
|
| 316 |
+
|
| 317 |
+
if window_size >= lbound and window_size < ubound:
|
| 318 |
+
window_sizes.append(window_size)
|
| 319 |
+
magnitudes.append(mag)
|
| 320 |
+
|
| 321 |
+
if len(magnitudes) == 0:
|
| 322 |
+
warnings.warn(f"Could not extract valid frequencies. Using window_size={lbound}.")
|
| 323 |
+
return lbound
|
| 324 |
+
|
| 325 |
+
return np.array(window_sizes)[np.argsort(magnitudes)[::-1]], np.sort(magnitudes)[::-1]
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
from difflib import SequenceMatcher
|
| 329 |
+
from collections import Counter
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def calculate_similarity_score(strings_list):
|
| 333 |
+
"""
|
| 334 |
+
Calculate an overall similarity score for a list of strings.
|
| 335 |
+
The score is based on multiple similarity metrics.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
strings_list: List of strings to compare
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
float: Overall similarity score between 0 and 1
|
| 342 |
+
"""
|
| 343 |
+
if not strings_list or len(strings_list) < 2:
|
| 344 |
+
return 1.0 # A single string or empty list is perfectly similar to itself
|
| 345 |
+
|
| 346 |
+
n = len(strings_list)
|
| 347 |
+
total_comparisons = n * (n - 1) // 2
|
| 348 |
+
|
| 349 |
+
# Initialize scores for different metrics
|
| 350 |
+
sequence_scores = []
|
| 351 |
+
jaccard_scores = []
|
| 352 |
+
length_ratio_scores = []
|
| 353 |
+
|
| 354 |
+
# Compare each pair of strings
|
| 355 |
+
for i in range(n):
|
| 356 |
+
for j in range(i + 1, n):
|
| 357 |
+
str1 = strings_list[i]
|
| 358 |
+
str2 = strings_list[j]
|
| 359 |
+
|
| 360 |
+
# Sequence Matcher (difflib) score
|
| 361 |
+
sequence_score = SequenceMatcher(None, str1, str2).ratio()
|
| 362 |
+
sequence_scores.append(sequence_score)
|
| 363 |
+
|
| 364 |
+
# Jaccard similarity (character-based)
|
| 365 |
+
set1, set2 = set(str1), set(str2)
|
| 366 |
+
jaccard_score = len(set1.intersection(set2)) / len(set1.union(set2)) if set1 or set2 else 1.0
|
| 367 |
+
jaccard_scores.append(jaccard_score)
|
| 368 |
+
|
| 369 |
+
# Length ratio (shorter/longer)
|
| 370 |
+
length_ratio = min(len(str1), len(str2)) / max(len(str1), len(str2)) if max(len(str1), len(str2)) > 0 else 1.0
|
| 371 |
+
length_ratio_scores.append(length_ratio)
|
| 372 |
+
|
| 373 |
+
# Calculate average scores
|
| 374 |
+
avg_sequence = np.mean(sequence_scores)
|
| 375 |
+
avg_jaccard = np.mean(jaccard_scores)
|
| 376 |
+
avg_length_ratio = np.mean(length_ratio_scores)
|
| 377 |
+
|
| 378 |
+
# Calculate overall score (weighted average of the three metrics)
|
| 379 |
+
overall_score = 0.5 * avg_sequence + 0.3 * avg_jaccard + 0.2 * avg_length_ratio
|
| 380 |
+
|
| 381 |
+
return overall_score
|
| 382 |
+
|
| 383 |
+
def fuse_adjacent(s):
|
| 384 |
+
if not s:
|
| 385 |
+
return ''
|
| 386 |
+
result = s[0]
|
| 387 |
+
for c in s[1:]:
|
| 388 |
+
if c != result[-1]:
|
| 389 |
+
result += c
|
| 390 |
+
return result
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def find_longest_identical_pair(s):
|
| 394 |
+
left = 0
|
| 395 |
+
right = len(s) - 1
|
| 396 |
+
id_pair = (None, -1, -1)
|
| 397 |
+
while left < right:
|
| 398 |
+
for i in range(left+1, right+1):
|
| 399 |
+
if s[left] == s[i]:
|
| 400 |
+
if id_pair[2] - id_pair[1] < i - left:
|
| 401 |
+
id_pair = (s[left], left, i)
|
| 402 |
+
left += 1
|
| 403 |
+
if id_pair[0] is None:
|
| 404 |
+
return None # If no identical pair is found
|
| 405 |
+
else:
|
| 406 |
+
return id_pair
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
'''def number_to_alpha(numbers):
|
| 410 |
+
# Create a mapping of numbers to alphabetic characters
|
| 411 |
+
alpha_map = {i: chr(97 + i) for i in range(26)} # a-z
|
| 412 |
+
alpha_map.update({i + 26: chr(65 + i) for i in range(26)}) # A-Z
|
| 413 |
+
|
| 414 |
+
# Convert numbers to characters
|
| 415 |
+
result = ''
|
| 416 |
+
for num in numbers:
|
| 417 |
+
if num in alpha_map:
|
| 418 |
+
result += alpha_map[num]
|
| 419 |
+
else:
|
| 420 |
+
result += '?' # For numbers outside the range 0-51
|
| 421 |
+
return result'''
|
| 422 |
+
|
| 423 |
+
def number_to_alpha(numbers):
|
| 424 |
+
alpha_map = {i: chr(65 + i) for i in range(26)} # A-Z
|
| 425 |
+
|
| 426 |
+
result = ''
|
| 427 |
+
for num in numbers:
|
| 428 |
+
if num in alpha_map:
|
| 429 |
+
result += alpha_map[num]
|
| 430 |
+
else:
|
| 431 |
+
result += '?' # For numbers outside the range 0-25
|
| 432 |
+
return result
|
| 433 |
+
|
| 434 |
+
def alpha_to_number(sequence):
|
| 435 |
+
return [ord(c.upper()) - ord('A') for c in sequence]
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def score_match(chars):
|
| 442 |
+
"""Score a column of aligned characters"""
|
| 443 |
+
if '-' in chars:
|
| 444 |
+
return -len([c for c in chars if c == '-']) # Gap penalty
|
| 445 |
+
return sum(1 for i, j in product(chars, chars) if i == j) - len(chars) # Sum of pairwise matches
|
| 446 |
+
|
| 447 |
+
def initialize_matrix(sequences):
|
| 448 |
+
"""Initialize the N-dimensional DP matrix and pointers"""
|
| 449 |
+
# Get dimensions for each sequence
|
| 450 |
+
dims = [len(seq) + 1 for seq in sequences]
|
| 451 |
+
|
| 452 |
+
# Create score matrix F and pointer matrix P
|
| 453 |
+
F = np.zeros(dims)
|
| 454 |
+
# Initialize P with lists instead of zeros
|
| 455 |
+
P = np.empty(dims, dtype=object)
|
| 456 |
+
for idx in np.ndindex(*dims):
|
| 457 |
+
P[idx] = []
|
| 458 |
+
|
| 459 |
+
# Initialize edges with gap penalties
|
| 460 |
+
for idx, dim in enumerate(dims):
|
| 461 |
+
# Create slice objects for each dimension
|
| 462 |
+
slices = [slice(None) if i == idx else 0 for i in range(len(dims))]
|
| 463 |
+
indices = range(1, dim)
|
| 464 |
+
F[tuple(slices)] = np.linspace(0, -len(sequences) * dim, dim)
|
| 465 |
+
|
| 466 |
+
return F, P
|
| 467 |
+
|
| 468 |
+
def get_neighbors(current_pos, dims):
|
| 469 |
+
"""Get all possible previous positions in the DP matrix"""
|
| 470 |
+
neighbors = []
|
| 471 |
+
for i in range(2 ** len(dims)):
|
| 472 |
+
neighbor = []
|
| 473 |
+
for j, pos in enumerate(current_pos):
|
| 474 |
+
if i & (1 << j):
|
| 475 |
+
if pos > 0: # Check boundary
|
| 476 |
+
neighbor.append(pos - 1)
|
| 477 |
+
else:
|
| 478 |
+
break
|
| 479 |
+
else:
|
| 480 |
+
neighbor.append(pos)
|
| 481 |
+
if len(neighbor) == len(dims):
|
| 482 |
+
neighbors.append(tuple(neighbor))
|
| 483 |
+
return neighbors[1:] # Exclude current position
|
| 484 |
+
|
| 485 |
+
def msa(sequences, gap_penalty=-1):
|
| 486 |
+
"""Perform multiple sequence alignment using N-dimensional Needleman-Wunsch"""
|
| 487 |
+
# Initialize matrices
|
| 488 |
+
F, P = initialize_matrix(sequences)
|
| 489 |
+
dims = F.shape
|
| 490 |
+
|
| 491 |
+
# Fill the DP matrix
|
| 492 |
+
for pos in product(*[range(1, dim) for dim in dims]):
|
| 493 |
+
|
| 494 |
+
# Get characters at current position
|
| 495 |
+
chars = [sequences[i][pos[i]-1] for i in range(len(sequences))]
|
| 496 |
+
|
| 497 |
+
# Get all possible previous positions
|
| 498 |
+
neighbors = get_neighbors(pos, dims)
|
| 499 |
+
|
| 500 |
+
# Calculate scores for all possible alignments
|
| 501 |
+
max_score = float('-inf')
|
| 502 |
+
best_moves = []
|
| 503 |
+
|
| 504 |
+
for neighbor in neighbors:
|
| 505 |
+
# Calculate score based on which sequences are aligned
|
| 506 |
+
aligned_chars = []
|
| 507 |
+
for i, (curr, prev) in enumerate(zip(pos, neighbor)):
|
| 508 |
+
if curr != prev:
|
| 509 |
+
aligned_chars.append(sequences[i][curr-1])
|
| 510 |
+
else:
|
| 511 |
+
aligned_chars.append('-')
|
| 512 |
+
|
| 513 |
+
score = F[neighbor] + score_match(aligned_chars)
|
| 514 |
+
|
| 515 |
+
if score > max_score:
|
| 516 |
+
max_score = score
|
| 517 |
+
best_moves = [neighbor]
|
| 518 |
+
elif score == max_score:
|
| 519 |
+
best_moves.append(neighbor)
|
| 520 |
+
|
| 521 |
+
F[pos] = max_score
|
| 522 |
+
P[pos] = best_moves # Store list of best moves
|
| 523 |
+
|
| 524 |
+
# Traceback
|
| 525 |
+
aligned_sequences = [[] for _ in sequences]
|
| 526 |
+
current_pos = tuple(dim-1 for dim in dims)
|
| 527 |
+
|
| 528 |
+
while any(pos > 0 for pos in current_pos):
|
| 529 |
+
# Ensure P[current_pos] contains valid moves
|
| 530 |
+
if not P[current_pos]: # If no moves stored, break
|
| 531 |
+
break
|
| 532 |
+
|
| 533 |
+
prev_pos = P[current_pos][0] # Take first best move
|
| 534 |
+
|
| 535 |
+
# Add characters or gaps based on moves
|
| 536 |
+
for i, (curr, prev) in enumerate(zip(current_pos, prev_pos)):
|
| 537 |
+
if curr != prev:
|
| 538 |
+
aligned_sequences[i].append(sequences[i][curr-1])
|
| 539 |
+
else:
|
| 540 |
+
aligned_sequences[i].append('-')
|
| 541 |
+
|
| 542 |
+
current_pos = prev_pos
|
| 543 |
+
|
| 544 |
+
# Reverse and join sequences
|
| 545 |
+
return [(''.join(seq))[::-1] for seq in aligned_sequences]
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# Example usage
|
| 549 |
+
#sequences = ['ACBAFDECBAECFACBA', 'CFACBAFDECBAECFA', 'ECFACBAFBECBAECFA']
|
| 550 |
+
#aligned_sequences = msa(sequences)
|
| 551 |
+
#print('\n'.join(aligned_sequences))
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def align_strings(s1, s2):
|
| 555 |
+
len1 = len(s1)
|
| 556 |
+
len2 = len(s2)
|
| 557 |
+
max_matches = 0
|
| 558 |
+
best_offset = 0
|
| 559 |
+
for offset in range(-len2 + 1, len1):
|
| 560 |
+
match_count = 0
|
| 561 |
+
for i in range(len1):
|
| 562 |
+
j = i - offset
|
| 563 |
+
if 0 <= j < len2 and s1[i] == s2[j]:
|
| 564 |
+
match_count += 1
|
| 565 |
+
if match_count > max_matches:
|
| 566 |
+
max_matches = match_count
|
| 567 |
+
best_offset = offset
|
| 568 |
+
return best_offset, max_matches
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def get_overlapping_substring(s1, s2, best_offset, max_matches):
|
| 572 |
+
len1 = len(s1)
|
| 573 |
+
len2 = len(s2)
|
| 574 |
+
start_index_s1 = -1
|
| 575 |
+
start_index_s2 = -1
|
| 576 |
+
|
| 577 |
+
for i in range(len1):
|
| 578 |
+
j = i - best_offset
|
| 579 |
+
if 0 <= j < len2 and s1[i] == s2[j]:
|
| 580 |
+
start_index_s1 = i
|
| 581 |
+
start_index_s2 = j
|
| 582 |
+
break # Find the first index of match
|
| 583 |
+
|
| 584 |
+
if start_index_s1 != -1:
|
| 585 |
+
return s1[start_index_s1 : start_index_s1 + max_matches]
|
| 586 |
+
else:
|
| 587 |
+
return ""
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
'''
|
| 591 |
+
def align_multiple_strings(strings):
|
| 592 |
+
if not strings:
|
| 593 |
+
return []
|
| 594 |
+
reference = strings[0]
|
| 595 |
+
for next_string in strings[1:]:
|
| 596 |
+
ref_align, _ = align_two_strings(reference, next_string)
|
| 597 |
+
# Collapse underscores to match specification
|
| 598 |
+
reference = ''.join(char for i, char in enumerate(ref_align) if char != '_' or (i > 0 and ref_align[i - 1] != '_'))
|
| 599 |
+
return reference
|
| 600 |
+
|
| 601 |
+
def align_two_strings(str1, str2):
|
| 602 |
+
# Alignment code (based on previously defined)
|
| 603 |
+
n, m = len(str1), len(str2)
|
| 604 |
+
dp = [[0] * (m + 1) for _ in range(n + 1)]
|
| 605 |
+
for i in range(n + 1): dp[i][0] = i
|
| 606 |
+
for j in range(m + 1): dp[0][j] = j
|
| 607 |
+
for i in range(1, n + 1):
|
| 608 |
+
for j in range(1, m + 1):
|
| 609 |
+
if str1[i - 1] == str2[j - 1]:
|
| 610 |
+
dp[i][j] = dp[i - 1][j - 1]
|
| 611 |
+
else:
|
| 612 |
+
dp[i][j] = min(dp[i - 1][j - 1], dp[i][j - 1], dp[i - 1][j]) + 1
|
| 613 |
+
aligned1, aligned2 = [], []
|
| 614 |
+
i, j = n, m
|
| 615 |
+
while i > 0 and j > 0:
|
| 616 |
+
if str1[i - 1] == str2[j - 1]:
|
| 617 |
+
aligned1.append(str1[i - 1])
|
| 618 |
+
aligned2.append(str2[j - 1])
|
| 619 |
+
i -= 1
|
| 620 |
+
j -= 1
|
| 621 |
+
elif dp[i][j] == dp[i - 1][j - 1] + 1:
|
| 622 |
+
aligned1.append('_')
|
| 623 |
+
aligned2.append('_')
|
| 624 |
+
i -= 1
|
| 625 |
+
j -= 1
|
| 626 |
+
elif dp[i][j] == dp[i - 1][j]+1:
|
| 627 |
+
aligned1.append('_')
|
| 628 |
+
i -= 1
|
| 629 |
+
else:
|
| 630 |
+
aligned2.append('_')
|
| 631 |
+
j -= 1
|
| 632 |
+
while i > 0:
|
| 633 |
+
aligned1.append('_')
|
| 634 |
+
i -= 1
|
| 635 |
+
while j > 0:
|
| 636 |
+
aligned2.append('_')
|
| 637 |
+
j -= 1
|
| 638 |
+
aligned1.reverse()
|
| 639 |
+
aligned2.reverse()
|
| 640 |
+
return ''.join(aligned1), ''.join(aligned2)
|
| 641 |
+
'''
|
utils/plot.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
from colorsys import hsv_to_rgb
|
| 4 |
+
import string
|
| 5 |
+
import networkx as nx
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
# import osmnx as ox
|
| 9 |
+
import matplotlib.animation as animation
|
| 10 |
+
from itertools import combinations
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#################plot for transcripts############################
|
| 15 |
+
def plot_string(text, figsize=(18, 4)):
|
| 16 |
+
"""
|
| 17 |
+
Plot string with identical colors for identical letters and 0.2x letter width spacing.
|
| 18 |
+
Supports lowercase letters, uppercase letters, and underscores.
|
| 19 |
+
"""
|
| 20 |
+
plt.rcParams['font.family'] = 'Times New Roman'
|
| 21 |
+
plt.figure(figsize=figsize)
|
| 22 |
+
|
| 23 |
+
# Include underscore in the character set
|
| 24 |
+
unique_chars = (sorted(set(string.ascii_lowercase))[:10])
|
| 25 |
+
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
|
| 26 |
+
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
|
| 27 |
+
for char, hue in zip(unique_chars, hues)}
|
| 28 |
+
unique_chars += ['-']
|
| 29 |
+
# Add special handling for underscore
|
| 30 |
+
color_map['-'] = (0.5, 0.5, 0.5) # Gray color for underscore
|
| 31 |
+
|
| 32 |
+
unique_chars += (sorted(set(string.ascii_lowercase))[10:])
|
| 33 |
+
color_map = {char: hsv_to_rgb(hue, 0.3, 0.8)
|
| 34 |
+
for char, hue in zip(unique_chars, hues)}
|
| 35 |
+
|
| 36 |
+
spacing = 0.1 # Space between letters relative to letter width
|
| 37 |
+
width = 1.0 # Width of each letter
|
| 38 |
+
total_width = len(text) * width * (1 + spacing) - spacing
|
| 39 |
+
|
| 40 |
+
for i, char in enumerate(text):
|
| 41 |
+
x_pos = i * width * (1 + spacing)
|
| 42 |
+
|
| 43 |
+
if char == '_':
|
| 44 |
+
# Draw underscore as a line slightly below the baseline
|
| 45 |
+
plt.plot([x_pos - width/3, x_pos + width/3],
|
| 46 |
+
[-0.2, -0.2],
|
| 47 |
+
color=color_map['_'],
|
| 48 |
+
linewidth=2)
|
| 49 |
+
else:
|
| 50 |
+
# Regular character plotting
|
| 51 |
+
color = color_map[char.lower()]
|
| 52 |
+
plt.text(x_pos, 0, char, fontsize=14, color=color,
|
| 53 |
+
ha='center', va='center')
|
| 54 |
+
|
| 55 |
+
plt.xlim(-width/2, total_width - width/2)
|
| 56 |
+
plt.ylim(-0.5, 0.5)
|
| 57 |
+
plt.axis('off')
|
| 58 |
+
plt.tight_layout()
|
| 59 |
+
plt.savefig('transcript.svg', format='svg')
|
| 60 |
+
#plt.show()
|
| 61 |
+
|
| 62 |
+
# Example usage
|
| 63 |
+
#text = "AACCCBBAAFFDDDEEEECCCCBBAAAEEECCCFFFFAAACCBBAAFFDDDDEEEEECCCCCBBBBAAAEECCCCCFFFAAAACCCBBBBAAAFFFBBBEECCBBBAAEECCCFFAA"
|
| 64 |
+
#plot_string(text)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def parse_sequence(sequence):
|
| 70 |
+
"""Parse sequence into main path and branches"""
|
| 71 |
+
branches = []
|
| 72 |
+
main_path = []
|
| 73 |
+
current_branch = []
|
| 74 |
+
in_branch = False
|
| 75 |
+
|
| 76 |
+
for part in sequence.split(','):
|
| 77 |
+
part = part.strip()
|
| 78 |
+
if '(' in part:
|
| 79 |
+
in_branch = True
|
| 80 |
+
part = part.replace('(', '')
|
| 81 |
+
if ')' in part:
|
| 82 |
+
in_branch = False
|
| 83 |
+
part = part.replace(')', '')
|
| 84 |
+
|
| 85 |
+
nodes = re.findall(r'([A-Za-z_]\d+)', part)
|
| 86 |
+
if len(nodes) == 2:
|
| 87 |
+
if in_branch:
|
| 88 |
+
current_branch.extend(nodes)
|
| 89 |
+
else:
|
| 90 |
+
if current_branch:
|
| 91 |
+
branches.append(current_branch)
|
| 92 |
+
current_branch = []
|
| 93 |
+
main_path.extend(nodes)
|
| 94 |
+
|
| 95 |
+
if current_branch:
|
| 96 |
+
branches.append(current_branch)
|
| 97 |
+
|
| 98 |
+
return main_path, branches
|
| 99 |
+
|
| 100 |
+
def get_node_number(node):
|
| 101 |
+
"""Extract number from node label"""
|
| 102 |
+
return int(re.findall(r'\d+', node)[0])
|
| 103 |
+
|
| 104 |
+
def plot_sequence(sequence, figsize=(8, 3)):
|
| 105 |
+
plt.figure(figsize=figsize)
|
| 106 |
+
|
| 107 |
+
# Assign colors
|
| 108 |
+
unique_chars = sorted(set(string.ascii_uppercase))[:10]
|
| 109 |
+
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
|
| 110 |
+
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
|
| 111 |
+
for char, hue in zip(unique_chars, hues)}
|
| 112 |
+
|
| 113 |
+
unique_chars += ['_']
|
| 114 |
+
# Add special handling for underscore
|
| 115 |
+
color_map['_'] = (0.5, 0.5, 0.5) # Gray color for underscore
|
| 116 |
+
|
| 117 |
+
color_map.update({char: hsv_to_rgb(hue, 0.3, 0.8)
|
| 118 |
+
for char, hue in zip(sorted(set(string.ascii_uppercase))[10:], hues)})
|
| 119 |
+
|
| 120 |
+
main_path, branches = parse_sequence(sequence)
|
| 121 |
+
G = nx.DiGraph()
|
| 122 |
+
|
| 123 |
+
# Calculate positions
|
| 124 |
+
pos = {}
|
| 125 |
+
x_spacing = 1
|
| 126 |
+
y_spacing = 0.5
|
| 127 |
+
|
| 128 |
+
# Group nodes by their number
|
| 129 |
+
nodes_by_number = {}
|
| 130 |
+
all_nodes = set(main_path)
|
| 131 |
+
for branch in branches:
|
| 132 |
+
all_nodes.update(branch)
|
| 133 |
+
|
| 134 |
+
for node in all_nodes:
|
| 135 |
+
num = get_node_number(node)
|
| 136 |
+
if num not in nodes_by_number:
|
| 137 |
+
nodes_by_number[num] = []
|
| 138 |
+
nodes_by_number[num].append(node)
|
| 139 |
+
|
| 140 |
+
# Position nodes
|
| 141 |
+
for num in sorted(nodes_by_number.keys()):
|
| 142 |
+
nodes = nodes_by_number[num]
|
| 143 |
+
x = (num - 1) * x_spacing
|
| 144 |
+
|
| 145 |
+
if len(nodes) == 1:
|
| 146 |
+
pos[nodes[0]] = (x, 0)
|
| 147 |
+
else:
|
| 148 |
+
# Center branching nodes vertically
|
| 149 |
+
total_height = (len(nodes) - 1) * y_spacing
|
| 150 |
+
start_y = -total_height / 2
|
| 151 |
+
for i, node in enumerate(sorted(nodes)):
|
| 152 |
+
pos[node] = (x, start_y + i * y_spacing)
|
| 153 |
+
|
| 154 |
+
# Add edges
|
| 155 |
+
for i in range(0, len(main_path)-1, 2):
|
| 156 |
+
G.add_edge(main_path[i], main_path[i+1])
|
| 157 |
+
|
| 158 |
+
for branch in branches:
|
| 159 |
+
for i in range(0, len(branch)-1, 2):
|
| 160 |
+
G.add_edge(branch[i], branch[i+1])
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Draw arrows
|
| 164 |
+
nx.draw_networkx_edges(G, pos, edge_color='gray',
|
| 165 |
+
arrowsize=10, width=1.5,
|
| 166 |
+
arrowstyle='->')
|
| 167 |
+
|
| 168 |
+
# Draw nodes
|
| 169 |
+
for node in G.nodes():
|
| 170 |
+
letter = node[0]
|
| 171 |
+
color = color_map[letter]
|
| 172 |
+
|
| 173 |
+
circle = plt.Circle(pos[node], 0.1,
|
| 174 |
+
color=color, alpha=0.3)
|
| 175 |
+
plt.gca().add_patch(circle)
|
| 176 |
+
|
| 177 |
+
plt.text(pos[node][0], pos[node][1], node[0],
|
| 178 |
+
color=color, fontsize=8,
|
| 179 |
+
ha='center', va='center',
|
| 180 |
+
fontweight='bold')
|
| 181 |
+
|
| 182 |
+
plt.axis('equal')
|
| 183 |
+
plt.axis('off')
|
| 184 |
+
plt.tight_layout()
|
| 185 |
+
plt.savefig('transcript.svg', format='svg')
|
| 186 |
+
|
| 187 |
+
# Example usage
|
| 188 |
+
#sequence = "A1->C2, C2->B3, B3->A4, A4->F5, (F5->B6, F5->D6), (B6->E7, D6->E7), E7->C8, C8->B9, B9->A10, A10->E11, E11->C12, C12->F13"
|
| 189 |
+
#plot_sequence(sequence, (5.4,2))
|
| 190 |
+
|
| 191 |
+
#################plot for transcripts############################
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
#################plot for 2D trajectories############################
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def plot_routes_animation(G, routes, colors, output_file, fps=20, duration_sec=10):
|
| 202 |
+
"""
|
| 203 |
+
Create an animation showing routes appearing dynamically.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
G (networkx.MultiDiGraph): Street network graph
|
| 207 |
+
routes (list): List of routes (each route is a list of nodes)
|
| 208 |
+
colors (list): List of colors for each route
|
| 209 |
+
output_file (str): Output filename (should end with .gif or .mp4)
|
| 210 |
+
fps (int): Frames per second
|
| 211 |
+
duration_sec (int): Total animation duration in seconds
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
# Create figure and axis
|
| 215 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 216 |
+
plt.rcParams['font.family'] = 'Times New Roman'
|
| 217 |
+
|
| 218 |
+
# Plot the base map
|
| 219 |
+
# ox.plot_graph(G, ax=ax, show=False, close=False,
|
| 220 |
+
# edge_color='gray', edge_alpha=0.2, node_size=0)
|
| 221 |
+
|
| 222 |
+
# Create empty route lines
|
| 223 |
+
route_lines = []
|
| 224 |
+
route_points = []
|
| 225 |
+
|
| 226 |
+
# Initialize all routes as empty
|
| 227 |
+
for color in colors:
|
| 228 |
+
line, = ax.plot([], [], marker='D', color=color, linewidth=2, alpha=0.8, zorder=2)
|
| 229 |
+
route_lines.append(line)
|
| 230 |
+
route_points.append([])
|
| 231 |
+
|
| 232 |
+
# Extract coordinates for all routes
|
| 233 |
+
all_route_coords = []
|
| 234 |
+
for route in routes:
|
| 235 |
+
coords = []
|
| 236 |
+
for node in route:
|
| 237 |
+
x = G.nodes[node]['x']
|
| 238 |
+
y = G.nodes[node]['y']
|
| 239 |
+
coords.append((x, y))
|
| 240 |
+
all_route_coords.append(coords)
|
| 241 |
+
|
| 242 |
+
# Calculate total number of frames
|
| 243 |
+
total_frames = fps * duration_sec
|
| 244 |
+
|
| 245 |
+
# Animation update function
|
| 246 |
+
def update(frame):
|
| 247 |
+
# Calculate progress (0 to 1)
|
| 248 |
+
progress = frame / total_frames
|
| 249 |
+
|
| 250 |
+
# Update each route
|
| 251 |
+
for i, coords in enumerate(all_route_coords):
|
| 252 |
+
# Determine how many points to show for this route
|
| 253 |
+
route_progress = min(1.0, progress * len(routes) - i)
|
| 254 |
+
|
| 255 |
+
if route_progress <= 0:
|
| 256 |
+
# Route hasn't started yet
|
| 257 |
+
route_lines[i].set_data([], [])
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
# Calculate number of points to show
|
| 261 |
+
num_points = max(2, int(route_progress * len(coords)))
|
| 262 |
+
|
| 263 |
+
# Get coordinates to display
|
| 264 |
+
visible_coords = coords[:num_points]
|
| 265 |
+
xs, ys = zip(*visible_coords) if visible_coords else ([], [])
|
| 266 |
+
|
| 267 |
+
# Update line data
|
| 268 |
+
route_lines[i].set_data(xs, ys)
|
| 269 |
+
|
| 270 |
+
# Update legend based on which routes are visible
|
| 271 |
+
visible_routes = [i for i, line in enumerate(route_lines)
|
| 272 |
+
if len(line.get_xdata()) > 0]
|
| 273 |
+
|
| 274 |
+
if visible_routes:
|
| 275 |
+
# Update legend with only visible routes
|
| 276 |
+
ax.legend([route_lines[i] for i in visible_routes],
|
| 277 |
+
[f'Period {i+1}' for i in visible_routes],
|
| 278 |
+
loc='upper right', prop={'size': 14},
|
| 279 |
+
bbox_to_anchor=(1, 1))
|
| 280 |
+
|
| 281 |
+
return route_lines
|
| 282 |
+
|
| 283 |
+
# Create animation
|
| 284 |
+
ani = animation.FuncAnimation(
|
| 285 |
+
fig, update, frames=total_frames,
|
| 286 |
+
interval=1000/fps, blit=True
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Tight layout
|
| 290 |
+
plt.tight_layout()
|
| 291 |
+
|
| 292 |
+
# Save animation
|
| 293 |
+
if output_file.endswith('.gif'):
|
| 294 |
+
ani.save(output_file, writer='pillow', fps=fps, dpi=150)
|
| 295 |
+
else:
|
| 296 |
+
# For MP4, use ffmpeg
|
| 297 |
+
writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
|
| 298 |
+
ani.save(output_file, writer=writer, dpi=150)
|
| 299 |
+
|
| 300 |
+
plt.close()
|
| 301 |
+
|
| 302 |
+
print(f"Animation saved to {output_file}")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# Example usage:
|
| 306 |
+
# plot_routes_animation(G, routes, colors, "route_animation.gif")
|
| 307 |
+
# For MP4: plot_routes_animation(G, routes, colors, "route_animation.mp4")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def plot_routes(G, routes, colors, output_file):
|
| 311 |
+
"""
|
| 312 |
+
Plot multiple routes on the same map.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
G (networkx.MultiDiGraph): Street network graph
|
| 316 |
+
routes (list): List of routes (each route is a list of nodes)
|
| 317 |
+
colors (list): List of colors for each route
|
| 318 |
+
output_file (str): Output filename
|
| 319 |
+
"""
|
| 320 |
+
# Create figure and axis
|
| 321 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 322 |
+
plt.rcParams['font.family'] = 'Times New Roman'
|
| 323 |
+
|
| 324 |
+
# Plot the base map
|
| 325 |
+
# ox.plot_graph(G, ax=ax, show=False, close=False,
|
| 326 |
+
# edge_color='gray', edge_alpha=0.2, node_size=0)
|
| 327 |
+
|
| 328 |
+
# Create empty list to store route lines for legend
|
| 329 |
+
route_lines = []
|
| 330 |
+
|
| 331 |
+
# Plot each route
|
| 332 |
+
for route, color in zip(routes, colors):
|
| 333 |
+
# Extract the coordinates for each node in the route
|
| 334 |
+
xs = []
|
| 335 |
+
ys = []
|
| 336 |
+
for node in route:
|
| 337 |
+
# Get node coordinates
|
| 338 |
+
x = G.nodes[node]['x']
|
| 339 |
+
y = G.nodes[node]['y']
|
| 340 |
+
xs.append(x)
|
| 341 |
+
ys.append(y)
|
| 342 |
+
|
| 343 |
+
# Plot the route
|
| 344 |
+
line = ax.plot(xs, ys, marker='D', color=color, linewidth=2, alpha=0.2, zorder=2)[0]
|
| 345 |
+
route_lines.append(line)
|
| 346 |
+
|
| 347 |
+
# Add legend
|
| 348 |
+
ax.legend(route_lines,
|
| 349 |
+
[f'Period {i+1}' for i in range(len(routes))],
|
| 350 |
+
loc='upper right', prop={'size': 14},
|
| 351 |
+
bbox_to_anchor=(1.25, 0.85))
|
| 352 |
+
|
| 353 |
+
# Adjust layout and save
|
| 354 |
+
plt.tight_layout()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
plt.savefig(output_file, dpi=100, bbox_inches='tight')
|
| 358 |
+
plt.show()
|
| 359 |
+
plt.close()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
#################plot for 2D trajectories############################
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def plot_task_2(obs_len, gt_seq_len, pred_seq_len, figsize_w=10, title=None):
|
| 368 |
+
"""
|
| 369 |
+
Plot both GT and Pred timelines in the same figure with aligned scales.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
obs_len: Length of observation period
|
| 373 |
+
gt_seq_len: Total sequence length for ground truth
|
| 374 |
+
pred_seq_len: Total sequence length for prediction
|
| 375 |
+
figsize_w: Width of the figure
|
| 376 |
+
title: Optional title for the figure
|
| 377 |
+
"""
|
| 378 |
+
# Use the maximum sequence length to determine the x-axis limits
|
| 379 |
+
max_seq_len = max(gt_seq_len, pred_seq_len)
|
| 380 |
+
|
| 381 |
+
# Create figure with two subplots, one for GT and one for Pred
|
| 382 |
+
fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3})
|
| 383 |
+
plt.rcParams['font.family'] = 'Times New Roman'
|
| 384 |
+
|
| 385 |
+
# Add title if provided
|
| 386 |
+
if title:
|
| 387 |
+
fig.suptitle(title, fontsize=14, fontweight='bold')
|
| 388 |
+
|
| 389 |
+
# Create consistent bar heights and label offset
|
| 390 |
+
bar_height = 0.5
|
| 391 |
+
label_offset = max_seq_len * 0.15 # Proportional offset based on sequence length
|
| 392 |
+
|
| 393 |
+
# GT plot (top)
|
| 394 |
+
y_position = 0
|
| 395 |
+
axes[0].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray')
|
| 396 |
+
axes[0].barh(y_position, gt_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightgreen')
|
| 397 |
+
axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center')
|
| 398 |
+
|
| 399 |
+
# Pred plot (bottom)
|
| 400 |
+
axes[1].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray')
|
| 401 |
+
axes[1].barh(y_position, pred_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightblue')
|
| 402 |
+
axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center')
|
| 403 |
+
|
| 404 |
+
# Configure both axes consistently
|
| 405 |
+
for ax in axes:
|
| 406 |
+
# Set consistent x-limits for alignment
|
| 407 |
+
ax.set_xlim(-label_offset, max_seq_len+1)
|
| 408 |
+
ax.set_ylim(-0.5, 0.5)
|
| 409 |
+
ax.set_yticks([])
|
| 410 |
+
|
| 411 |
+
# Remove the box/frame
|
| 412 |
+
for spine in ax.spines.values():
|
| 413 |
+
spine.set_visible(False)
|
| 414 |
+
|
| 415 |
+
# Add a thin line below the bar for better visibility
|
| 416 |
+
ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5)
|
| 417 |
+
|
| 418 |
+
# Set tick marks for each plot
|
| 419 |
+
axes[0].set_xticks([0, obs_len, gt_seq_len])
|
| 420 |
+
axes[1].set_xticks([0, obs_len, pred_seq_len])
|
| 421 |
+
|
| 422 |
+
plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1])
|
| 423 |
+
return fig
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def plot_task_3(gt_seq_len, GT_start, GT_end, pred_start, pred_end, figsize_w=10, title=None):
|
| 427 |
+
"""
|
| 428 |
+
Plot both GT and Pred timelines in the same figure with aligned scales.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
gt_seq_len: Total sequence length for ground truth
|
| 432 |
+
GT_start: Start position of GT highlight bar
|
| 433 |
+
GT_end: End position of GT highlight bar
|
| 434 |
+
pred_start: Start position of prediction highlight bar
|
| 435 |
+
pred_end: End position of prediction highlight bar
|
| 436 |
+
figsize_w: Width of the figure
|
| 437 |
+
title: Optional title for the figure
|
| 438 |
+
"""
|
| 439 |
+
# Use the maximum sequence length to determine the x-axis limits
|
| 440 |
+
max_seq_len = gt_seq_len
|
| 441 |
+
|
| 442 |
+
# Create figure with two subplots, one for GT and one for Pred
|
| 443 |
+
fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3})
|
| 444 |
+
plt.rcParams['font.family'] = 'Times New Roman'
|
| 445 |
+
|
| 446 |
+
# Add title if provided
|
| 447 |
+
if title:
|
| 448 |
+
fig.suptitle(title, fontsize=14, fontweight='bold')
|
| 449 |
+
|
| 450 |
+
# Create consistent bar heights and label offset
|
| 451 |
+
bar_height = 0.5
|
| 452 |
+
label_offset = max_seq_len * 0.15 # Proportional offset based on sequence length
|
| 453 |
+
|
| 454 |
+
# GT plot (top)
|
| 455 |
+
y_position = 0
|
| 456 |
+
# Plot full lightgray bar for GT
|
| 457 |
+
axes[0].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray')
|
| 458 |
+
# Plot lightgreen bar within GT from GT_start to GT_end
|
| 459 |
+
axes[0].barh(y_position, GT_end - GT_start, height=bar_height, left=GT_start, color='lightgreen')
|
| 460 |
+
axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center')
|
| 461 |
+
|
| 462 |
+
# Pred plot (bottom)
|
| 463 |
+
# Plot full lightgray bar for Pred
|
| 464 |
+
axes[1].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray')
|
| 465 |
+
# Plot lightblue bar within Pred from pred_start to pred_end
|
| 466 |
+
axes[1].barh(y_position, pred_end - pred_start, height=bar_height, left=pred_start, color='lightblue')
|
| 467 |
+
axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center')
|
| 468 |
+
|
| 469 |
+
# Configure both axes consistently
|
| 470 |
+
for ax in axes:
|
| 471 |
+
# Set consistent x-limits for alignment
|
| 472 |
+
ax.set_xlim(-label_offset, max_seq_len+1)
|
| 473 |
+
ax.set_ylim(-0.5, 0.5)
|
| 474 |
+
ax.set_yticks([])
|
| 475 |
+
|
| 476 |
+
# Remove the box/frame
|
| 477 |
+
for spine in ax.spines.values():
|
| 478 |
+
spine.set_visible(False)
|
| 479 |
+
|
| 480 |
+
# Add a thin line below the bar for better visibility
|
| 481 |
+
ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5)
|
| 482 |
+
|
| 483 |
+
# Set tick marks for each plot
|
| 484 |
+
axes[0].set_xticks([0, GT_start, GT_end, gt_seq_len])
|
| 485 |
+
axes[1].set_xticks([0, pred_start, pred_end, gt_seq_len])
|
| 486 |
+
|
| 487 |
+
plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1])
|
| 488 |
+
return fig
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
import string
|
| 492 |
+
def plot_images_with_token(images, tokens, n_rows = 2):
|
| 493 |
+
assert len(images) == len(tokens), "Each image must have a corresponding token"
|
| 494 |
+
|
| 495 |
+
n_images = len(images)
|
| 496 |
+
# Calculate rows and columns for grid layout
|
| 497 |
+
|
| 498 |
+
n_cols = (n_images + 1) // n_rows # Ceiling division to handle odd number of images
|
| 499 |
+
|
| 500 |
+
# Create a figure to display the images
|
| 501 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
|
| 502 |
+
plt.rcParams['font.family'] = 'Times New Roman'
|
| 503 |
+
unique_chars = (sorted(set(string.ascii_lowercase))[:10])
|
| 504 |
+
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
|
| 505 |
+
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
|
| 506 |
+
for char, hue in zip(unique_chars, hues)}
|
| 507 |
+
|
| 508 |
+
# Make axes a 2D array even if there's just one column
|
| 509 |
+
if n_cols == 1:
|
| 510 |
+
axes = axes.reshape(-1, 1)
|
| 511 |
+
|
| 512 |
+
# Flatten axes for easy iteration if there are multiple columns
|
| 513 |
+
axes_flat = axes.flatten()
|
| 514 |
+
|
| 515 |
+
for i, (image, token) in enumerate(zip(images, tokens)):
|
| 516 |
+
if i < len(axes_flat):
|
| 517 |
+
color = color_map[token.lower()]
|
| 518 |
+
axes_flat[i].imshow(image)
|
| 519 |
+
axes_flat[i].set_title(token, color=color, size=50)
|
| 520 |
+
axes_flat[i].axis('off') # Hide axes
|
| 521 |
+
|
| 522 |
+
# Hide any unused subplots
|
| 523 |
+
for j in range(i + 1, n_rows * n_cols):
|
| 524 |
+
if j < len(axes_flat):
|
| 525 |
+
axes_flat[j].axis('off')
|
| 526 |
+
fig.delaxes(axes_flat[j])
|
| 527 |
+
|
| 528 |
+
plt.tight_layout()
|
| 529 |
+
plt.savefig('anchors.jpg', bbox_inches='tight', pad_inches=0)
|
| 530 |
+
#plt.show()
|
| 531 |
+
|
| 532 |
+
|
utils/render.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import matplotlib
|
| 3 |
+
matplotlib.use('Agg') # Non-interactive backend
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.animation as animation
|
| 6 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
from scipy.spatial.transform import Rotation as R
|
| 10 |
+
|
| 11 |
+
def render_smpl(pose_data, output_path, fps=30):
|
| 12 |
+
"""
|
| 13 |
+
Render SMPL 3D pose data to a video file.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
pose_data (np.ndarray): Shape (Frames, 24, 3)
|
| 17 |
+
output_path (str): Path to save the MP4 video.
|
| 18 |
+
fps (int): Frames per second.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# SMPL kinematic tree (approximate for visualization)
|
| 22 |
+
# 0: Pelvis
|
| 23 |
+
# 1: L_Hip, 2: R_Hip, 3: Spine1
|
| 24 |
+
# 4: L_Knee, 5: R_Knee, 6: Spine2
|
| 25 |
+
# 7: L_Ankle, 8: R_Ankle, 9: Spine3
|
| 26 |
+
# 10: L_Foot, 11: R_Foot, 12: Neck
|
| 27 |
+
# 13: L_Collar, 14: R_Collar, 15: Head
|
| 28 |
+
# 16: L_Shoulder, 17: R_Shoulder
|
| 29 |
+
# 18: L_Elbow, 19: R_Elbow
|
| 30 |
+
# 20: L_Wrist, 21: R_Wrist
|
| 31 |
+
# 22: L_Hand, 23: R_Hand
|
| 32 |
+
|
| 33 |
+
# Connectivity for drawing bones
|
| 34 |
+
connections = [
|
| 35 |
+
(0, 1), (0, 2), (0, 3),
|
| 36 |
+
(1, 4), (2, 5), (3, 6),
|
| 37 |
+
(4, 7), (5, 8), (6, 9),
|
| 38 |
+
(7, 10), (8, 11), (9, 12),
|
| 39 |
+
(9, 13), (9, 14), (12, 15),
|
| 40 |
+
(13, 16), (14, 17),
|
| 41 |
+
(16, 18), (17, 19),
|
| 42 |
+
(18, 20), (19, 21),
|
| 43 |
+
(20, 22), (21, 23)
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
fig = plt.figure(figsize=(10, 10))
|
| 47 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 48 |
+
|
| 49 |
+
# --- Alignment & Centering ---
|
| 50 |
+
# 1. Fit plane to feet to find ground orientation
|
| 51 |
+
feet_indices = [10, 11] # L_Foot, R_Foot
|
| 52 |
+
feet_points = pose_data[:, feet_indices, :].reshape(-1, 3)
|
| 53 |
+
|
| 54 |
+
pca = PCA(n_components=3)
|
| 55 |
+
pca.fit(feet_points)
|
| 56 |
+
normal = pca.components_[2] # Component with least variance is the normal
|
| 57 |
+
|
| 58 |
+
# Calculate Body Up vector (Pelvis to Head) to determine correct up direction
|
| 59 |
+
# Pelvis is 0, Head is 15
|
| 60 |
+
pelvis_head_vector = pose_data[:, 15, :] - pose_data[:, 0, :]
|
| 61 |
+
avg_body_up = np.mean(pelvis_head_vector, axis=0)
|
| 62 |
+
|
| 63 |
+
# Ensure normal points in same direction as body up
|
| 64 |
+
if np.dot(normal, avg_body_up) < 0:
|
| 65 |
+
normal = -normal
|
| 66 |
+
|
| 67 |
+
# 2. Compute rotation to align normal to Z-axis [0, 0, 1]
|
| 68 |
+
target_normal = np.array([0, 0, 1])
|
| 69 |
+
|
| 70 |
+
# Use scipy to find rotation
|
| 71 |
+
# We want R such that R * normal = target_normal
|
| 72 |
+
# align_vectors finds rotation that maps vectors_b to vectors_a.
|
| 73 |
+
# So we map normal (b) to target (a).
|
| 74 |
+
rot, rssd = R.align_vectors([target_normal], [normal])
|
| 75 |
+
rot_matrix = rot.as_matrix()
|
| 76 |
+
|
| 77 |
+
# Apply rotation to all points
|
| 78 |
+
# Points are (Frames, Joints, 3). Flatten for transform
|
| 79 |
+
original_shape = pose_data.shape
|
| 80 |
+
flat_data = pose_data.reshape(-1, 3)
|
| 81 |
+
# Apply rotation: (R @ v.T).T = v @ R.T
|
| 82 |
+
# Scipy apply: rot.apply(vectors) handles the broadcasting
|
| 83 |
+
pose_data_rotated = rot.apply(flat_data)
|
| 84 |
+
pose_data = pose_data_rotated.reshape(original_shape)
|
| 85 |
+
|
| 86 |
+
# 3. Center trajectory
|
| 87 |
+
# Center X/Y at 0
|
| 88 |
+
all_x = pose_data[:, :, 0]
|
| 89 |
+
all_y = pose_data[:, :, 1]
|
| 90 |
+
all_z = pose_data[:, :, 2]
|
| 91 |
+
|
| 92 |
+
# Mean of all points as center (or could use root joint mean)
|
| 93 |
+
center_x = np.mean(all_x)
|
| 94 |
+
center_y = np.mean(all_y)
|
| 95 |
+
|
| 96 |
+
pose_data[:, :, 0] -= center_x
|
| 97 |
+
pose_data[:, :, 1] -= center_y
|
| 98 |
+
|
| 99 |
+
# Shift Z so min is 0 (Ground level)
|
| 100 |
+
min_z = np.min(all_z)
|
| 101 |
+
pose_data[:, :, 2] -= min_z
|
| 102 |
+
|
| 103 |
+
# Update bounds variables for plotting
|
| 104 |
+
all_x = pose_data[:, :, 0]
|
| 105 |
+
all_y = pose_data[:, :, 1]
|
| 106 |
+
all_z = pose_data[:, :, 2]
|
| 107 |
+
|
| 108 |
+
mid_x = (np.min(all_x) + np.max(all_x)) / 2
|
| 109 |
+
mid_y = (np.min(all_y) + np.max(all_y)) / 2
|
| 110 |
+
mid_z = (np.min(all_z) + np.max(all_z)) / 2
|
| 111 |
+
|
| 112 |
+
max_range = np.array([np.ptp(all_x), np.ptp(all_y), np.ptp(all_z)]).max() / 2.0
|
| 113 |
+
|
| 114 |
+
# Recalculate bounds after shift
|
| 115 |
+
all_x = pose_data[:, :, 0]
|
| 116 |
+
all_y = pose_data[:, :, 1]
|
| 117 |
+
all_z = pose_data[:, :, 2]
|
| 118 |
+
|
| 119 |
+
# Use (min+max)/2 for center to ensure bounding box is centered
|
| 120 |
+
mid_x = (np.min(all_x) + np.max(all_x)) / 2
|
| 121 |
+
mid_y = (np.min(all_y) + np.max(all_y)) / 2
|
| 122 |
+
mid_z = (np.min(all_z) + np.max(all_z)) / 2
|
| 123 |
+
|
| 124 |
+
# Dynamic ground plane bounds covering all trajectory
|
| 125 |
+
padding = 1.0 # Increase padding
|
| 126 |
+
gp_min_x = np.min(all_x) - padding
|
| 127 |
+
gp_max_x = np.max(all_x) + padding
|
| 128 |
+
gp_min_y = np.min(all_y) - padding
|
| 129 |
+
gp_max_y = np.max(all_y) + padding
|
| 130 |
+
|
| 131 |
+
def update(frame):
|
| 132 |
+
ax.clear()
|
| 133 |
+
ax.set_axis_off()
|
| 134 |
+
|
| 135 |
+
# Transparent gray ground plane at z=0
|
| 136 |
+
x = np.linspace(gp_min_x, gp_max_x, 2)
|
| 137 |
+
y = np.linspace(gp_min_y, gp_max_y, 2)
|
| 138 |
+
X, Y = np.meshgrid(x, y)
|
| 139 |
+
Z = np.zeros_like(X) # Ground at z=0
|
| 140 |
+
|
| 141 |
+
ax.plot_surface(X, Y, Z, color='gray', alpha=0.2, shade=False)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
current_pose = pose_data[frame]
|
| 146 |
+
|
| 147 |
+
# Scatter points for joints
|
| 148 |
+
ax.scatter(current_pose[:, 0], current_pose[:, 1], current_pose[:, 2], c='blue', s=20)
|
| 149 |
+
|
| 150 |
+
# Draw bones
|
| 151 |
+
for start, end in connections:
|
| 152 |
+
xs = [current_pose[start, 0], current_pose[end, 0]]
|
| 153 |
+
ys = [current_pose[start, 1], current_pose[end, 1]]
|
| 154 |
+
zs = [current_pose[start, 2], current_pose[end, 2]]
|
| 155 |
+
ax.plot(xs, ys, zs, c='red')
|
| 156 |
+
|
| 157 |
+
# Set limits
|
| 158 |
+
ax.set_xlim(mid_x - max_range, mid_x + max_range)
|
| 159 |
+
ax.set_ylim(mid_y - max_range, mid_y + max_range)
|
| 160 |
+
ax.set_zlim(mid_z - max_range, mid_z + max_range)
|
| 161 |
+
|
| 162 |
+
# ax.set_xlabel('X')
|
| 163 |
+
# ax.set_ylabel('Y')
|
| 164 |
+
# ax.set_zlabel('Z')
|
| 165 |
+
ax.set_title(f"Frame {frame}")
|
| 166 |
+
|
| 167 |
+
ani = animation.FuncAnimation(fig, update, frames=len(pose_data), interval=1000/fps)
|
| 168 |
+
|
| 169 |
+
# Save using ffmpeg writer
|
| 170 |
+
print(f"Saving video to {output_path}...")
|
| 171 |
+
try:
|
| 172 |
+
if animation.writers.is_available('ffmpeg'):
|
| 173 |
+
writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
|
| 174 |
+
ani.save(output_path, writer=writer)
|
| 175 |
+
else:
|
| 176 |
+
raise RuntimeError("ffmpeg not available")
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"ffmpeg failed or not found ({e}). Using OpenCV fallback...")
|
| 179 |
+
try:
|
| 180 |
+
import cv2
|
| 181 |
+
plt.close(fig) # Close the animation fig
|
| 182 |
+
|
| 183 |
+
# Re-setup figure for opencv loop
|
| 184 |
+
fig = plt.figure(figsize=(10, 10))
|
| 185 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 186 |
+
|
| 187 |
+
# Figure size in pixels approx (10*100 = 1000x1000 usually dpi=100)
|
| 188 |
+
fig.canvas.draw()
|
| 189 |
+
width, height = fig.canvas.get_width_height()
|
| 190 |
+
|
| 191 |
+
# Setup video writer - Try H.264 (avc1) first
|
| 192 |
+
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
| 193 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 194 |
+
|
| 195 |
+
if not out.isOpened():
|
| 196 |
+
print("avc1 failed. Trying h264...")
|
| 197 |
+
fourcc = cv2.VideoWriter_fourcc(*'h264')
|
| 198 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 199 |
+
|
| 200 |
+
if not out.isOpened():
|
| 201 |
+
print("h264 failed. Trying vp80...")
|
| 202 |
+
fourcc = cv2.VideoWriter_fourcc(*'vp80')
|
| 203 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 204 |
+
|
| 205 |
+
if not out.isOpened():
|
| 206 |
+
print("vp80 failed. Trying mp4v (less compatible)...")
|
| 207 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 208 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 209 |
+
|
| 210 |
+
if not out.isOpened():
|
| 211 |
+
raise RuntimeError("Failed to open VideoWriter with any compatible codec.")
|
| 212 |
+
|
| 213 |
+
print("Rendering frames directly to OpenCV VideoWriter...")
|
| 214 |
+
for frame in range(len(pose_data)):
|
| 215 |
+
update(frame)
|
| 216 |
+
fig.canvas.draw()
|
| 217 |
+
|
| 218 |
+
# Convert canvas to image
|
| 219 |
+
# Check for buffer_rgba support (matplotlib 3.x)
|
| 220 |
+
try:
|
| 221 |
+
img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
| 222 |
+
img = img.reshape(height, width, 4)[:, :, :3] # RGBA -> RGB
|
| 223 |
+
except AttributeError:
|
| 224 |
+
# Fallback for older matplotlib or different backend
|
| 225 |
+
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 226 |
+
img = img.reshape(height, width, 3)
|
| 227 |
+
|
| 228 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 229 |
+
|
| 230 |
+
out.write(img)
|
| 231 |
+
|
| 232 |
+
out.release()
|
| 233 |
+
plt.close(fig)
|
| 234 |
+
print("OpenCV fallback rendering complete.")
|
| 235 |
+
|
| 236 |
+
except Exception as cv_e:
|
| 237 |
+
print(f"OpenCV fallback also failed: {cv_e}")
|
| 238 |
+
raise cv_e
|
| 239 |
+
|
| 240 |
+
return output_path
|
| 241 |
+
|
| 242 |
+
|