Junyi42 commited on
Commit
7ab7557
·
1 Parent(s): 2ddc0fc

refactor the code

Browse files
Files changed (2) hide show
  1. app.py +5 -132
  2. vis_st4rtrack.py +153 -102
app.py CHANGED
@@ -1,132 +1,15 @@
1
  import random
2
  import threading
3
  import psutil
4
- import numpy as onp
5
  import fastapi
6
  import gradio as gr
7
  import uvicorn
8
- import os
9
- from pathlib import Path
10
- from glob import glob
11
- import cv2
12
- import numpy as np
13
- import imageio.v3 as iio
14
 
15
  from viser_proxy_manager import ViserProxyManager
16
- from vis_st4rtrack import visualize_st4rtrack, log_memory_usage
17
 
18
  # Global cache for loaded data
19
- global_data_cache = {
20
- 'traj_3d_head1': None,
21
- 'traj_3d_head2': None,
22
- 'conf_mask_head1': None,
23
- 'conf_mask_head2': None,
24
- 'masks': None,
25
- 'loaded': False
26
- }
27
-
28
- def load_data_once(traj_path="results", use_float16=True):
29
- """Load data once and store in global cache."""
30
- if global_data_cache['loaded']:
31
- return
32
-
33
- log_memory_usage("before loading data")
34
-
35
- # Load masks
36
- mask_folder = './train'
37
- masks_paths = sorted(glob(mask_folder + '/*.jpg'))
38
- masks = None
39
-
40
- if masks_paths:
41
-
42
-
43
- masks = [iio.imread(p) for p in masks_paths]
44
- masks = np.stack(masks, axis=0)
45
- # Convert masks to binary (0 or 1)
46
- masks = (masks < 1).astype(np.float32)
47
- masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
48
- print(f"Original masks shape: {masks.shape}")
49
- else:
50
- print("No masks found. Will create default masks when needed.")
51
-
52
- global_data_cache['masks'] = masks
53
-
54
- if Path(traj_path).is_dir():
55
- # Load head1 data
56
- traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
57
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
58
- conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
59
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
60
-
61
- # Load head2 data
62
- traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
63
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
64
- conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
65
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
66
-
67
- # Process head1
68
- if traj_3d_paths_head1:
69
- if use_float16:
70
- traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
71
- else:
72
- traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
73
-
74
- log_memory_usage("after loading head1 data")
75
-
76
- h, w, _ = traj_3d_head1.shape[1:]
77
-
78
- # If masks is None, create default masks (all ones)
79
- if masks is None:
80
- num_frames = traj_3d_head1.shape[0]
81
- masks = np.ones((num_frames, h, w), dtype=bool)
82
- print(f"Created default masks with shape: {masks.shape}")
83
- global_data_cache['masks'] = masks
84
- else:
85
- # Resize masks to match trajectory dimensions using nearest neighbor interpolation
86
- masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
87
- for i in range(masks.shape[0]):
88
- masks_resized[i] = cv2.resize(
89
- masks[i].astype(np.uint8),
90
- (w, h),
91
- interpolation=cv2.INTER_NEAREST
92
- ).astype(bool)
93
-
94
- print(f"Resized masks shape: {masks_resized.shape}")
95
- global_data_cache['masks'] = masks_resized
96
-
97
- # Reshape trajectory data
98
- traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
99
- global_data_cache['traj_3d_head1'] = traj_3d_head1
100
-
101
- if conf_paths_head1:
102
- conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
103
- conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
104
- conf_mask_head1 = conf_head1 > 1.0 # Default threshold
105
- global_data_cache['conf_mask_head1'] = conf_mask_head1
106
-
107
- # Process head2
108
- if traj_3d_paths_head2:
109
- if use_float16:
110
- traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
111
- else:
112
- traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
113
-
114
- log_memory_usage("after loading head2 data")
115
- raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
116
-
117
- traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
118
- global_data_cache['traj_3d_head2'] = traj_3d_head2
119
-
120
- if conf_paths_head2:
121
- conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
122
- conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
123
- conf_mask_head2 = conf_head2 > 1.0 # Default threshold
124
- global_data_cache['conf_mask_head2'] = conf_mask_head2
125
-
126
- global_data_cache['loaded'] = True
127
- global_data_cache['raw_video'] = raw_video
128
- log_memory_usage("after loading all data")
129
-
130
 
131
  def check_ram_usage(threshold_percent=90):
132
  """Check if RAM usage is above the threshold.
@@ -143,20 +26,16 @@ def check_ram_usage(threshold_percent=90):
143
 
144
 
145
  def main() -> None:
146
- # Load data once at startup
147
- load_data_once(use_float16=True)
 
148
 
149
  app = fastapi.FastAPI()
150
  viser_manager = ViserProxyManager(app)
151
 
152
-
153
  # Create a Gradio interface with title, iframe, and buttons
154
  with gr.Blocks(title="Viser Viewer") as demo:
155
- # Add a title and description
156
- # gr.Markdown("# 🌐 Viser Interactive Viewer Test")
157
-
158
  # Add the iframe with a border
159
- # add_sphere_btn = gr.Button("Add Random Sphere")
160
  iframe_html = gr.HTML("")
161
  status_text = gr.Markdown("") # Add status text component
162
 
@@ -194,13 +73,8 @@ def main() -> None:
194
  "use_float16": True,
195
  "preloaded_data": global_data_cache, # Pass the preloaded data
196
  "color_code": "jet",
197
- # "blue_rgb": (0.22, 0.82, 1.0), # #37D2FF
198
- # "red_rgb": (1.0, 0.39, 0.22), # #FF6337
199
  "blue_rgb": (0.0, 0.149, 0.463), # #002676
200
  "red_rgb": (0.769, 0.510, 0.055), # #FDB515
201
- # "color_code": "rainbow",
202
- # "blue_rgb": (0., 0., 1.),
203
- # "red_rgb": (1., 0., 0.),
204
  "blend_ratio": 0.7
205
  },
206
  daemon=True
@@ -217,7 +91,6 @@ def main() -> None:
217
  loading="lazy"
218
  ></iframe>
219
  """, "**System Status:** Visualization loaded successfully."
220
-
221
 
222
  @demo.unload
223
  def stop(request: gr.Request):
 
1
  import random
2
  import threading
3
  import psutil
 
4
  import fastapi
5
  import gradio as gr
6
  import uvicorn
 
 
 
 
 
 
7
 
8
  from viser_proxy_manager import ViserProxyManager
9
+ from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage
10
 
11
  # Global cache for loaded data
12
+ global_data_cache = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def check_ram_usage(threshold_percent=90):
15
  """Check if RAM usage is above the threshold.
 
26
 
27
 
28
  def main() -> None:
29
+ # Load data once at startup using the function from vis_st4rtrack.py
30
+ global global_data_cache
31
+ global_data_cache = load_trajectory_data(use_float16=True, max_frames=32)
32
 
33
  app = fastapi.FastAPI()
34
  viser_manager = ViserProxyManager(app)
35
 
 
36
  # Create a Gradio interface with title, iframe, and buttons
37
  with gr.Blocks(title="Viser Viewer") as demo:
 
 
 
38
  # Add the iframe with a border
 
39
  iframe_html = gr.HTML("")
40
  status_text = gr.Markdown("") # Add status text component
41
 
 
73
  "use_float16": True,
74
  "preloaded_data": global_data_cache, # Pass the preloaded data
75
  "color_code": "jet",
 
 
76
  "blue_rgb": (0.0, 0.149, 0.463), # #002676
77
  "red_rgb": (0.769, 0.510, 0.055), # #FDB515
 
 
 
78
  "blend_ratio": 0.7
79
  },
80
  daemon=True
 
91
  loading="lazy"
92
  ></iframe>
93
  """, "**System Status:** Visualization loaded successfully."
 
94
 
95
  @demo.unload
96
  def stop(request: gr.Request):
vis_st4rtrack.py CHANGED
@@ -28,6 +28,138 @@ def log_memory_usage(message=""):
28
  memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
29
  print(f"Memory usage {message}: {memory_mb:.2f} MB")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def visualize_st4rtrack(
32
  traj_path: str = "results",
33
  up_dir: str = "-z", # should be +z or -z
@@ -81,7 +213,26 @@ def visualize_st4rtrack(
81
  format="jpeg"
82
  )
83
 
84
- # Create a function to process video frames and resize them
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def process_video_frame(frame_idx):
86
  if raw_video is None:
87
  return np.zeros((video_height, video_width, 3), dtype=np.uint8)
@@ -121,106 +272,6 @@ def visualize_st4rtrack(
121
  server.scene.set_up_direction(up_dir)
122
  print("Setting up visualization!")
123
 
124
- # Use preloaded data if available
125
- if preloaded_data and preloaded_data['loaded']:
126
- traj_3d_head1 = preloaded_data['traj_3d_head1']
127
- traj_3d_head2 = preloaded_data['traj_3d_head2']
128
- conf_mask_head1 = preloaded_data['conf_mask_head1']
129
- conf_mask_head2 = preloaded_data['conf_mask_head2']
130
- masks = preloaded_data['masks']
131
- raw_video = preloaded_data['raw_video']
132
- print("Using preloaded data!")
133
- else:
134
- # Original data loading code (as a fallback)
135
- print("No preloaded data available, loading from files...")
136
- # Load both head1 and head2 data
137
- traj_3d_head1 = None
138
- traj_3d_head2 = None
139
- conf_mask_head1 = None
140
- conf_mask_head2 = None
141
- masks = None
142
- if mask_folder is not None:
143
- masks_paths = sorted(glob(mask_folder + '/*.jpg'))
144
- masks = [iio.imread(p) for p in masks_paths]
145
- masks = np.stack(masks, axis=0)
146
- # Convert masks to binary (0 or 1)
147
- masks = (masks < 1).astype(np.float32)
148
- masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
149
- print(f"Original masks shape: {masks.shape}")
150
-
151
- if Path(traj_path).is_dir():
152
- # Load head1 data
153
- traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
154
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
155
- conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
156
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
157
-
158
- # Load head2 data
159
- traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
160
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
161
- conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
162
- key=lambda x: int(x.split('_p')[-1].split('.')[0]))
163
-
164
- # Process head1
165
- if traj_3d_paths_head1:
166
- if use_float16:
167
- traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
168
- else:
169
- traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
170
-
171
- log_memory_usage("after loading head1 data")
172
-
173
- h, w, _ = traj_3d_head1.shape[1:]
174
- num_frames = traj_3d_head1.shape[0]
175
-
176
- # If masks is None, create default masks (all ones)
177
- if masks is None:
178
- masks = np.ones((num_frames, h, w), dtype=bool)
179
- print(f"Created default masks with shape: {masks.shape}")
180
- else:
181
- # Resize masks to match trajectory dimensions using nearest neighbor interpolation
182
- masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
183
- for i in range(masks.shape[0]):
184
- masks_resized[i] = cv2.resize(
185
- masks[i].astype(np.uint8),
186
- (w, h),
187
- interpolation=cv2.INTER_NEAREST
188
- ).astype(bool)
189
-
190
- print(f"Resized masks shape: {masks_resized.shape}")
191
- masks = masks_resized
192
-
193
- # Reshape trajectory data
194
- traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
195
-
196
- if conf_paths_head1:
197
- conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
198
- conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
199
- conf_head1 = conf_head1.mean(axis=0)
200
- # repeat the conf_head1 to match the number of frames in the dimension 0
201
- conf_head1 = np.tile(conf_head1, (num_frames, 1))
202
- # Convert to float32 before calculating percentile to avoid overflow
203
- conf_thre = np.percentile(conf_head1.astype(np.float32), conf_thre_percentile)
204
- conf_mask_head1 = conf_head1 > conf_thre
205
-
206
- # Process head2
207
- if traj_3d_paths_head2:
208
- if use_float16:
209
- traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
210
- else:
211
- traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
212
-
213
- log_memory_usage("after loading head2 data")
214
- raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
215
-
216
- traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
217
- if conf_paths_head2:
218
- conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
219
- conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
220
- # set conf thre to be 10 percentile of the conf_head2, for each frame
221
- conf_thre = np.percentile(conf_head2.astype(np.float32), conf_thre_percentile, axis=1)
222
- conf_mask_head2 = conf_head2 > conf_thre[:, None]
223
-
224
  # Add visualization controls
225
  with server.gui.add_folder("Visualization"):
226
  gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
@@ -286,7 +337,7 @@ def visualize_st4rtrack(
286
  min=1,
287
  max=num_frames,
288
  step=1,
289
- initial_value=1,
290
  disabled=True, # Initially disabled
291
  )
292
 
 
28
  memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
29
  print(f"Memory usage {message}: {memory_mb:.2f} MB")
30
 
31
+ def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train'):
32
+ """Load trajectory data from files.
33
+
34
+ Args:
35
+ traj_path: Path to the directory containing trajectory data
36
+ use_float16: Whether to convert data to float16 to save memory
37
+ max_frames: Maximum number of frames to load (None for all)
38
+ mask_folder: Path to the directory containing mask images
39
+
40
+ Returns:
41
+ A dictionary containing loaded data
42
+ """
43
+ log_memory_usage("before loading data")
44
+
45
+ data_cache = {
46
+ 'traj_3d_head1': None,
47
+ 'traj_3d_head2': None,
48
+ 'conf_mask_head1': None,
49
+ 'conf_mask_head2': None,
50
+ 'masks': None,
51
+ 'raw_video': None,
52
+ 'loaded': False
53
+ }
54
+
55
+ # Load masks
56
+ masks_paths = sorted(glob(mask_folder + '/*.jpg'))
57
+ masks = None
58
+
59
+ if masks_paths:
60
+ masks = [iio.imread(p) for p in masks_paths]
61
+ masks = np.stack(masks, axis=0)
62
+ # Convert masks to binary (0 or 1)
63
+ masks = (masks < 1).astype(np.float32)
64
+ masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
65
+ print(f"Original masks shape: {masks.shape}")
66
+ else:
67
+ print("No masks found. Will create default masks when needed.")
68
+
69
+ data_cache['masks'] = masks
70
+
71
+ if Path(traj_path).is_dir():
72
+ # Find all trajectory files
73
+ traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
74
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
75
+ conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
76
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
77
+
78
+ traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
79
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
80
+ conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
81
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
82
+
83
+ # Limit number of frames if specified
84
+ if max_frames is not None:
85
+ traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames]
86
+ conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else []
87
+ traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames]
88
+ conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else []
89
+
90
+ # Process head1
91
+ if traj_3d_paths_head1:
92
+ if use_float16:
93
+ traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
94
+ else:
95
+ traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
96
+
97
+ log_memory_usage("after loading head1 data")
98
+
99
+ h, w, _ = traj_3d_head1.shape[1:]
100
+ num_frames = traj_3d_head1.shape[0]
101
+
102
+ # If masks is None, create default masks (all ones)
103
+ if masks is None:
104
+ masks = np.ones((num_frames, h, w), dtype=bool)
105
+ print(f"Created default masks with shape: {masks.shape}")
106
+ data_cache['masks'] = masks
107
+ else:
108
+ # Resize masks to match trajectory dimensions using nearest neighbor interpolation
109
+ masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
110
+ for i in range(masks.shape[0]):
111
+ masks_resized[i] = cv2.resize(
112
+ masks[i].astype(np.uint8),
113
+ (w, h),
114
+ interpolation=cv2.INTER_NEAREST
115
+ ).astype(bool)
116
+
117
+ print(f"Resized masks shape: {masks_resized.shape}")
118
+ data_cache['masks'] = masks_resized
119
+
120
+ # Reshape trajectory data
121
+ traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
122
+ data_cache['traj_3d_head1'] = traj_3d_head1
123
+
124
+ if conf_paths_head1:
125
+ conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
126
+ conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
127
+ conf_head1 = conf_head1.mean(axis=0)
128
+ # repeat the conf_head1 to match the number of frames in the dimension 0
129
+ conf_head1 = np.tile(conf_head1, (num_frames, 1))
130
+ # Convert to float32 before calculating percentile to avoid overflow
131
+ conf_thre = np.percentile(conf_head1.astype(np.float32), 1) # Default percentile
132
+ conf_mask_head1 = conf_head1 > conf_thre
133
+ data_cache['conf_mask_head1'] = conf_mask_head1
134
+
135
+ # Process head2
136
+ if traj_3d_paths_head2:
137
+ if use_float16:
138
+ traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
139
+ else:
140
+ traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
141
+
142
+ log_memory_usage("after loading head2 data")
143
+
144
+ # Store raw video data
145
+ raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
146
+ data_cache['raw_video'] = raw_video
147
+
148
+ traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
149
+ data_cache['traj_3d_head2'] = traj_3d_head2
150
+
151
+ if conf_paths_head2:
152
+ conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
153
+ conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
154
+ # set conf thre to be 1 percentile of the conf_head2, for each frame
155
+ conf_thre = np.percentile(conf_head2.astype(np.float32), 1, axis=1)
156
+ conf_mask_head2 = conf_head2 > conf_thre[:, None]
157
+ data_cache['conf_mask_head2'] = conf_mask_head2
158
+
159
+ data_cache['loaded'] = True
160
+ log_memory_usage("after loading all data")
161
+ return data_cache
162
+
163
  def visualize_st4rtrack(
164
  traj_path: str = "results",
165
  up_dir: str = "-z", # should be +z or -z
 
213
  format="jpeg"
214
  )
215
 
216
+ # Use preloaded data if available
217
+ if preloaded_data and preloaded_data.get('loaded', False):
218
+ traj_3d_head1 = preloaded_data.get('traj_3d_head1')
219
+ traj_3d_head2 = preloaded_data.get('traj_3d_head2')
220
+ conf_mask_head1 = preloaded_data.get('conf_mask_head1')
221
+ conf_mask_head2 = preloaded_data.get('conf_mask_head2')
222
+ masks = preloaded_data.get('masks')
223
+ raw_video = preloaded_data.get('raw_video')
224
+ print("Using preloaded data!")
225
+ else:
226
+ # Load data using the shared function
227
+ print("No preloaded data available, loading from files...")
228
+ data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder)
229
+ traj_3d_head1 = data.get('traj_3d_head1')
230
+ traj_3d_head2 = data.get('traj_3d_head2')
231
+ conf_mask_head1 = data.get('conf_mask_head1')
232
+ conf_mask_head2 = data.get('conf_mask_head2')
233
+ masks = data.get('masks')
234
+ raw_video = data.get('raw_video')
235
+
236
  def process_video_frame(frame_idx):
237
  if raw_video is None:
238
  return np.zeros((video_height, video_width, 3), dtype=np.uint8)
 
272
  server.scene.set_up_direction(up_dir)
273
  print("Setting up visualization!")
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  # Add visualization controls
276
  with server.gui.add_folder("Visualization"):
277
  gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
 
337
  min=1,
338
  max=num_frames,
339
  step=1,
340
+ initial_value=5,
341
  disabled=True, # Initially disabled
342
  )
343