Spaces:
Configuration error
Configuration error
Commit ·
4c075ec
1
Parent(s): 6805b8e
add saving and reloading of session
Browse files- app.py +229 -50
- configs/stream_session.json +3 -0
- stream3r/stream_session.py +33 -4
- tests/test_stream_session_cache.py +17 -0
app.py
CHANGED
|
@@ -64,7 +64,13 @@ def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]:
|
|
| 64 |
if ext not in ALLOWED_IMG_EXT:
|
| 65 |
continue
|
| 66 |
# Construct final path safely
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Zip-slip guard (in case filename has ../ etc.)
|
| 69 |
if not _is_within_dir(outdir, dest_path):
|
| 70 |
continue
|
|
@@ -74,19 +80,82 @@ def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]:
|
|
| 74 |
return extracted
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# -------------------------------------------------------------------------
|
| 78 |
# 1) Core model inference
|
| 79 |
# -------------------------------------------------------------------------
|
| 80 |
@spaces.GPU(duration=180) # triggers ZeroGPU allocation for this call
|
| 81 |
-
def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> dict:
|
| 82 |
"""
|
| 83 |
-
Run the STream3R model on images in the 'target_dir/images' folder
|
| 84 |
|
| 85 |
Args:
|
| 86 |
target_dir: Directory containing the images subfolder
|
| 87 |
model: STream3R model instance
|
| 88 |
mode: Processing mode ("causal", "window", or "full")
|
| 89 |
streaming: If True, use StreamSession for sequential processing; if False, use batch processing
|
|
|
|
|
|
|
|
|
|
| 90 |
"""
|
| 91 |
print(f"Processing images from {target_dir}")
|
| 92 |
|
|
@@ -113,6 +182,8 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
|
|
| 113 |
print(f"Running inference in {'streaming' if streaming else 'batch'} mode...")
|
| 114 |
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 115 |
|
|
|
|
|
|
|
| 116 |
with torch.no_grad():
|
| 117 |
with torch.amp.autocast(dtype=dtype, device_type=device):
|
| 118 |
if streaming:
|
|
@@ -123,12 +194,34 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
|
|
| 123 |
|
| 124 |
session = StreamSession(model, mode=mode)
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
session.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
else:
|
| 133 |
# Use batch processing (original behavior)
|
| 134 |
predictions = model(images, mode=mode)
|
|
@@ -153,19 +246,20 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
|
|
| 153 |
|
| 154 |
# Clean up
|
| 155 |
torch.cuda.empty_cache()
|
| 156 |
-
return predictions
|
| 157 |
|
| 158 |
|
| 159 |
# -------------------------------------------------------------------------
|
| 160 |
# 2) Handle uploaded video/images --> produce target_dir + images
|
| 161 |
# -------------------------------------------------------------------------
|
| 162 |
-
def handle_uploads(input_video, input_images, input_zip=None):
|
| 163 |
"""
|
| 164 |
Create a new 'target_dir' + 'images' subfolder.
|
| 165 |
- Copies uploaded images
|
| 166 |
- Optionally extracts images from a ZIP
|
| 167 |
- Optionally extracts frames from a video (1 fps)
|
| 168 |
-
|
|
|
|
| 169 |
"""
|
| 170 |
start_time = time.time()
|
| 171 |
gc.collect()
|
|
@@ -173,11 +267,23 @@ def handle_uploads(input_video, input_images, input_zip=None):
|
|
| 173 |
|
| 174 |
# Create a unique folder name
|
| 175 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
-
|
| 180 |
-
shutil.rmtree(target_dir)
|
| 181 |
os.makedirs(target_dir_images, exist_ok=True)
|
| 182 |
|
| 183 |
image_paths: list[str] = []
|
|
@@ -186,9 +292,8 @@ def handle_uploads(input_video, input_images, input_zip=None):
|
|
| 186 |
if input_images:
|
| 187 |
for file_data in input_images:
|
| 188 |
file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
image_paths.append(dst_path)
|
| 192 |
|
| 193 |
# --- Handle ZIP (extract images) ---
|
| 194 |
if input_zip:
|
|
@@ -203,7 +308,7 @@ def handle_uploads(input_video, input_images, input_zip=None):
|
|
| 203 |
fps = vs.get(cv2.CAP_PROP_FPS) or 30.0
|
| 204 |
frame_interval = max(1, int(fps * 1)) # 1 frame/sec
|
| 205 |
count = 0
|
| 206 |
-
video_frame_num =
|
| 207 |
while True:
|
| 208 |
gotit, frame = vs.read()
|
| 209 |
if not gotit:
|
|
@@ -218,23 +323,44 @@ def handle_uploads(input_video, input_images, input_zip=None):
|
|
| 218 |
|
| 219 |
image_paths = sorted(set(image_paths)) # de-dupe + sort
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
end_time = time.time()
|
| 222 |
print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s")
|
| 223 |
-
return target_dir, image_paths
|
| 224 |
|
| 225 |
|
| 226 |
|
| 227 |
# -------------------------------------------------------------------------
|
| 228 |
# 3) Update gallery on upload
|
| 229 |
# -------------------------------------------------------------------------
|
| 230 |
-
def update_gallery_on_upload(input_video, input_images, input_zip):
|
| 231 |
"""
|
| 232 |
Handle any new uploads (video, images, or zip) and render preview.
|
| 233 |
"""
|
| 234 |
-
if not input_video and not input_images and not input_zip:
|
| 235 |
-
return None, None, None, None
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
|
|
@@ -271,12 +397,19 @@ def gradio_demo(
|
|
| 271 |
|
| 272 |
print("Running run_model...")
|
| 273 |
with torch.no_grad():
|
| 274 |
-
predictions = run_model(target_dir, model, mode=mode, streaming=streaming)
|
| 275 |
|
| 276 |
# Save predictions
|
| 277 |
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
| 278 |
np.savez(prediction_save_path, **predictions)
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
# Handle None frame_filter
|
| 281 |
if frame_filter is None:
|
| 282 |
frame_filter = "All"
|
|
@@ -310,7 +443,12 @@ def gradio_demo(
|
|
| 310 |
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 311 |
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 312 |
|
| 313 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
|
| 316 |
# -------------------------------------------------------------------------
|
|
@@ -331,7 +469,16 @@ def update_log():
|
|
| 331 |
|
| 332 |
|
| 333 |
def update_visualization(
|
| 334 |
-
target_dir,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
):
|
| 336 |
"""
|
| 337 |
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
|
@@ -364,9 +511,10 @@ def update_visualization(
|
|
| 364 |
loaded = np.load(predictions_path)
|
| 365 |
predictions = {key: np.array(loaded[key]) for key in key_list}
|
| 366 |
|
|
|
|
| 367 |
glbfile = os.path.join(
|
| 368 |
target_dir,
|
| 369 |
-
f"glbscene_{conf_thres}_{
|
| 370 |
)
|
| 371 |
|
| 372 |
if not os.path.exists(glbfile):
|
|
@@ -504,6 +652,7 @@ with gr.Blocks(
|
|
| 504 |
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 505 |
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
| 506 |
input_zip = gr.File(file_types=[".zip"], label="Upload ZIP of Images", interactive=True)
|
|
|
|
| 507 |
|
| 508 |
image_gallery = gr.Gallery(
|
| 509 |
label="Preview",
|
|
@@ -521,11 +670,22 @@ with gr.Blocks(
|
|
| 521 |
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
|
| 522 |
)
|
| 523 |
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
|
|
|
|
| 524 |
|
| 525 |
with gr.Row():
|
| 526 |
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 527 |
clear_btn = gr.ClearButton(
|
| 528 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
scale=1,
|
| 530 |
)
|
| 531 |
|
|
@@ -626,13 +786,22 @@ with gr.Blocks(
|
|
| 626 |
3) Return model3D + logs + new_dir + updated dropdown + gallery
|
| 627 |
We do NOT return is_example. It's just an input.
|
| 628 |
"""
|
| 629 |
-
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 630 |
# Always use "All" for frame_filter in examples
|
| 631 |
frame_filter = "All"
|
| 632 |
-
glbfile, log_msg, dropdown = gradio_demo(
|
| 633 |
-
target_dir,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
)
|
| 635 |
-
return glbfile, log_msg, target_dir, dropdown, image_paths
|
| 636 |
|
| 637 |
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
|
| 638 |
|
|
@@ -652,7 +821,14 @@ with gr.Blocks(
|
|
| 652 |
is_example,
|
| 653 |
mode,
|
| 654 |
],
|
| 655 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
fn=example_pipeline,
|
| 657 |
cache_examples=False,
|
| 658 |
examples_per_page=50,
|
|
@@ -681,7 +857,7 @@ with gr.Blocks(
|
|
| 681 |
mode,
|
| 682 |
streaming,
|
| 683 |
],
|
| 684 |
-
outputs=[reconstruction_output, log_output, frame_filter],
|
| 685 |
).then(
|
| 686 |
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
|
| 687 |
)
|
|
@@ -700,6 +876,7 @@ with gr.Blocks(
|
|
| 700 |
show_cam,
|
| 701 |
mask_sky,
|
| 702 |
prediction_mode,
|
|
|
|
| 703 |
is_example,
|
| 704 |
],
|
| 705 |
[reconstruction_output, log_output],
|
|
@@ -715,6 +892,7 @@ with gr.Blocks(
|
|
| 715 |
show_cam,
|
| 716 |
mask_sky,
|
| 717 |
prediction_mode,
|
|
|
|
| 718 |
is_example,
|
| 719 |
],
|
| 720 |
[reconstruction_output, log_output],
|
|
@@ -730,6 +908,7 @@ with gr.Blocks(
|
|
| 730 |
show_cam,
|
| 731 |
mask_sky,
|
| 732 |
prediction_mode,
|
|
|
|
| 733 |
is_example,
|
| 734 |
],
|
| 735 |
[reconstruction_output, log_output],
|
|
@@ -745,6 +924,7 @@ with gr.Blocks(
|
|
| 745 |
show_cam,
|
| 746 |
mask_sky,
|
| 747 |
prediction_mode,
|
|
|
|
| 748 |
is_example,
|
| 749 |
],
|
| 750 |
[reconstruction_output, log_output],
|
|
@@ -760,6 +940,7 @@ with gr.Blocks(
|
|
| 760 |
show_cam,
|
| 761 |
mask_sky,
|
| 762 |
prediction_mode,
|
|
|
|
| 763 |
is_example,
|
| 764 |
],
|
| 765 |
[reconstruction_output, log_output],
|
|
@@ -775,6 +956,7 @@ with gr.Blocks(
|
|
| 775 |
show_cam,
|
| 776 |
mask_sky,
|
| 777 |
prediction_mode,
|
|
|
|
| 778 |
is_example,
|
| 779 |
],
|
| 780 |
[reconstruction_output, log_output],
|
|
@@ -790,6 +972,7 @@ with gr.Blocks(
|
|
| 790 |
show_cam,
|
| 791 |
mask_sky,
|
| 792 |
prediction_mode,
|
|
|
|
| 793 |
is_example,
|
| 794 |
],
|
| 795 |
[reconstruction_output, log_output],
|
|
@@ -798,20 +981,16 @@ with gr.Blocks(
|
|
| 798 |
# -------------------------------------------------------------------------
|
| 799 |
# Auto-update gallery whenever user uploads or changes their files
|
| 800 |
# -------------------------------------------------------------------------
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
)
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
inputs=[input_video, input_images, input_zip],
|
| 809 |
-
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 810 |
-
)
|
| 811 |
-
input_zip.change(
|
| 812 |
fn=update_gallery_on_upload,
|
| 813 |
-
inputs=[input_video, input_images, input_zip],
|
| 814 |
-
outputs=
|
| 815 |
)
|
| 816 |
|
| 817 |
demo.queue(max_size=20).launch(show_error=True, share=False)
|
|
|
|
| 64 |
if ext not in ALLOWED_IMG_EXT:
|
| 65 |
continue
|
| 66 |
# Construct final path safely
|
| 67 |
+
base_name = os.path.basename(name)
|
| 68 |
+
name_root, name_ext = os.path.splitext(base_name)
|
| 69 |
+
dest_path = os.path.join(outdir, base_name)
|
| 70 |
+
counter = 1
|
| 71 |
+
while os.path.exists(dest_path):
|
| 72 |
+
dest_path = os.path.join(outdir, f"{name_root}_{counter}{name_ext}")
|
| 73 |
+
counter += 1
|
| 74 |
# Zip-slip guard (in case filename has ../ etc.)
|
| 75 |
if not _is_within_dir(outdir, dest_path):
|
| 76 |
continue
|
|
|
|
| 80 |
return extracted
|
| 81 |
|
| 82 |
|
| 83 |
+
def extract_session_state(zip_path: str, extract_root: str) -> str:
|
| 84 |
+
"""Extract a previously saved session archive into *extract_root*.
|
| 85 |
+
|
| 86 |
+
Returns the directory that contains the restored session data.
|
| 87 |
+
"""
|
| 88 |
+
if os.path.exists(extract_root):
|
| 89 |
+
shutil.rmtree(extract_root)
|
| 90 |
+
os.makedirs(extract_root, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
with zipfile.ZipFile(zip_path, "r") as zf:
|
| 93 |
+
zf.extractall(extract_root)
|
| 94 |
+
|
| 95 |
+
entries = [os.path.join(extract_root, entry) for entry in os.listdir(extract_root)]
|
| 96 |
+
dirs = [entry for entry in entries if os.path.isdir(entry)]
|
| 97 |
+
files = [entry for entry in entries if os.path.isfile(entry)]
|
| 98 |
+
|
| 99 |
+
if len(dirs) == 1 and not files:
|
| 100 |
+
return dirs[0]
|
| 101 |
+
return extract_root
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def package_session_state(target_dir: str) -> str:
|
| 105 |
+
"""Create a zip archive containing the entire session directory."""
|
| 106 |
+
if not os.path.isdir(target_dir):
|
| 107 |
+
raise ValueError(f"Target directory does not exist: {target_dir}")
|
| 108 |
+
|
| 109 |
+
os.makedirs("demo_cache", exist_ok=True)
|
| 110 |
+
archive_name = f"{os.path.basename(os.path.normpath(target_dir))}_session.zip"
|
| 111 |
+
archive_path = os.path.join("demo_cache", archive_name)
|
| 112 |
+
|
| 113 |
+
if os.path.exists(archive_path):
|
| 114 |
+
os.remove(archive_path)
|
| 115 |
+
|
| 116 |
+
with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
| 117 |
+
for root, _, files in os.walk(target_dir):
|
| 118 |
+
for fname in files:
|
| 119 |
+
file_path = os.path.join(root, fname)
|
| 120 |
+
if os.path.abspath(file_path) == os.path.abspath(archive_path):
|
| 121 |
+
continue
|
| 122 |
+
arcname = os.path.join(os.path.basename(target_dir), os.path.relpath(file_path, target_dir))
|
| 123 |
+
zf.write(file_path, arcname)
|
| 124 |
+
|
| 125 |
+
return archive_path
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _copy_with_unique_name(src_path: str, dst_dir: str) -> str:
|
| 129 |
+
"""Copy *src_path* into *dst_dir*, avoiding filename collisions."""
|
| 130 |
+
base_name = os.path.basename(src_path)
|
| 131 |
+
name, ext = os.path.splitext(base_name)
|
| 132 |
+
candidate = base_name
|
| 133 |
+
counter = 1
|
| 134 |
+
dest_path = os.path.join(dst_dir, candidate)
|
| 135 |
+
while os.path.exists(dest_path):
|
| 136 |
+
candidate = f"{name}_{counter}{ext}"
|
| 137 |
+
dest_path = os.path.join(dst_dir, candidate)
|
| 138 |
+
counter += 1
|
| 139 |
+
shutil.copy(src_path, dest_path)
|
| 140 |
+
return dest_path
|
| 141 |
+
|
| 142 |
+
|
| 143 |
# -------------------------------------------------------------------------
|
| 144 |
# 1) Core model inference
|
| 145 |
# -------------------------------------------------------------------------
|
| 146 |
@spaces.GPU(duration=180) # triggers ZeroGPU allocation for this call
|
| 147 |
+
def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> tuple[dict, str | None]:
|
| 148 |
"""
|
| 149 |
+
Run the STream3R model on images in the 'target_dir/images' folder.
|
| 150 |
|
| 151 |
Args:
|
| 152 |
target_dir: Directory containing the images subfolder
|
| 153 |
model: STream3R model instance
|
| 154 |
mode: Processing mode ("causal", "window", or "full")
|
| 155 |
streaming: If True, use StreamSession for sequential processing; if False, use batch processing
|
| 156 |
+
Returns:
|
| 157 |
+
tuple[dict, str | None]: Predictions dictionary and optional path to the saved session cache when
|
| 158 |
+
streaming mode is used.
|
| 159 |
"""
|
| 160 |
print(f"Processing images from {target_dir}")
|
| 161 |
|
|
|
|
| 182 |
print(f"Running inference in {'streaming' if streaming else 'batch'} mode...")
|
| 183 |
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 184 |
|
| 185 |
+
session_cache_path: str | None = None
|
| 186 |
+
|
| 187 |
with torch.no_grad():
|
| 188 |
with torch.amp.autocast(dtype=dtype, device_type=device):
|
| 189 |
if streaming:
|
|
|
|
| 194 |
|
| 195 |
session = StreamSession(model, mode=mode)
|
| 196 |
|
| 197 |
+
kv_cache_path = os.path.join(target_dir, "kv_cache.pt")
|
| 198 |
+
if os.path.exists(kv_cache_path):
|
| 199 |
+
print(f"Loading existing session cache from {kv_cache_path}")
|
| 200 |
+
session.load_cache(kv_cache_path, device=images.device)
|
| 201 |
+
|
| 202 |
+
existing_predictions = session.get_all_predictions()
|
| 203 |
+
existing_frames = 0
|
| 204 |
+
for value in existing_predictions.values():
|
| 205 |
+
if isinstance(value, torch.Tensor) and value.dim() >= 2:
|
| 206 |
+
existing_frames = max(existing_frames, value.shape[1])
|
| 207 |
+
|
| 208 |
+
total_frames = images.shape[0]
|
| 209 |
+
if existing_frames > total_frames:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"Session cache contains more frames than available images. Please ensure the images folder "
|
| 212 |
+
"matches the saved session state."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if existing_frames == total_frames:
|
| 216 |
+
print("No new frames detected; reusing cached predictions.")
|
| 217 |
+
else:
|
| 218 |
+
for i in range(existing_frames, total_frames):
|
| 219 |
+
image = images[i : i + 1]
|
| 220 |
+
session.forward_stream(image)
|
| 221 |
+
|
| 222 |
+
predictions = session.get_all_predictions()
|
| 223 |
+
session.save_cache(kv_cache_path)
|
| 224 |
+
session_cache_path = kv_cache_path
|
| 225 |
else:
|
| 226 |
# Use batch processing (original behavior)
|
| 227 |
predictions = model(images, mode=mode)
|
|
|
|
| 246 |
|
| 247 |
# Clean up
|
| 248 |
torch.cuda.empty_cache()
|
| 249 |
+
return predictions, session_cache_path
|
| 250 |
|
| 251 |
|
| 252 |
# -------------------------------------------------------------------------
|
| 253 |
# 2) Handle uploaded video/images --> produce target_dir + images
|
| 254 |
# -------------------------------------------------------------------------
|
| 255 |
+
def handle_uploads(input_video, input_images, input_zip=None, session_state=None, current_target_dir: str | None = None):
|
| 256 |
"""
|
| 257 |
Create a new 'target_dir' + 'images' subfolder.
|
| 258 |
- Copies uploaded images
|
| 259 |
- Optionally extracts images from a ZIP
|
| 260 |
- Optionally extracts frames from a video (1 fps)
|
| 261 |
+
- Optionally loads a previously saved session archive
|
| 262 |
+
Returns (target_dir, image_paths, session_loaded).
|
| 263 |
"""
|
| 264 |
start_time = time.time()
|
| 265 |
gc.collect()
|
|
|
|
| 267 |
|
| 268 |
# Create a unique folder name
|
| 269 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 270 |
+
session_loaded = False
|
| 271 |
+
|
| 272 |
+
if session_state:
|
| 273 |
+
session_path = session_state.get("name") if isinstance(session_state, dict) and "name" in session_state else getattr(session_state, "name", None)
|
| 274 |
+
session_path = session_path or session_state
|
| 275 |
+
extract_root = os.path.join("demo_cache", f"session_{timestamp}")
|
| 276 |
+
target_dir = extract_session_state(session_path, extract_root)
|
| 277 |
+
session_loaded = True
|
| 278 |
+
elif current_target_dir and os.path.isdir(current_target_dir):
|
| 279 |
+
target_dir = current_target_dir
|
| 280 |
+
else:
|
| 281 |
+
target_dir = os.path.join("demo_cache", f"input_images_{timestamp}")
|
| 282 |
+
if os.path.exists(target_dir):
|
| 283 |
+
shutil.rmtree(target_dir)
|
| 284 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 285 |
|
| 286 |
+
target_dir_images = os.path.join(target_dir, "images")
|
|
|
|
| 287 |
os.makedirs(target_dir_images, exist_ok=True)
|
| 288 |
|
| 289 |
image_paths: list[str] = []
|
|
|
|
| 292 |
if input_images:
|
| 293 |
for file_data in input_images:
|
| 294 |
file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data
|
| 295 |
+
copied_path = _copy_with_unique_name(file_path, target_dir_images)
|
| 296 |
+
image_paths.append(copied_path)
|
|
|
|
| 297 |
|
| 298 |
# --- Handle ZIP (extract images) ---
|
| 299 |
if input_zip:
|
|
|
|
| 308 |
fps = vs.get(cv2.CAP_PROP_FPS) or 30.0
|
| 309 |
frame_interval = max(1, int(fps * 1)) # 1 frame/sec
|
| 310 |
count = 0
|
| 311 |
+
video_frame_num = len(os.listdir(target_dir_images))
|
| 312 |
while True:
|
| 313 |
gotit, frame = vs.read()
|
| 314 |
if not gotit:
|
|
|
|
| 323 |
|
| 324 |
image_paths = sorted(set(image_paths)) # de-dupe + sort
|
| 325 |
|
| 326 |
+
# Ensure gallery reflects existing files in the images directory
|
| 327 |
+
existing_images = sorted(glob.glob(os.path.join(target_dir_images, "*")))
|
| 328 |
+
image_paths = existing_images
|
| 329 |
+
|
| 330 |
end_time = time.time()
|
| 331 |
print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s")
|
| 332 |
+
return target_dir, image_paths, session_loaded
|
| 333 |
|
| 334 |
|
| 335 |
|
| 336 |
# -------------------------------------------------------------------------
|
| 337 |
# 3) Update gallery on upload
|
| 338 |
# -------------------------------------------------------------------------
|
| 339 |
+
def update_gallery_on_upload(input_video, input_images, input_zip, session_state, current_target_dir):
|
| 340 |
"""
|
| 341 |
Handle any new uploads (video, images, or zip) and render preview.
|
| 342 |
"""
|
| 343 |
+
if not input_video and not input_images and not input_zip and not session_state:
|
| 344 |
+
return None, current_target_dir, None, None, None
|
| 345 |
+
|
| 346 |
+
target_dir, image_paths, session_loaded = handle_uploads(
|
| 347 |
+
input_video,
|
| 348 |
+
input_images,
|
| 349 |
+
input_zip,
|
| 350 |
+
session_state=session_state,
|
| 351 |
+
current_target_dir=current_target_dir,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if session_loaded:
|
| 355 |
+
message = "Session state loaded. Add new frames and click 'Reconstruct' to continue."
|
| 356 |
+
else:
|
| 357 |
+
message = "Upload complete. Click 'Reconstruct' to begin 3D processing."
|
| 358 |
+
|
| 359 |
+
return None, target_dir, image_paths, message, None
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def update_gallery_without_session(input_video, input_images, input_zip, current_target_dir):
|
| 363 |
+
return update_gallery_on_upload(input_video, input_images, input_zip, None, current_target_dir)
|
| 364 |
|
| 365 |
|
| 366 |
|
|
|
|
| 397 |
|
| 398 |
print("Running run_model...")
|
| 399 |
with torch.no_grad():
|
| 400 |
+
predictions, session_cache_path = run_model(target_dir, model, mode=mode, streaming=streaming)
|
| 401 |
|
| 402 |
# Save predictions
|
| 403 |
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
| 404 |
np.savez(prediction_save_path, **predictions)
|
| 405 |
|
| 406 |
+
session_state_file = None
|
| 407 |
+
if streaming:
|
| 408 |
+
if session_cache_path is None:
|
| 409 |
+
session_cache_path = os.path.join(target_dir, "kv_cache.pt")
|
| 410 |
+
if os.path.exists(session_cache_path):
|
| 411 |
+
session_state_file = package_session_state(target_dir)
|
| 412 |
+
|
| 413 |
# Handle None frame_filter
|
| 414 |
if frame_filter is None:
|
| 415 |
frame_filter = "All"
|
|
|
|
| 443 |
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 444 |
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 445 |
|
| 446 |
+
return (
|
| 447 |
+
glbfile,
|
| 448 |
+
log_msg,
|
| 449 |
+
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
|
| 450 |
+
session_state_file,
|
| 451 |
+
)
|
| 452 |
|
| 453 |
|
| 454 |
# -------------------------------------------------------------------------
|
|
|
|
| 469 |
|
| 470 |
|
| 471 |
def update_visualization(
|
| 472 |
+
target_dir,
|
| 473 |
+
conf_thres,
|
| 474 |
+
frame_filter,
|
| 475 |
+
mask_black_bg,
|
| 476 |
+
mask_white_bg,
|
| 477 |
+
show_cam,
|
| 478 |
+
mask_sky,
|
| 479 |
+
prediction_mode,
|
| 480 |
+
mode_value,
|
| 481 |
+
is_example,
|
| 482 |
):
|
| 483 |
"""
|
| 484 |
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
|
|
|
| 511 |
loaded = np.load(predictions_path)
|
| 512 |
predictions = {key: np.array(loaded[key]) for key in key_list}
|
| 513 |
|
| 514 |
+
sanitized_frame = frame_filter.replace('.', '_').replace(':', '').replace(' ', '_') if frame_filter else "All"
|
| 515 |
glbfile = os.path.join(
|
| 516 |
target_dir,
|
| 517 |
+
f"glbscene_{conf_thres}_{sanitized_frame}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode_value}.glb",
|
| 518 |
)
|
| 519 |
|
| 520 |
if not os.path.exists(glbfile):
|
|
|
|
| 652 |
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 653 |
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
| 654 |
input_zip = gr.File(file_types=[".zip"], label="Upload ZIP of Images", interactive=True)
|
| 655 |
+
session_state_input = gr.File(file_types=[".zip"], label="Load Session State", interactive=True)
|
| 656 |
|
| 657 |
image_gallery = gr.Gallery(
|
| 658 |
label="Preview",
|
|
|
|
| 670 |
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
|
| 671 |
)
|
| 672 |
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
|
| 673 |
+
session_state_output = gr.File(label="Download Session State", interactive=False)
|
| 674 |
|
| 675 |
with gr.Row():
|
| 676 |
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 677 |
clear_btn = gr.ClearButton(
|
| 678 |
+
[
|
| 679 |
+
input_video,
|
| 680 |
+
input_images,
|
| 681 |
+
input_zip,
|
| 682 |
+
session_state_input,
|
| 683 |
+
reconstruction_output,
|
| 684 |
+
log_output,
|
| 685 |
+
target_dir_output,
|
| 686 |
+
image_gallery,
|
| 687 |
+
session_state_output,
|
| 688 |
+
],
|
| 689 |
scale=1,
|
| 690 |
)
|
| 691 |
|
|
|
|
| 786 |
3) Return model3D + logs + new_dir + updated dropdown + gallery
|
| 787 |
We do NOT return is_example. It's just an input.
|
| 788 |
"""
|
| 789 |
+
target_dir, image_paths, _ = handle_uploads(input_video, input_images)
|
| 790 |
# Always use "All" for frame_filter in examples
|
| 791 |
frame_filter = "All"
|
| 792 |
+
glbfile, log_msg, dropdown, session_file = gradio_demo(
|
| 793 |
+
target_dir,
|
| 794 |
+
conf_thres,
|
| 795 |
+
frame_filter,
|
| 796 |
+
mask_black_bg,
|
| 797 |
+
mask_white_bg,
|
| 798 |
+
show_cam,
|
| 799 |
+
mask_sky,
|
| 800 |
+
prediction_mode,
|
| 801 |
+
mode,
|
| 802 |
+
False,
|
| 803 |
)
|
| 804 |
+
return glbfile, log_msg, target_dir, dropdown, image_paths, session_file
|
| 805 |
|
| 806 |
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
|
| 807 |
|
|
|
|
| 821 |
is_example,
|
| 822 |
mode,
|
| 823 |
],
|
| 824 |
+
outputs=[
|
| 825 |
+
reconstruction_output,
|
| 826 |
+
log_output,
|
| 827 |
+
target_dir_output,
|
| 828 |
+
frame_filter,
|
| 829 |
+
image_gallery,
|
| 830 |
+
session_state_output,
|
| 831 |
+
],
|
| 832 |
fn=example_pipeline,
|
| 833 |
cache_examples=False,
|
| 834 |
examples_per_page=50,
|
|
|
|
| 857 |
mode,
|
| 858 |
streaming,
|
| 859 |
],
|
| 860 |
+
outputs=[reconstruction_output, log_output, frame_filter, session_state_output],
|
| 861 |
).then(
|
| 862 |
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
|
| 863 |
)
|
|
|
|
| 876 |
show_cam,
|
| 877 |
mask_sky,
|
| 878 |
prediction_mode,
|
| 879 |
+
mode,
|
| 880 |
is_example,
|
| 881 |
],
|
| 882 |
[reconstruction_output, log_output],
|
|
|
|
| 892 |
show_cam,
|
| 893 |
mask_sky,
|
| 894 |
prediction_mode,
|
| 895 |
+
mode,
|
| 896 |
is_example,
|
| 897 |
],
|
| 898 |
[reconstruction_output, log_output],
|
|
|
|
| 908 |
show_cam,
|
| 909 |
mask_sky,
|
| 910 |
prediction_mode,
|
| 911 |
+
mode,
|
| 912 |
is_example,
|
| 913 |
],
|
| 914 |
[reconstruction_output, log_output],
|
|
|
|
| 924 |
show_cam,
|
| 925 |
mask_sky,
|
| 926 |
prediction_mode,
|
| 927 |
+
mode,
|
| 928 |
is_example,
|
| 929 |
],
|
| 930 |
[reconstruction_output, log_output],
|
|
|
|
| 940 |
show_cam,
|
| 941 |
mask_sky,
|
| 942 |
prediction_mode,
|
| 943 |
+
mode,
|
| 944 |
is_example,
|
| 945 |
],
|
| 946 |
[reconstruction_output, log_output],
|
|
|
|
| 956 |
show_cam,
|
| 957 |
mask_sky,
|
| 958 |
prediction_mode,
|
| 959 |
+
mode,
|
| 960 |
is_example,
|
| 961 |
],
|
| 962 |
[reconstruction_output, log_output],
|
|
|
|
| 972 |
show_cam,
|
| 973 |
mask_sky,
|
| 974 |
prediction_mode,
|
| 975 |
+
mode,
|
| 976 |
is_example,
|
| 977 |
],
|
| 978 |
[reconstruction_output, log_output],
|
|
|
|
| 981 |
# -------------------------------------------------------------------------
|
| 982 |
# Auto-update gallery whenever user uploads or changes their files
|
| 983 |
# -------------------------------------------------------------------------
|
| 984 |
+
upload_outputs = [reconstruction_output, target_dir_output, image_gallery, log_output, session_state_output]
|
| 985 |
+
no_session_inputs = [input_video, input_images, input_zip, target_dir_output]
|
| 986 |
+
|
| 987 |
+
input_video.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
|
| 988 |
+
input_images.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
|
| 989 |
+
input_zip.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
|
| 990 |
+
session_state_input.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
fn=update_gallery_on_upload,
|
| 992 |
+
inputs=[input_video, input_images, input_zip, session_state_input, target_dir_output],
|
| 993 |
+
outputs=upload_outputs,
|
| 994 |
)
|
| 995 |
|
| 996 |
demo.queue(max_size=20).launch(show_error=True, share=False)
|
configs/stream_session.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"window_size": 25
|
| 3 |
+
}
|
stream3r/stream_session.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import Any, Dict, Optional
|
| 3 |
|
|
@@ -9,12 +10,13 @@ class StreamSession:
|
|
| 9 |
"""
|
| 10 |
A causal streaming inference session with KV cache management for STream3R.
|
| 11 |
"""
|
| 12 |
-
def __init__(self, model: STream3R, mode: str):
|
| 13 |
self.model = model
|
| 14 |
self.mode = mode
|
| 15 |
self.aggregator_kv_cache_depth = model.aggregator.depth
|
| 16 |
self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
|
| 17 |
self.camera_head_iterations = 4
|
|
|
|
| 18 |
|
| 19 |
if self.mode not in ["causal", "window"]:
|
| 20 |
raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
|
|
@@ -41,13 +43,12 @@ class StreamSession:
|
|
| 41 |
self.aggregator_kv_cache_list = aggregator_kv_cache_list
|
| 42 |
self.camera_head_kv_cache_list = camera_head_kv_cache_list
|
| 43 |
elif self.mode == "window":
|
| 44 |
-
window_size = 25
|
| 45 |
for k in range(2):
|
| 46 |
for i in range(self.aggregator_kv_cache_depth):
|
| 47 |
h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
|
| 48 |
P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
|
| 49 |
anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
|
| 50 |
-
window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-window_size*P):]
|
| 51 |
self.aggregator_kv_cache_list[i][k] = torch.cat(
|
| 52 |
[
|
| 53 |
anchor_token,
|
|
@@ -58,7 +59,7 @@ class StreamSession:
|
|
| 58 |
for i in range(self.camera_head_iterations):
|
| 59 |
for j in range(self.camera_head_kv_cache_depth):
|
| 60 |
anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
|
| 61 |
-
window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-window_size):]
|
| 62 |
self.camera_head_kv_cache_list[i][j][k] = torch.cat(
|
| 63 |
[
|
| 64 |
anchor_token,
|
|
@@ -112,6 +113,32 @@ class StreamSession:
|
|
| 112 |
except StopIteration:
|
| 113 |
return torch.device("cpu")
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def save_cache(self, file_path: str) -> None:
|
| 116 |
aggregator_cache, camera_cache = self._get_cache()
|
| 117 |
|
|
@@ -121,6 +148,7 @@ class StreamSession:
|
|
| 121 |
"aggregator_depth": self.aggregator_kv_cache_depth,
|
| 122 |
"camera_head_depth": self.camera_head_kv_cache_depth,
|
| 123 |
"camera_head_iterations": self.camera_head_iterations,
|
|
|
|
| 124 |
"patch_size": getattr(self.model.aggregator, "patch_size", None),
|
| 125 |
"patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
|
| 126 |
},
|
|
@@ -148,6 +176,7 @@ class StreamSession:
|
|
| 148 |
"aggregator_depth": self.aggregator_kv_cache_depth,
|
| 149 |
"camera_head_depth": self.camera_head_kv_cache_depth,
|
| 150 |
"camera_head_iterations": self.camera_head_iterations,
|
|
|
|
| 151 |
}
|
| 152 |
|
| 153 |
for key, expected_value in expected_metadata.items():
|
|
|
|
| 1 |
+
import json
|
| 2 |
import os
|
| 3 |
from typing import Any, Dict, Optional
|
| 4 |
|
|
|
|
| 10 |
"""
|
| 11 |
A causal streaming inference session with KV cache management for STream3R.
|
| 12 |
"""
|
| 13 |
+
def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None):
|
| 14 |
self.model = model
|
| 15 |
self.mode = mode
|
| 16 |
self.aggregator_kv_cache_depth = model.aggregator.depth
|
| 17 |
self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
|
| 18 |
self.camera_head_iterations = 4
|
| 19 |
+
self.window_size = self._resolve_window_size(window_size, config_path)
|
| 20 |
|
| 21 |
if self.mode not in ["causal", "window"]:
|
| 22 |
raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
|
|
|
|
| 43 |
self.aggregator_kv_cache_list = aggregator_kv_cache_list
|
| 44 |
self.camera_head_kv_cache_list = camera_head_kv_cache_list
|
| 45 |
elif self.mode == "window":
|
|
|
|
| 46 |
for k in range(2):
|
| 47 |
for i in range(self.aggregator_kv_cache_depth):
|
| 48 |
h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
|
| 49 |
P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
|
| 50 |
anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
|
| 51 |
+
window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):]
|
| 52 |
self.aggregator_kv_cache_list[i][k] = torch.cat(
|
| 53 |
[
|
| 54 |
anchor_token,
|
|
|
|
| 59 |
for i in range(self.camera_head_iterations):
|
| 60 |
for j in range(self.camera_head_kv_cache_depth):
|
| 61 |
anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
|
| 62 |
+
window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):]
|
| 63 |
self.camera_head_kv_cache_list[i][j][k] = torch.cat(
|
| 64 |
[
|
| 65 |
anchor_token,
|
|
|
|
| 113 |
except StopIteration:
|
| 114 |
return torch.device("cpu")
|
| 115 |
|
| 116 |
+
def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int:
|
| 117 |
+
if override is not None:
|
| 118 |
+
return override
|
| 119 |
+
|
| 120 |
+
config_path = config_path or os.path.abspath(
|
| 121 |
+
os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json")
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
default_window_size = 25
|
| 125 |
+
|
| 126 |
+
if not os.path.exists(config_path):
|
| 127 |
+
return default_window_size
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
with open(config_path, "r", encoding="utf-8") as handle:
|
| 131 |
+
data = json.load(handle)
|
| 132 |
+
except (json.JSONDecodeError, OSError):
|
| 133 |
+
return default_window_size
|
| 134 |
+
|
| 135 |
+
window_size = data.get("window_size")
|
| 136 |
+
|
| 137 |
+
if isinstance(window_size, int) and window_size > 0:
|
| 138 |
+
return window_size
|
| 139 |
+
|
| 140 |
+
return default_window_size
|
| 141 |
+
|
| 142 |
def save_cache(self, file_path: str) -> None:
|
| 143 |
aggregator_cache, camera_cache = self._get_cache()
|
| 144 |
|
|
|
|
| 148 |
"aggregator_depth": self.aggregator_kv_cache_depth,
|
| 149 |
"camera_head_depth": self.camera_head_kv_cache_depth,
|
| 150 |
"camera_head_iterations": self.camera_head_iterations,
|
| 151 |
+
"window_size": self.window_size,
|
| 152 |
"patch_size": getattr(self.model.aggregator, "patch_size", None),
|
| 153 |
"patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
|
| 154 |
},
|
|
|
|
| 176 |
"aggregator_depth": self.aggregator_kv_cache_depth,
|
| 177 |
"camera_head_depth": self.camera_head_kv_cache_depth,
|
| 178 |
"camera_head_iterations": self.camera_head_iterations,
|
| 179 |
+
"window_size": self.window_size,
|
| 180 |
}
|
| 181 |
|
| 182 |
for key, expected_value in expected_metadata.items():
|
tests/test_stream_session_cache.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
| 3 |
import unittest
|
|
@@ -101,6 +102,22 @@ else:
|
|
| 101 |
restored_tensor = restored_session.predictions[key]
|
| 102 |
self.assertTrue(torch.equal(original_tensor, restored_tensor))
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
if __name__ == "__main__": # pragma: no cover - manual execution
|
| 106 |
unittest.main()
|
|
|
|
| 1 |
+
import json
|
| 2 |
import os
|
| 3 |
import tempfile
|
| 4 |
import unittest
|
|
|
|
| 102 |
restored_tensor = restored_session.predictions[key]
|
| 103 |
self.assertTrue(torch.equal(original_tensor, restored_tensor))
|
| 104 |
|
| 105 |
+
def test_window_size_from_config(self):
|
| 106 |
+
model = _DummyModel()
|
| 107 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 108 |
+
config_path = os.path.join(tmpdir, "stream_session.json")
|
| 109 |
+
with open(config_path, "w", encoding="utf-8") as handle:
|
| 110 |
+
json.dump({"window_size": 7}, handle)
|
| 111 |
+
|
| 112 |
+
session = StreamSession(model, mode="window", config_path=config_path)
|
| 113 |
+
|
| 114 |
+
self.assertEqual(session.window_size, 7)
|
| 115 |
+
|
| 116 |
+
def test_window_size_override(self):
|
| 117 |
+
model = _DummyModel()
|
| 118 |
+
session = StreamSession(model, mode="window", window_size=11)
|
| 119 |
+
self.assertEqual(session.window_size, 11)
|
| 120 |
+
|
| 121 |
|
| 122 |
if __name__ == "__main__": # pragma: no cover - manual execution
|
| 123 |
unittest.main()
|