fanduluhf commited on
Commit
f460dc5
·
verified ·
1 Parent(s): 6936c46

Upload 4 files

Browse files
Files changed (4) hide show
  1. utils/eval.py +180 -0
  2. utils/periodic_detection_helper.py +641 -0
  3. utils/plot.py +532 -0
  4. utils/render.py +242 -0
utils/eval.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+
4
+ def temporal_iou(pred_span, gt_span):
5
+ """
6
+ Calculate 1D Intersection over Union (IoU) between two temporal spans.
7
+ Args:
8
+ pred_span (tuple/list): Predicted temporal span (start, end)
9
+ gt_span (tuple/list): Ground truth temporal span (start, end)
10
+ Returns:
11
+ float: IoU score between 0 and 1
12
+ """
13
+ pred_start, pred_end = pred_span
14
+ gt_start, gt_end = gt_span
15
+
16
+ # Ensure valid spans
17
+ if pred_end < pred_start or gt_end < gt_start:
18
+ raise ValueError("End time cannot be before start time")
19
+
20
+ # Calculate intersection
21
+ intersection_start = max(pred_start, gt_start)
22
+ intersection_end = min(pred_end, gt_end)
23
+
24
+ if intersection_end <= intersection_start:
25
+ return 0.0
26
+
27
+ intersection = intersection_end - intersection_start
28
+
29
+ # Calculate union
30
+ pred_duration = pred_end - pred_start
31
+ gt_duration = gt_end - gt_start
32
+ union = pred_duration + gt_duration - intersection
33
+
34
+ # Calculate IoU
35
+ iou = intersection / union
36
+
37
+ return float(iou)
38
+
39
+ def match_temporal_iou(preds, gts):
40
+ """
41
+ Find optimal matching between predicted and ground truth temporal spans using Hungarian algorithm.
42
+
43
+ Args:
44
+ preds (list): List of predicted temporal spans, each span is [start, end]
45
+ gts (list): List of ground truth temporal spans, each span is [start, end]
46
+
47
+ Returns:
48
+ tuple: (matched_indices, total_iou)
49
+ - matched_indices: List of (pred_idx, gt_idx) pairs
50
+ - total_iou: Sum of IoUs for the matched pairs
51
+ """
52
+ if not preds or not gts:
53
+ return [], 0.0
54
+
55
+ # Calculate cost matrix (negative IoU since Hungarian algorithm minimizes cost)
56
+ cost_matrix = np.zeros((len(preds), len(gts)))
57
+ for i, pred in enumerate(preds):
58
+ for j, gt in enumerate(gts):
59
+ cost_matrix[i, j] = -temporal_iou(pred, gt) # Negative since we want to maximize IoU
60
+
61
+ # Apply Hungarian algorithm
62
+ pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
63
+
64
+ # Get matched pairs and total IoU
65
+ matched_pairs = list(zip(pred_indices, gt_indices))
66
+ total_iou = -cost_matrix[pred_indices, gt_indices].sum() # Convert back to positive
67
+ avg_iou = total_iou / len(gts)
68
+
69
+ return matched_pairs, avg_iou
70
+ '''
71
+ # Example usage:
72
+ if __name__ == "__main__":
73
+ # Example predictions and ground truths
74
+ predictions = [[10, 20], [25, 35], [40, 50], [50, 55]]
75
+ ground_truths = [[15, 25], [30, 40], [45, 55]]
76
+
77
+ # Find optimal matching
78
+ matches, avg_iou = match_temporal_iou(predictions, ground_truths)
79
+
80
+ print("Matched pairs (pred_idx, gt_idx):", matches)
81
+ print("Avg IoU:", avg_iou)
82
+
83
+ # Print individual IoUs for matched pairs
84
+ print("\nIndividual IoUs:")
85
+ for pred_idx, gt_idx in matches:
86
+ iou = temporal_iou(predictions[pred_idx], ground_truths[gt_idx])
87
+ print(f"Pred {pred_idx} - GT {gt_idx}: {iou:.3f}")
88
+ '''
89
+
90
+
91
+ def find_difference_range(s1, s2):
92
+ # Ignore first and last chars by slicing [1:-1]
93
+ s1_mid = s1[1:-1]
94
+ s2_mid = s2[1:-1]
95
+
96
+ n = len(s1_mid)
97
+ if n != len(s2_mid):
98
+ return None # Strings of different lengths
99
+
100
+ # Find start of difference
101
+ start = 0
102
+ while start < n and s1_mid[start] == s2_mid[start]:
103
+ start += 1
104
+
105
+ # Find end of difference (going backwards)
106
+ end = n - 1
107
+ while end >= start and s1_mid[end] == s2_mid[end]:
108
+ end -= 1
109
+
110
+ # Adjust indices to account for ignored first character
111
+ return [start + 1, end + 1] if start <= end else None
112
+
113
+ '''
114
+ # Test with your example
115
+ s1 = "GIBJBIGCHEHCGIBFAD-"
116
+ s2 = "GIBJBIGCHED----FADG"
117
+ result = find_difference_range(s1, s2)
118
+ print(f"Different substrings: '{s1[result[0]:result[1]+1]}' and '{s2[result[0]:result[1]+1]}'")
119
+ '''
120
+
121
+
122
+
123
+ def get_overlapping_substring(s1, s2, best_offset, max_matches):
124
+ len1 = len(s1)
125
+ len2 = len(s2)
126
+ start_index_s1 = -1
127
+ start_index_s2 = -1
128
+
129
+ for i in range(len1):
130
+ j = i - best_offset
131
+ if 0 <= j < len2 and s1[i] == s2[j]:
132
+ start_index_s1 = i
133
+ start_index_s2 = j
134
+ break # Find the first index of match
135
+
136
+ if start_index_s1 != -1:
137
+ return s1[start_index_s1 : start_index_s1 + max_matches]
138
+ else:
139
+ return ""
140
+
141
+ '''
142
+ string1 = 'JBHKHBJGCEID'
143
+ string2 = 'BJGCEIDIALFKCGJ'
144
+ best_offset, max_matches = align_strings(string1, string2)
145
+ overlapping_part = get_overlapping_substring(string1, string2, best_offset, max_matches)
146
+ print("String 1:", string1)
147
+ print("String 2:", string2)
148
+ print("\nOverlapping part:", overlapping_part)
149
+ '''
150
+
151
+
152
+ def find_difference_range(s1, s2):
153
+ # Ignore first and last chars by slicing [1:-1]
154
+ s1_mid = s1[1:-1]
155
+ s2_mid = s2[1:-1]
156
+
157
+ n = len(s1_mid)
158
+ if n != len(s2_mid):
159
+ return None # Strings of different lengths
160
+
161
+ # Find start of difference
162
+ start = 0
163
+ while start < n and s1_mid[start] == s2_mid[start]:
164
+ start += 1
165
+
166
+ # Find end of difference (going backwards)
167
+ end = n - 1
168
+ while end >= start and s1_mid[end] == s2_mid[end]:
169
+ end -= 1
170
+
171
+ # Adjust indices to account for ignored first character
172
+ return [start + 1, end + 1] if start <= end else None
173
+ '''
174
+ # Test with your example
175
+ s1 = "GIBJBIGCHEHCGIBFAD-"
176
+ s2 = "GIBJBIGCHED----FADG"
177
+ result = find_difference_range(s1, s2)
178
+ print(f"Different substrings: '{s1[result[0]:result[1]+1]}' and '{s2[result[0]:result[1]+1]}'")
179
+ print(f"Range index: {result}")
180
+ '''
utils/periodic_detection_helper.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ from itertools import product
4
+
5
+ from tqdm import tqdm
6
+ from sklearn.cluster import KMeans, MeanShift
7
+ from sklearn.preprocessing import StandardScaler
8
+ from typing import List, Tuple
9
+ import matplotlib.pyplot as plt
10
+ from mpl_toolkits.mplot3d import Axes3D
11
+ from sklearn.cluster import KMeans
12
+
13
+ import copy
14
+
15
+
16
+
17
+ def smooth(period_labels, gap = 1):
18
+ period_labels_copy = copy.deepcopy(period_labels)
19
+ for i in range(gap,len(period_labels)-gap):
20
+ counts = np.bincount(period_labels[i-gap:i+gap])
21
+ value = np.argmax(counts)
22
+ period_labels_copy[i] = value
23
+ return period_labels_copy
24
+
25
+ def spatiotemporal_clustering(spatiotemporal_data: np.ndarray, n_clusters: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
26
+ """
27
+ Clusters a 3D spatial trajectory (ignoring timestamps) using DBSCAN and tokenizes it.
28
+
29
+ Args:
30
+ spatiotemporal_data: An array of [frame, n_feats].
31
+ Returns:
32
+ A tuple containing:
33
+ - cluster_labels: A numpy array of cluster labels.
34
+ - hard_tokenized_trajectory: A numpy array representing the hard-encoded tokenized trajectory (cluster labels)
35
+ - soft_tokenized_trajectory: A numpy array representing the soft-encoded tokenized trajectory (vector of normalized distance to all centroids)
36
+ """
37
+
38
+ kmeans = KMeans(n_clusters=n_clusters, random_state=20, n_init='auto')
39
+ cluster_labels = kmeans.fit_predict(spatiotemporal_data)
40
+
41
+ cluster_labels = smooth(cluster_labels, gap = 1)
42
+ # Hard-encoded tokenization for the trajectory using cluster labels.
43
+ hard_tokenized_trajectory = cluster_labels
44
+
45
+ # Get cluster centroids
46
+ centroids = kmeans.cluster_centers_
47
+ n_clusters = len(centroids)
48
+ n_points = len(spatiotemporal_data)
49
+
50
+ # Initialize array for soft tokenization
51
+ soft_tokenized_trajectory = np.zeros((n_points, n_clusters))
52
+
53
+ # Compute Euclidean distances to all centroids for each point
54
+ for i in tqdm(range(n_points)):
55
+ point = spatiotemporal_data[i]
56
+ distances = np.array([np.linalg.norm(point - centroid) for centroid in centroids])
57
+ #'''
58
+ # Convert distances to similarities using exponential decay
59
+ similarities = np.exp(-distances)
60
+
61
+ # Normalize similarities to sum to 1
62
+ soft_tokenized_trajectory[i] = similarities / np.sum(similarities)
63
+
64
+ return cluster_labels, hard_tokenized_trajectory.T, soft_tokenized_trajectory.T, centroids
65
+
66
+
67
+
68
+ def create_path(nodes):
69
+ """
70
+ Create a string representing a path from a list of node sets.
71
+
72
+ Args:
73
+ nodes (list): List of lists of node IDs. Each list of nodes is connected by
74
+ an edge.
75
+
76
+ Returns:
77
+ str: String representing the path.
78
+ """
79
+ result = []
80
+ # Initial edge
81
+ result.append(f"{nodes[0][0]}->{nodes[1][0]}")
82
+
83
+ current_idx = 1
84
+ # Loop until all edges are processed
85
+ while current_idx < len(nodes) - 1:
86
+ sources = nodes[current_idx]
87
+ targets = nodes[current_idx + 1]
88
+
89
+ if len(sources) == 1 and len(targets) == 1:
90
+ # One source, one target
91
+ result.append(f"{sources[0]}->{targets[0]}")
92
+ elif len(sources) == 1:
93
+ # One source, multiple targets
94
+ paths = [f"{sources[0]}->{target}" for target in targets]
95
+ result.append(f"({', '.join(paths)})")
96
+ elif len(targets) == 1:
97
+ # Multiple sources, one target
98
+ paths = [f"{source}->{targets[0]}" for source in sources]
99
+ result.append(f"({', '.join(paths)})")
100
+ else:
101
+ # Multiple sources, multiple targets
102
+ paths = []
103
+ for i in range(len(sources)):
104
+ paths.append(f"{sources[i]}->{targets[i]}")
105
+ result.append(f"({', '.join(paths)})")
106
+
107
+ current_idx += 1
108
+
109
+ return ', '.join(result)
110
+
111
+ def summarize_strings(strings):
112
+ """
113
+ Summarize a list of strings by comparing characters at each position.
114
+
115
+ If all strings have the same character at a position, that character is
116
+ included in the result. If not, an underscore is included.
117
+
118
+ Args:
119
+ strings (list): List of strings to summarize
120
+
121
+ Returns:
122
+ str: Summary of the strings
123
+ """
124
+ if not strings:
125
+ return ""
126
+
127
+ # Get length of shortest string
128
+ min_len = min(len(s) for s in strings)
129
+
130
+ # Compare characters at each position
131
+ result = []
132
+ for i in range(min_len):
133
+ chars = set(s[i] for s in strings)
134
+ # If all strings have the same character at this position, use that
135
+ # character. Otherwise, use an underscore.
136
+ result.append("_" if len(chars) > 1 else strings[0][i])
137
+
138
+ return "".join(result)
139
+
140
+ def find_dash_end_index(strings):
141
+ """
142
+ Find the index of the last dash in the strings that is
143
+ immediately preceded by a letter.
144
+
145
+ Args:
146
+ strings (list): List of strings with same length
147
+
148
+ Returns:
149
+ int: Index of the last dash (if found) or -1
150
+ """
151
+ # Ensure all strings have same length
152
+ if not all(len(s) == len(strings[0]) for s in strings):
153
+ raise ValueError("Strings must be of equal length")
154
+
155
+ # Iterate from the right
156
+ for i in range(len(strings[0])-1, -1, -1):
157
+ for s in strings:
158
+ if s[i] == '-':
159
+ # Check if previous char is letter
160
+ if i > 0 and s[i-1].isalpha():
161
+ return i
162
+ elif not s[i].isalpha(): # Skip if not dash or letter
163
+ continue
164
+
165
+ return -1 # No matching pattern found
166
+
167
+
168
+
169
+ def find_longest_repeated_ends(strings):
170
+ """
171
+ Find the longest prefix and suffix that are identical across all strings.
172
+
173
+ Args:
174
+ strings (list): List of strings to check.
175
+
176
+ Returns:
177
+ int: Length of the longest common prefix and suffix.
178
+ """
179
+ if not strings:
180
+ return 0
181
+
182
+ # Use the first string as a reference
183
+ s = strings[0]
184
+ n = len(s)
185
+ max_len = 0
186
+
187
+ # Iterate over possible prefix/suffix lengths
188
+ for i in range(1, n // 2 + 1):
189
+ prefix = s[:i]
190
+ suffix = s[-i:]
191
+
192
+ # Check if prefix equals suffix and appears in all strings
193
+ if prefix == suffix and all(st.startswith(prefix) and st.endswith(suffix) for st in strings):
194
+ max_len = i
195
+
196
+ return max_len
197
+
198
+
199
+ def create_path(nodes):
200
+ result = []
201
+ result.append(f"{nodes[0][0]}->{nodes[1][0]}")
202
+
203
+ current_idx = 0
204
+ while current_idx < len(nodes) - 1:
205
+ sources = nodes[current_idx]
206
+ targets = nodes[current_idx + 1]
207
+
208
+ if len(sources) == 1 and len(targets) == 1:
209
+ result.append(f"{sources[0]}->{targets[0]}")
210
+ elif len(sources) == 1:
211
+ paths = [f"{sources[0]}->{target}" for target in targets]
212
+ result.append(f"({', '.join(paths)})")
213
+ elif len(targets) == 1:
214
+ paths = [f"{source}->{targets[0]}" for source in sources]
215
+ result.append(f"({', '.join(paths)})")
216
+ else:
217
+ paths = []
218
+ for i in range(len(sources)):
219
+ paths.append(f"{sources[i]}->{targets[i]}")
220
+ result.append(f"({', '.join(paths)})")
221
+
222
+ current_idx += 1
223
+
224
+ return ', '.join(result)
225
+
226
+
227
+ def dominant_fourier_frequency_2d(matrix, lbound=10, ubound=1000):
228
+ """
229
+ Find the dominant Fourier frequencies of a 2D matrix within a window size range.
230
+
231
+ Parameters
232
+ ----------
233
+ matrix : array-like
234
+ The input 2D matrix
235
+ lbound : int, optional
236
+ The lower bound of the window size range. Default is 10.
237
+ ubound : int, optional
238
+ The upper bound of the window size range. Default is 1000.
239
+
240
+ Returns
241
+ -------L
242
+ tuple
243
+ period_condidates
244
+ period_condidates_magnitudes
245
+ """
246
+ # Compute 2D FFT
247
+ fourier = np.fft.fft2(matrix)
248
+
249
+ # Get frequency components for temporal dimensions
250
+ freq_x = np.fft.fftfreq(matrix.shape[1], 1)
251
+
252
+ magnitudes_x = []
253
+ window_sizes_x = []
254
+
255
+ # Analyze horizontal frequencies (x-axis)
256
+ for j, freq in enumerate(freq_x):
257
+ if freq > 0: # Only consider positive frequencies
258
+ window_size = int(1 / freq)
259
+ if window_size >= lbound and window_size < ubound:
260
+ # Sum magnitudes across columns for this frequency
261
+ mag = 0
262
+ for i in range(matrix.shape[0]):
263
+ coef = fourier[i, j]
264
+ mag += math.sqrt(coef.real * coef.real + coef.imag * coef.imag)
265
+ window_sizes_x.append(window_size)
266
+ magnitudes_x.append(mag)
267
+
268
+ '''
269
+ # Handle cases where no valid frequencies are found
270
+ if len(magnitudes_x) == 0:
271
+ warnings.warn(f"Could not extract valid horizontal frequencies. Using window_size={lbound}.")
272
+ period_x = lbound
273
+ else:
274
+ period_x = window_sizes_x[np.argmax(magnitudes_x)]
275
+ '''
276
+
277
+ return np.array(window_sizes_x)[np.argsort(magnitudes_x)[::-1]], np.sort(magnitudes_x)[::-1]
278
+
279
+
280
+
281
+ def dominant_fourier_frequency_1d(time_series, lbound=10, ubound=1000):
282
+ """
283
+ Find the dominant Fourier frequency of the time series within a window size range.
284
+
285
+ Parameters
286
+ ----------
287
+ time_series : array-like
288
+ The input time series.
289
+ lbound : int, optional
290
+ The lower bound of the window size range. Default is 10.
291
+ ubound : int, optional
292
+ The upper bound of the window size range. Default is 1000.
293
+
294
+ Returns
295
+ -------
296
+ The dominant Fourier frequency's corresponding window size within the specified range.
297
+ period_condidates
298
+ period_condidates_magnitudes
299
+ """
300
+
301
+ if time_series.shape[0] < 2 * lbound:
302
+ warnings.warn(
303
+ f"Time series must at least have 2*lbound much data points. Using window_size={time_series.shape[0]}.")
304
+ return time_series.shape[0]
305
+
306
+ fourier = np.fft.fft(time_series)
307
+ freq = np.fft.fftfreq(time_series.shape[0], 1)
308
+
309
+ magnitudes = []
310
+ window_sizes = []
311
+
312
+ for coef, freq in zip(fourier, freq):
313
+ if coef and freq > 0:
314
+ window_size = int(1 / freq)
315
+ mag = math.sqrt(coef.real * coef.real + coef.imag * coef.imag)
316
+
317
+ if window_size >= lbound and window_size < ubound:
318
+ window_sizes.append(window_size)
319
+ magnitudes.append(mag)
320
+
321
+ if len(magnitudes) == 0:
322
+ warnings.warn(f"Could not extract valid frequencies. Using window_size={lbound}.")
323
+ return lbound
324
+
325
+ return np.array(window_sizes)[np.argsort(magnitudes)[::-1]], np.sort(magnitudes)[::-1]
326
+
327
+
328
+ from difflib import SequenceMatcher
329
+ from collections import Counter
330
+
331
+
332
+ def calculate_similarity_score(strings_list):
333
+ """
334
+ Calculate an overall similarity score for a list of strings.
335
+ The score is based on multiple similarity metrics.
336
+
337
+ Args:
338
+ strings_list: List of strings to compare
339
+
340
+ Returns:
341
+ float: Overall similarity score between 0 and 1
342
+ """
343
+ if not strings_list or len(strings_list) < 2:
344
+ return 1.0 # A single string or empty list is perfectly similar to itself
345
+
346
+ n = len(strings_list)
347
+ total_comparisons = n * (n - 1) // 2
348
+
349
+ # Initialize scores for different metrics
350
+ sequence_scores = []
351
+ jaccard_scores = []
352
+ length_ratio_scores = []
353
+
354
+ # Compare each pair of strings
355
+ for i in range(n):
356
+ for j in range(i + 1, n):
357
+ str1 = strings_list[i]
358
+ str2 = strings_list[j]
359
+
360
+ # Sequence Matcher (difflib) score
361
+ sequence_score = SequenceMatcher(None, str1, str2).ratio()
362
+ sequence_scores.append(sequence_score)
363
+
364
+ # Jaccard similarity (character-based)
365
+ set1, set2 = set(str1), set(str2)
366
+ jaccard_score = len(set1.intersection(set2)) / len(set1.union(set2)) if set1 or set2 else 1.0
367
+ jaccard_scores.append(jaccard_score)
368
+
369
+ # Length ratio (shorter/longer)
370
+ length_ratio = min(len(str1), len(str2)) / max(len(str1), len(str2)) if max(len(str1), len(str2)) > 0 else 1.0
371
+ length_ratio_scores.append(length_ratio)
372
+
373
+ # Calculate average scores
374
+ avg_sequence = np.mean(sequence_scores)
375
+ avg_jaccard = np.mean(jaccard_scores)
376
+ avg_length_ratio = np.mean(length_ratio_scores)
377
+
378
+ # Calculate overall score (weighted average of the three metrics)
379
+ overall_score = 0.5 * avg_sequence + 0.3 * avg_jaccard + 0.2 * avg_length_ratio
380
+
381
+ return overall_score
382
+
383
+ def fuse_adjacent(s):
384
+ if not s:
385
+ return ''
386
+ result = s[0]
387
+ for c in s[1:]:
388
+ if c != result[-1]:
389
+ result += c
390
+ return result
391
+
392
+
393
+ def find_longest_identical_pair(s):
394
+ left = 0
395
+ right = len(s) - 1
396
+ id_pair = (None, -1, -1)
397
+ while left < right:
398
+ for i in range(left+1, right+1):
399
+ if s[left] == s[i]:
400
+ if id_pair[2] - id_pair[1] < i - left:
401
+ id_pair = (s[left], left, i)
402
+ left += 1
403
+ if id_pair[0] is None:
404
+ return None # If no identical pair is found
405
+ else:
406
+ return id_pair
407
+
408
+
409
+ '''def number_to_alpha(numbers):
410
+ # Create a mapping of numbers to alphabetic characters
411
+ alpha_map = {i: chr(97 + i) for i in range(26)} # a-z
412
+ alpha_map.update({i + 26: chr(65 + i) for i in range(26)}) # A-Z
413
+
414
+ # Convert numbers to characters
415
+ result = ''
416
+ for num in numbers:
417
+ if num in alpha_map:
418
+ result += alpha_map[num]
419
+ else:
420
+ result += '?' # For numbers outside the range 0-51
421
+ return result'''
422
+
423
+ def number_to_alpha(numbers):
424
+ alpha_map = {i: chr(65 + i) for i in range(26)} # A-Z
425
+
426
+ result = ''
427
+ for num in numbers:
428
+ if num in alpha_map:
429
+ result += alpha_map[num]
430
+ else:
431
+ result += '?' # For numbers outside the range 0-25
432
+ return result
433
+
434
+ def alpha_to_number(sequence):
435
+ return [ord(c.upper()) - ord('A') for c in sequence]
436
+
437
+
438
+
439
+
440
+
441
+ def score_match(chars):
442
+ """Score a column of aligned characters"""
443
+ if '-' in chars:
444
+ return -len([c for c in chars if c == '-']) # Gap penalty
445
+ return sum(1 for i, j in product(chars, chars) if i == j) - len(chars) # Sum of pairwise matches
446
+
447
+ def initialize_matrix(sequences):
448
+ """Initialize the N-dimensional DP matrix and pointers"""
449
+ # Get dimensions for each sequence
450
+ dims = [len(seq) + 1 for seq in sequences]
451
+
452
+ # Create score matrix F and pointer matrix P
453
+ F = np.zeros(dims)
454
+ # Initialize P with lists instead of zeros
455
+ P = np.empty(dims, dtype=object)
456
+ for idx in np.ndindex(*dims):
457
+ P[idx] = []
458
+
459
+ # Initialize edges with gap penalties
460
+ for idx, dim in enumerate(dims):
461
+ # Create slice objects for each dimension
462
+ slices = [slice(None) if i == idx else 0 for i in range(len(dims))]
463
+ indices = range(1, dim)
464
+ F[tuple(slices)] = np.linspace(0, -len(sequences) * dim, dim)
465
+
466
+ return F, P
467
+
468
+ def get_neighbors(current_pos, dims):
469
+ """Get all possible previous positions in the DP matrix"""
470
+ neighbors = []
471
+ for i in range(2 ** len(dims)):
472
+ neighbor = []
473
+ for j, pos in enumerate(current_pos):
474
+ if i & (1 << j):
475
+ if pos > 0: # Check boundary
476
+ neighbor.append(pos - 1)
477
+ else:
478
+ break
479
+ else:
480
+ neighbor.append(pos)
481
+ if len(neighbor) == len(dims):
482
+ neighbors.append(tuple(neighbor))
483
+ return neighbors[1:] # Exclude current position
484
+
485
+ def msa(sequences, gap_penalty=-1):
486
+ """Perform multiple sequence alignment using N-dimensional Needleman-Wunsch"""
487
+ # Initialize matrices
488
+ F, P = initialize_matrix(sequences)
489
+ dims = F.shape
490
+
491
+ # Fill the DP matrix
492
+ for pos in product(*[range(1, dim) for dim in dims]):
493
+
494
+ # Get characters at current position
495
+ chars = [sequences[i][pos[i]-1] for i in range(len(sequences))]
496
+
497
+ # Get all possible previous positions
498
+ neighbors = get_neighbors(pos, dims)
499
+
500
+ # Calculate scores for all possible alignments
501
+ max_score = float('-inf')
502
+ best_moves = []
503
+
504
+ for neighbor in neighbors:
505
+ # Calculate score based on which sequences are aligned
506
+ aligned_chars = []
507
+ for i, (curr, prev) in enumerate(zip(pos, neighbor)):
508
+ if curr != prev:
509
+ aligned_chars.append(sequences[i][curr-1])
510
+ else:
511
+ aligned_chars.append('-')
512
+
513
+ score = F[neighbor] + score_match(aligned_chars)
514
+
515
+ if score > max_score:
516
+ max_score = score
517
+ best_moves = [neighbor]
518
+ elif score == max_score:
519
+ best_moves.append(neighbor)
520
+
521
+ F[pos] = max_score
522
+ P[pos] = best_moves # Store list of best moves
523
+
524
+ # Traceback
525
+ aligned_sequences = [[] for _ in sequences]
526
+ current_pos = tuple(dim-1 for dim in dims)
527
+
528
+ while any(pos > 0 for pos in current_pos):
529
+ # Ensure P[current_pos] contains valid moves
530
+ if not P[current_pos]: # If no moves stored, break
531
+ break
532
+
533
+ prev_pos = P[current_pos][0] # Take first best move
534
+
535
+ # Add characters or gaps based on moves
536
+ for i, (curr, prev) in enumerate(zip(current_pos, prev_pos)):
537
+ if curr != prev:
538
+ aligned_sequences[i].append(sequences[i][curr-1])
539
+ else:
540
+ aligned_sequences[i].append('-')
541
+
542
+ current_pos = prev_pos
543
+
544
+ # Reverse and join sequences
545
+ return [(''.join(seq))[::-1] for seq in aligned_sequences]
546
+
547
+
548
+ # Example usage
549
+ #sequences = ['ACBAFDECBAECFACBA', 'CFACBAFDECBAECFA', 'ECFACBAFBECBAECFA']
550
+ #aligned_sequences = msa(sequences)
551
+ #print('\n'.join(aligned_sequences))
552
+
553
+
554
+ def align_strings(s1, s2):
555
+ len1 = len(s1)
556
+ len2 = len(s2)
557
+ max_matches = 0
558
+ best_offset = 0
559
+ for offset in range(-len2 + 1, len1):
560
+ match_count = 0
561
+ for i in range(len1):
562
+ j = i - offset
563
+ if 0 <= j < len2 and s1[i] == s2[j]:
564
+ match_count += 1
565
+ if match_count > max_matches:
566
+ max_matches = match_count
567
+ best_offset = offset
568
+ return best_offset, max_matches
569
+
570
+
571
+ def get_overlapping_substring(s1, s2, best_offset, max_matches):
572
+ len1 = len(s1)
573
+ len2 = len(s2)
574
+ start_index_s1 = -1
575
+ start_index_s2 = -1
576
+
577
+ for i in range(len1):
578
+ j = i - best_offset
579
+ if 0 <= j < len2 and s1[i] == s2[j]:
580
+ start_index_s1 = i
581
+ start_index_s2 = j
582
+ break # Find the first index of match
583
+
584
+ if start_index_s1 != -1:
585
+ return s1[start_index_s1 : start_index_s1 + max_matches]
586
+ else:
587
+ return ""
588
+
589
+
590
+ '''
591
+ def align_multiple_strings(strings):
592
+ if not strings:
593
+ return []
594
+ reference = strings[0]
595
+ for next_string in strings[1:]:
596
+ ref_align, _ = align_two_strings(reference, next_string)
597
+ # Collapse underscores to match specification
598
+ reference = ''.join(char for i, char in enumerate(ref_align) if char != '_' or (i > 0 and ref_align[i - 1] != '_'))
599
+ return reference
600
+
601
+ def align_two_strings(str1, str2):
602
+ # Alignment code (based on previously defined)
603
+ n, m = len(str1), len(str2)
604
+ dp = [[0] * (m + 1) for _ in range(n + 1)]
605
+ for i in range(n + 1): dp[i][0] = i
606
+ for j in range(m + 1): dp[0][j] = j
607
+ for i in range(1, n + 1):
608
+ for j in range(1, m + 1):
609
+ if str1[i - 1] == str2[j - 1]:
610
+ dp[i][j] = dp[i - 1][j - 1]
611
+ else:
612
+ dp[i][j] = min(dp[i - 1][j - 1], dp[i][j - 1], dp[i - 1][j]) + 1
613
+ aligned1, aligned2 = [], []
614
+ i, j = n, m
615
+ while i > 0 and j > 0:
616
+ if str1[i - 1] == str2[j - 1]:
617
+ aligned1.append(str1[i - 1])
618
+ aligned2.append(str2[j - 1])
619
+ i -= 1
620
+ j -= 1
621
+ elif dp[i][j] == dp[i - 1][j - 1] + 1:
622
+ aligned1.append('_')
623
+ aligned2.append('_')
624
+ i -= 1
625
+ j -= 1
626
+ elif dp[i][j] == dp[i - 1][j]+1:
627
+ aligned1.append('_')
628
+ i -= 1
629
+ else:
630
+ aligned2.append('_')
631
+ j -= 1
632
+ while i > 0:
633
+ aligned1.append('_')
634
+ i -= 1
635
+ while j > 0:
636
+ aligned2.append('_')
637
+ j -= 1
638
+ aligned1.reverse()
639
+ aligned2.reverse()
640
+ return ''.join(aligned1), ''.join(aligned2)
641
+ '''
utils/plot.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from colorsys import hsv_to_rgb
4
+ import string
5
+ import networkx as nx
6
+ import re
7
+
8
+ # import osmnx as ox
9
+ import matplotlib.animation as animation
10
+ from itertools import combinations
11
+ import random
12
+
13
+
14
+ #################plot for transcripts############################
15
+ def plot_string(text, figsize=(18, 4)):
16
+ """
17
+ Plot string with identical colors for identical letters and 0.2x letter width spacing.
18
+ Supports lowercase letters, uppercase letters, and underscores.
19
+ """
20
+ plt.rcParams['font.family'] = 'Times New Roman'
21
+ plt.figure(figsize=figsize)
22
+
23
+ # Include underscore in the character set
24
+ unique_chars = (sorted(set(string.ascii_lowercase))[:10])
25
+ hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
26
+ color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
27
+ for char, hue in zip(unique_chars, hues)}
28
+ unique_chars += ['-']
29
+ # Add special handling for underscore
30
+ color_map['-'] = (0.5, 0.5, 0.5) # Gray color for underscore
31
+
32
+ unique_chars += (sorted(set(string.ascii_lowercase))[10:])
33
+ color_map = {char: hsv_to_rgb(hue, 0.3, 0.8)
34
+ for char, hue in zip(unique_chars, hues)}
35
+
36
+ spacing = 0.1 # Space between letters relative to letter width
37
+ width = 1.0 # Width of each letter
38
+ total_width = len(text) * width * (1 + spacing) - spacing
39
+
40
+ for i, char in enumerate(text):
41
+ x_pos = i * width * (1 + spacing)
42
+
43
+ if char == '_':
44
+ # Draw underscore as a line slightly below the baseline
45
+ plt.plot([x_pos - width/3, x_pos + width/3],
46
+ [-0.2, -0.2],
47
+ color=color_map['_'],
48
+ linewidth=2)
49
+ else:
50
+ # Regular character plotting
51
+ color = color_map[char.lower()]
52
+ plt.text(x_pos, 0, char, fontsize=14, color=color,
53
+ ha='center', va='center')
54
+
55
+ plt.xlim(-width/2, total_width - width/2)
56
+ plt.ylim(-0.5, 0.5)
57
+ plt.axis('off')
58
+ plt.tight_layout()
59
+ plt.savefig('transcript.svg', format='svg')
60
+ #plt.show()
61
+
62
+ # Example usage
63
+ #text = "AACCCBBAAFFDDDEEEECCCCBBAAAEEECCCFFFFAAACCBBAAFFDDDDEEEEECCCCCBBBBAAAEECCCCCFFFAAAACCCBBBBAAAFFFBBBEECCBBBAAEECCCFFAA"
64
+ #plot_string(text)
65
+
66
+
67
+
68
+
69
+ def parse_sequence(sequence):
70
+ """Parse sequence into main path and branches"""
71
+ branches = []
72
+ main_path = []
73
+ current_branch = []
74
+ in_branch = False
75
+
76
+ for part in sequence.split(','):
77
+ part = part.strip()
78
+ if '(' in part:
79
+ in_branch = True
80
+ part = part.replace('(', '')
81
+ if ')' in part:
82
+ in_branch = False
83
+ part = part.replace(')', '')
84
+
85
+ nodes = re.findall(r'([A-Za-z_]\d+)', part)
86
+ if len(nodes) == 2:
87
+ if in_branch:
88
+ current_branch.extend(nodes)
89
+ else:
90
+ if current_branch:
91
+ branches.append(current_branch)
92
+ current_branch = []
93
+ main_path.extend(nodes)
94
+
95
+ if current_branch:
96
+ branches.append(current_branch)
97
+
98
+ return main_path, branches
99
+
100
+ def get_node_number(node):
101
+ """Extract number from node label"""
102
+ return int(re.findall(r'\d+', node)[0])
103
+
104
+ def plot_sequence(sequence, figsize=(8, 3)):
105
+ plt.figure(figsize=figsize)
106
+
107
+ # Assign colors
108
+ unique_chars = sorted(set(string.ascii_uppercase))[:10]
109
+ hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
110
+ color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
111
+ for char, hue in zip(unique_chars, hues)}
112
+
113
+ unique_chars += ['_']
114
+ # Add special handling for underscore
115
+ color_map['_'] = (0.5, 0.5, 0.5) # Gray color for underscore
116
+
117
+ color_map.update({char: hsv_to_rgb(hue, 0.3, 0.8)
118
+ for char, hue in zip(sorted(set(string.ascii_uppercase))[10:], hues)})
119
+
120
+ main_path, branches = parse_sequence(sequence)
121
+ G = nx.DiGraph()
122
+
123
+ # Calculate positions
124
+ pos = {}
125
+ x_spacing = 1
126
+ y_spacing = 0.5
127
+
128
+ # Group nodes by their number
129
+ nodes_by_number = {}
130
+ all_nodes = set(main_path)
131
+ for branch in branches:
132
+ all_nodes.update(branch)
133
+
134
+ for node in all_nodes:
135
+ num = get_node_number(node)
136
+ if num not in nodes_by_number:
137
+ nodes_by_number[num] = []
138
+ nodes_by_number[num].append(node)
139
+
140
+ # Position nodes
141
+ for num in sorted(nodes_by_number.keys()):
142
+ nodes = nodes_by_number[num]
143
+ x = (num - 1) * x_spacing
144
+
145
+ if len(nodes) == 1:
146
+ pos[nodes[0]] = (x, 0)
147
+ else:
148
+ # Center branching nodes vertically
149
+ total_height = (len(nodes) - 1) * y_spacing
150
+ start_y = -total_height / 2
151
+ for i, node in enumerate(sorted(nodes)):
152
+ pos[node] = (x, start_y + i * y_spacing)
153
+
154
+ # Add edges
155
+ for i in range(0, len(main_path)-1, 2):
156
+ G.add_edge(main_path[i], main_path[i+1])
157
+
158
+ for branch in branches:
159
+ for i in range(0, len(branch)-1, 2):
160
+ G.add_edge(branch[i], branch[i+1])
161
+
162
+
163
+ # Draw arrows
164
+ nx.draw_networkx_edges(G, pos, edge_color='gray',
165
+ arrowsize=10, width=1.5,
166
+ arrowstyle='->')
167
+
168
+ # Draw nodes
169
+ for node in G.nodes():
170
+ letter = node[0]
171
+ color = color_map[letter]
172
+
173
+ circle = plt.Circle(pos[node], 0.1,
174
+ color=color, alpha=0.3)
175
+ plt.gca().add_patch(circle)
176
+
177
+ plt.text(pos[node][0], pos[node][1], node[0],
178
+ color=color, fontsize=8,
179
+ ha='center', va='center',
180
+ fontweight='bold')
181
+
182
+ plt.axis('equal')
183
+ plt.axis('off')
184
+ plt.tight_layout()
185
+ plt.savefig('transcript.svg', format='svg')
186
+
187
+ # Example usage
188
+ #sequence = "A1->C2, C2->B3, B3->A4, A4->F5, (F5->B6, F5->D6), (B6->E7, D6->E7), E7->C8, C8->B9, B9->A10, A10->E11, E11->C12, C12->F13"
189
+ #plot_sequence(sequence, (5.4,2))
190
+
191
+ #################plot for transcripts############################
192
+
193
+
194
+
195
+ #################plot for 2D trajectories############################
196
+
197
+
198
+
199
+
200
+
201
+ def plot_routes_animation(G, routes, colors, output_file, fps=20, duration_sec=10):
202
+ """
203
+ Create an animation showing routes appearing dynamically.
204
+
205
+ Args:
206
+ G (networkx.MultiDiGraph): Street network graph
207
+ routes (list): List of routes (each route is a list of nodes)
208
+ colors (list): List of colors for each route
209
+ output_file (str): Output filename (should end with .gif or .mp4)
210
+ fps (int): Frames per second
211
+ duration_sec (int): Total animation duration in seconds
212
+ """
213
+
214
+ # Create figure and axis
215
+ fig, ax = plt.subplots(figsize=(10, 8))
216
+ plt.rcParams['font.family'] = 'Times New Roman'
217
+
218
+ # Plot the base map
219
+ # ox.plot_graph(G, ax=ax, show=False, close=False,
220
+ # edge_color='gray', edge_alpha=0.2, node_size=0)
221
+
222
+ # Create empty route lines
223
+ route_lines = []
224
+ route_points = []
225
+
226
+ # Initialize all routes as empty
227
+ for color in colors:
228
+ line, = ax.plot([], [], marker='D', color=color, linewidth=2, alpha=0.8, zorder=2)
229
+ route_lines.append(line)
230
+ route_points.append([])
231
+
232
+ # Extract coordinates for all routes
233
+ all_route_coords = []
234
+ for route in routes:
235
+ coords = []
236
+ for node in route:
237
+ x = G.nodes[node]['x']
238
+ y = G.nodes[node]['y']
239
+ coords.append((x, y))
240
+ all_route_coords.append(coords)
241
+
242
+ # Calculate total number of frames
243
+ total_frames = fps * duration_sec
244
+
245
+ # Animation update function
246
+ def update(frame):
247
+ # Calculate progress (0 to 1)
248
+ progress = frame / total_frames
249
+
250
+ # Update each route
251
+ for i, coords in enumerate(all_route_coords):
252
+ # Determine how many points to show for this route
253
+ route_progress = min(1.0, progress * len(routes) - i)
254
+
255
+ if route_progress <= 0:
256
+ # Route hasn't started yet
257
+ route_lines[i].set_data([], [])
258
+ continue
259
+
260
+ # Calculate number of points to show
261
+ num_points = max(2, int(route_progress * len(coords)))
262
+
263
+ # Get coordinates to display
264
+ visible_coords = coords[:num_points]
265
+ xs, ys = zip(*visible_coords) if visible_coords else ([], [])
266
+
267
+ # Update line data
268
+ route_lines[i].set_data(xs, ys)
269
+
270
+ # Update legend based on which routes are visible
271
+ visible_routes = [i for i, line in enumerate(route_lines)
272
+ if len(line.get_xdata()) > 0]
273
+
274
+ if visible_routes:
275
+ # Update legend with only visible routes
276
+ ax.legend([route_lines[i] for i in visible_routes],
277
+ [f'Period {i+1}' for i in visible_routes],
278
+ loc='upper right', prop={'size': 14},
279
+ bbox_to_anchor=(1, 1))
280
+
281
+ return route_lines
282
+
283
+ # Create animation
284
+ ani = animation.FuncAnimation(
285
+ fig, update, frames=total_frames,
286
+ interval=1000/fps, blit=True
287
+ )
288
+
289
+ # Tight layout
290
+ plt.tight_layout()
291
+
292
+ # Save animation
293
+ if output_file.endswith('.gif'):
294
+ ani.save(output_file, writer='pillow', fps=fps, dpi=150)
295
+ else:
296
+ # For MP4, use ffmpeg
297
+ writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
298
+ ani.save(output_file, writer=writer, dpi=150)
299
+
300
+ plt.close()
301
+
302
+ print(f"Animation saved to {output_file}")
303
+
304
+
305
+ # Example usage:
306
+ # plot_routes_animation(G, routes, colors, "route_animation.gif")
307
+ # For MP4: plot_routes_animation(G, routes, colors, "route_animation.mp4")
308
+
309
+
310
+ def plot_routes(G, routes, colors, output_file):
311
+ """
312
+ Plot multiple routes on the same map.
313
+
314
+ Args:
315
+ G (networkx.MultiDiGraph): Street network graph
316
+ routes (list): List of routes (each route is a list of nodes)
317
+ colors (list): List of colors for each route
318
+ output_file (str): Output filename
319
+ """
320
+ # Create figure and axis
321
+ fig, ax = plt.subplots(figsize=(8, 6))
322
+ plt.rcParams['font.family'] = 'Times New Roman'
323
+
324
+ # Plot the base map
325
+ # ox.plot_graph(G, ax=ax, show=False, close=False,
326
+ # edge_color='gray', edge_alpha=0.2, node_size=0)
327
+
328
+ # Create empty list to store route lines for legend
329
+ route_lines = []
330
+
331
+ # Plot each route
332
+ for route, color in zip(routes, colors):
333
+ # Extract the coordinates for each node in the route
334
+ xs = []
335
+ ys = []
336
+ for node in route:
337
+ # Get node coordinates
338
+ x = G.nodes[node]['x']
339
+ y = G.nodes[node]['y']
340
+ xs.append(x)
341
+ ys.append(y)
342
+
343
+ # Plot the route
344
+ line = ax.plot(xs, ys, marker='D', color=color, linewidth=2, alpha=0.2, zorder=2)[0]
345
+ route_lines.append(line)
346
+
347
+ # Add legend
348
+ ax.legend(route_lines,
349
+ [f'Period {i+1}' for i in range(len(routes))],
350
+ loc='upper right', prop={'size': 14},
351
+ bbox_to_anchor=(1.25, 0.85))
352
+
353
+ # Adjust layout and save
354
+ plt.tight_layout()
355
+
356
+
357
+ plt.savefig(output_file, dpi=100, bbox_inches='tight')
358
+ plt.show()
359
+ plt.close()
360
+
361
+
362
+
363
+
364
+ #################plot for 2D trajectories############################
365
+
366
+
367
+ def plot_task_2(obs_len, gt_seq_len, pred_seq_len, figsize_w=10, title=None):
368
+ """
369
+ Plot both GT and Pred timelines in the same figure with aligned scales.
370
+
371
+ Args:
372
+ obs_len: Length of observation period
373
+ gt_seq_len: Total sequence length for ground truth
374
+ pred_seq_len: Total sequence length for prediction
375
+ figsize_w: Width of the figure
376
+ title: Optional title for the figure
377
+ """
378
+ # Use the maximum sequence length to determine the x-axis limits
379
+ max_seq_len = max(gt_seq_len, pred_seq_len)
380
+
381
+ # Create figure with two subplots, one for GT and one for Pred
382
+ fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3})
383
+ plt.rcParams['font.family'] = 'Times New Roman'
384
+
385
+ # Add title if provided
386
+ if title:
387
+ fig.suptitle(title, fontsize=14, fontweight='bold')
388
+
389
+ # Create consistent bar heights and label offset
390
+ bar_height = 0.5
391
+ label_offset = max_seq_len * 0.15 # Proportional offset based on sequence length
392
+
393
+ # GT plot (top)
394
+ y_position = 0
395
+ axes[0].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray')
396
+ axes[0].barh(y_position, gt_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightgreen')
397
+ axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center')
398
+
399
+ # Pred plot (bottom)
400
+ axes[1].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray')
401
+ axes[1].barh(y_position, pred_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightblue')
402
+ axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center')
403
+
404
+ # Configure both axes consistently
405
+ for ax in axes:
406
+ # Set consistent x-limits for alignment
407
+ ax.set_xlim(-label_offset, max_seq_len+1)
408
+ ax.set_ylim(-0.5, 0.5)
409
+ ax.set_yticks([])
410
+
411
+ # Remove the box/frame
412
+ for spine in ax.spines.values():
413
+ spine.set_visible(False)
414
+
415
+ # Add a thin line below the bar for better visibility
416
+ ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5)
417
+
418
+ # Set tick marks for each plot
419
+ axes[0].set_xticks([0, obs_len, gt_seq_len])
420
+ axes[1].set_xticks([0, obs_len, pred_seq_len])
421
+
422
+ plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1])
423
+ return fig
424
+
425
+
426
+ def plot_task_3(gt_seq_len, GT_start, GT_end, pred_start, pred_end, figsize_w=10, title=None):
427
+ """
428
+ Plot both GT and Pred timelines in the same figure with aligned scales.
429
+
430
+ Args:
431
+ gt_seq_len: Total sequence length for ground truth
432
+ GT_start: Start position of GT highlight bar
433
+ GT_end: End position of GT highlight bar
434
+ pred_start: Start position of prediction highlight bar
435
+ pred_end: End position of prediction highlight bar
436
+ figsize_w: Width of the figure
437
+ title: Optional title for the figure
438
+ """
439
+ # Use the maximum sequence length to determine the x-axis limits
440
+ max_seq_len = gt_seq_len
441
+
442
+ # Create figure with two subplots, one for GT and one for Pred
443
+ fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3})
444
+ plt.rcParams['font.family'] = 'Times New Roman'
445
+
446
+ # Add title if provided
447
+ if title:
448
+ fig.suptitle(title, fontsize=14, fontweight='bold')
449
+
450
+ # Create consistent bar heights and label offset
451
+ bar_height = 0.5
452
+ label_offset = max_seq_len * 0.15 # Proportional offset based on sequence length
453
+
454
+ # GT plot (top)
455
+ y_position = 0
456
+ # Plot full lightgray bar for GT
457
+ axes[0].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray')
458
+ # Plot lightgreen bar within GT from GT_start to GT_end
459
+ axes[0].barh(y_position, GT_end - GT_start, height=bar_height, left=GT_start, color='lightgreen')
460
+ axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center')
461
+
462
+ # Pred plot (bottom)
463
+ # Plot full lightgray bar for Pred
464
+ axes[1].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray')
465
+ # Plot lightblue bar within Pred from pred_start to pred_end
466
+ axes[1].barh(y_position, pred_end - pred_start, height=bar_height, left=pred_start, color='lightblue')
467
+ axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center')
468
+
469
+ # Configure both axes consistently
470
+ for ax in axes:
471
+ # Set consistent x-limits for alignment
472
+ ax.set_xlim(-label_offset, max_seq_len+1)
473
+ ax.set_ylim(-0.5, 0.5)
474
+ ax.set_yticks([])
475
+
476
+ # Remove the box/frame
477
+ for spine in ax.spines.values():
478
+ spine.set_visible(False)
479
+
480
+ # Add a thin line below the bar for better visibility
481
+ ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5)
482
+
483
+ # Set tick marks for each plot
484
+ axes[0].set_xticks([0, GT_start, GT_end, gt_seq_len])
485
+ axes[1].set_xticks([0, pred_start, pred_end, gt_seq_len])
486
+
487
+ plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1])
488
+ return fig
489
+
490
+
491
+ import string
492
+ def plot_images_with_token(images, tokens, n_rows = 2):
493
+ assert len(images) == len(tokens), "Each image must have a corresponding token"
494
+
495
+ n_images = len(images)
496
+ # Calculate rows and columns for grid layout
497
+
498
+ n_cols = (n_images + 1) // n_rows # Ceiling division to handle odd number of images
499
+
500
+ # Create a figure to display the images
501
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
502
+ plt.rcParams['font.family'] = 'Times New Roman'
503
+ unique_chars = (sorted(set(string.ascii_lowercase))[:10])
504
+ hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
505
+ color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
506
+ for char, hue in zip(unique_chars, hues)}
507
+
508
+ # Make axes a 2D array even if there's just one column
509
+ if n_cols == 1:
510
+ axes = axes.reshape(-1, 1)
511
+
512
+ # Flatten axes for easy iteration if there are multiple columns
513
+ axes_flat = axes.flatten()
514
+
515
+ for i, (image, token) in enumerate(zip(images, tokens)):
516
+ if i < len(axes_flat):
517
+ color = color_map[token.lower()]
518
+ axes_flat[i].imshow(image)
519
+ axes_flat[i].set_title(token, color=color, size=50)
520
+ axes_flat[i].axis('off') # Hide axes
521
+
522
+ # Hide any unused subplots
523
+ for j in range(i + 1, n_rows * n_cols):
524
+ if j < len(axes_flat):
525
+ axes_flat[j].axis('off')
526
+ fig.delaxes(axes_flat[j])
527
+
528
+ plt.tight_layout()
529
+ plt.savefig('anchors.jpg', bbox_inches='tight', pad_inches=0)
530
+ #plt.show()
531
+
532
+
utils/render.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib
3
+ matplotlib.use('Agg') # Non-interactive backend
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.animation as animation
6
+ from mpl_toolkits.mplot3d import Axes3D
7
+ import numpy as np
8
+ from sklearn.decomposition import PCA
9
+ from scipy.spatial.transform import Rotation as R
10
+
11
+ def render_smpl(pose_data, output_path, fps=30):
12
+ """
13
+ Render SMPL 3D pose data to a video file.
14
+
15
+ Args:
16
+ pose_data (np.ndarray): Shape (Frames, 24, 3)
17
+ output_path (str): Path to save the MP4 video.
18
+ fps (int): Frames per second.
19
+ """
20
+
21
+ # SMPL kinematic tree (approximate for visualization)
22
+ # 0: Pelvis
23
+ # 1: L_Hip, 2: R_Hip, 3: Spine1
24
+ # 4: L_Knee, 5: R_Knee, 6: Spine2
25
+ # 7: L_Ankle, 8: R_Ankle, 9: Spine3
26
+ # 10: L_Foot, 11: R_Foot, 12: Neck
27
+ # 13: L_Collar, 14: R_Collar, 15: Head
28
+ # 16: L_Shoulder, 17: R_Shoulder
29
+ # 18: L_Elbow, 19: R_Elbow
30
+ # 20: L_Wrist, 21: R_Wrist
31
+ # 22: L_Hand, 23: R_Hand
32
+
33
+ # Connectivity for drawing bones
34
+ connections = [
35
+ (0, 1), (0, 2), (0, 3),
36
+ (1, 4), (2, 5), (3, 6),
37
+ (4, 7), (5, 8), (6, 9),
38
+ (7, 10), (8, 11), (9, 12),
39
+ (9, 13), (9, 14), (12, 15),
40
+ (13, 16), (14, 17),
41
+ (16, 18), (17, 19),
42
+ (18, 20), (19, 21),
43
+ (20, 22), (21, 23)
44
+ ]
45
+
46
+ fig = plt.figure(figsize=(10, 10))
47
+ ax = fig.add_subplot(111, projection='3d')
48
+
49
+ # --- Alignment & Centering ---
50
+ # 1. Fit plane to feet to find ground orientation
51
+ feet_indices = [10, 11] # L_Foot, R_Foot
52
+ feet_points = pose_data[:, feet_indices, :].reshape(-1, 3)
53
+
54
+ pca = PCA(n_components=3)
55
+ pca.fit(feet_points)
56
+ normal = pca.components_[2] # Component with least variance is the normal
57
+
58
+ # Calculate Body Up vector (Pelvis to Head) to determine correct up direction
59
+ # Pelvis is 0, Head is 15
60
+ pelvis_head_vector = pose_data[:, 15, :] - pose_data[:, 0, :]
61
+ avg_body_up = np.mean(pelvis_head_vector, axis=0)
62
+
63
+ # Ensure normal points in same direction as body up
64
+ if np.dot(normal, avg_body_up) < 0:
65
+ normal = -normal
66
+
67
+ # 2. Compute rotation to align normal to Z-axis [0, 0, 1]
68
+ target_normal = np.array([0, 0, 1])
69
+
70
+ # Use scipy to find rotation
71
+ # We want R such that R * normal = target_normal
72
+ # align_vectors finds rotation that maps vectors_b to vectors_a.
73
+ # So we map normal (b) to target (a).
74
+ rot, rssd = R.align_vectors([target_normal], [normal])
75
+ rot_matrix = rot.as_matrix()
76
+
77
+ # Apply rotation to all points
78
+ # Points are (Frames, Joints, 3). Flatten for transform
79
+ original_shape = pose_data.shape
80
+ flat_data = pose_data.reshape(-1, 3)
81
+ # Apply rotation: (R @ v.T).T = v @ R.T
82
+ # Scipy apply: rot.apply(vectors) handles the broadcasting
83
+ pose_data_rotated = rot.apply(flat_data)
84
+ pose_data = pose_data_rotated.reshape(original_shape)
85
+
86
+ # 3. Center trajectory
87
+ # Center X/Y at 0
88
+ all_x = pose_data[:, :, 0]
89
+ all_y = pose_data[:, :, 1]
90
+ all_z = pose_data[:, :, 2]
91
+
92
+ # Mean of all points as center (or could use root joint mean)
93
+ center_x = np.mean(all_x)
94
+ center_y = np.mean(all_y)
95
+
96
+ pose_data[:, :, 0] -= center_x
97
+ pose_data[:, :, 1] -= center_y
98
+
99
+ # Shift Z so min is 0 (Ground level)
100
+ min_z = np.min(all_z)
101
+ pose_data[:, :, 2] -= min_z
102
+
103
+ # Update bounds variables for plotting
104
+ all_x = pose_data[:, :, 0]
105
+ all_y = pose_data[:, :, 1]
106
+ all_z = pose_data[:, :, 2]
107
+
108
+ mid_x = (np.min(all_x) + np.max(all_x)) / 2
109
+ mid_y = (np.min(all_y) + np.max(all_y)) / 2
110
+ mid_z = (np.min(all_z) + np.max(all_z)) / 2
111
+
112
+ max_range = np.array([np.ptp(all_x), np.ptp(all_y), np.ptp(all_z)]).max() / 2.0
113
+
114
+ # Recalculate bounds after shift
115
+ all_x = pose_data[:, :, 0]
116
+ all_y = pose_data[:, :, 1]
117
+ all_z = pose_data[:, :, 2]
118
+
119
+ # Use (min+max)/2 for center to ensure bounding box is centered
120
+ mid_x = (np.min(all_x) + np.max(all_x)) / 2
121
+ mid_y = (np.min(all_y) + np.max(all_y)) / 2
122
+ mid_z = (np.min(all_z) + np.max(all_z)) / 2
123
+
124
+ # Dynamic ground plane bounds covering all trajectory
125
+ padding = 1.0 # Increase padding
126
+ gp_min_x = np.min(all_x) - padding
127
+ gp_max_x = np.max(all_x) + padding
128
+ gp_min_y = np.min(all_y) - padding
129
+ gp_max_y = np.max(all_y) + padding
130
+
131
+ def update(frame):
132
+ ax.clear()
133
+ ax.set_axis_off()
134
+
135
+ # Transparent gray ground plane at z=0
136
+ x = np.linspace(gp_min_x, gp_max_x, 2)
137
+ y = np.linspace(gp_min_y, gp_max_y, 2)
138
+ X, Y = np.meshgrid(x, y)
139
+ Z = np.zeros_like(X) # Ground at z=0
140
+
141
+ ax.plot_surface(X, Y, Z, color='gray', alpha=0.2, shade=False)
142
+
143
+
144
+
145
+ current_pose = pose_data[frame]
146
+
147
+ # Scatter points for joints
148
+ ax.scatter(current_pose[:, 0], current_pose[:, 1], current_pose[:, 2], c='blue', s=20)
149
+
150
+ # Draw bones
151
+ for start, end in connections:
152
+ xs = [current_pose[start, 0], current_pose[end, 0]]
153
+ ys = [current_pose[start, 1], current_pose[end, 1]]
154
+ zs = [current_pose[start, 2], current_pose[end, 2]]
155
+ ax.plot(xs, ys, zs, c='red')
156
+
157
+ # Set limits
158
+ ax.set_xlim(mid_x - max_range, mid_x + max_range)
159
+ ax.set_ylim(mid_y - max_range, mid_y + max_range)
160
+ ax.set_zlim(mid_z - max_range, mid_z + max_range)
161
+
162
+ # ax.set_xlabel('X')
163
+ # ax.set_ylabel('Y')
164
+ # ax.set_zlabel('Z')
165
+ ax.set_title(f"Frame {frame}")
166
+
167
+ ani = animation.FuncAnimation(fig, update, frames=len(pose_data), interval=1000/fps)
168
+
169
+ # Save using ffmpeg writer
170
+ print(f"Saving video to {output_path}...")
171
+ try:
172
+ if animation.writers.is_available('ffmpeg'):
173
+ writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
174
+ ani.save(output_path, writer=writer)
175
+ else:
176
+ raise RuntimeError("ffmpeg not available")
177
+ except Exception as e:
178
+ print(f"ffmpeg failed or not found ({e}). Using OpenCV fallback...")
179
+ try:
180
+ import cv2
181
+ plt.close(fig) # Close the animation fig
182
+
183
+ # Re-setup figure for opencv loop
184
+ fig = plt.figure(figsize=(10, 10))
185
+ ax = fig.add_subplot(111, projection='3d')
186
+
187
+ # Figure size in pixels approx (10*100 = 1000x1000 usually dpi=100)
188
+ fig.canvas.draw()
189
+ width, height = fig.canvas.get_width_height()
190
+
191
+ # Setup video writer - Try H.264 (avc1) first
192
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
193
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
194
+
195
+ if not out.isOpened():
196
+ print("avc1 failed. Trying h264...")
197
+ fourcc = cv2.VideoWriter_fourcc(*'h264')
198
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
199
+
200
+ if not out.isOpened():
201
+ print("h264 failed. Trying vp80...")
202
+ fourcc = cv2.VideoWriter_fourcc(*'vp80')
203
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
204
+
205
+ if not out.isOpened():
206
+ print("vp80 failed. Trying mp4v (less compatible)...")
207
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
208
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
209
+
210
+ if not out.isOpened():
211
+ raise RuntimeError("Failed to open VideoWriter with any compatible codec.")
212
+
213
+ print("Rendering frames directly to OpenCV VideoWriter...")
214
+ for frame in range(len(pose_data)):
215
+ update(frame)
216
+ fig.canvas.draw()
217
+
218
+ # Convert canvas to image
219
+ # Check for buffer_rgba support (matplotlib 3.x)
220
+ try:
221
+ img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
222
+ img = img.reshape(height, width, 4)[:, :, :3] # RGBA -> RGB
223
+ except AttributeError:
224
+ # Fallback for older matplotlib or different backend
225
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
226
+ img = img.reshape(height, width, 3)
227
+
228
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
229
+
230
+ out.write(img)
231
+
232
+ out.release()
233
+ plt.close(fig)
234
+ print("OpenCV fallback rendering complete.")
235
+
236
+ except Exception as cv_e:
237
+ print(f"OpenCV fallback also failed: {cv_e}")
238
+ raise cv_e
239
+
240
+ return output_path
241
+
242
+