brian4dwell commited on
Commit
87dee1a
·
1 Parent(s): 133f02d

support zip

Browse files
app.py CHANGED
@@ -14,7 +14,7 @@ from datetime import datetime
14
  import glob
15
  import gc
16
  import time
17
-
18
  from stream3r.models.stream3r import STream3R
19
  from stream3r.stream_session import StreamSession
20
  from stream3r.models.components.utils.load_fn import load_and_preprocess_images
@@ -31,6 +31,49 @@ device = "cuda"
31
 
32
  model = STream3R.from_pretrained("yslan/STream3R")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # -------------------------------------------------------------------------
35
  # 1) Core model inference
36
  # -------------------------------------------------------------------------
@@ -116,10 +159,13 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
116
  # -------------------------------------------------------------------------
117
  # 2) Handle uploaded video/images --> produce target_dir + images
118
  # -------------------------------------------------------------------------
119
- def handle_uploads(input_video, input_images):
120
  """
121
- Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
122
- images or extracted frames from video into it. Return (target_dir, image_paths).
 
 
 
123
  """
124
  start_time = time.time()
125
  gc.collect()
@@ -130,36 +176,32 @@ def handle_uploads(input_video, input_images):
130
  target_dir = os.path.join("demo_cache", f"input_images_{timestamp}")
131
  target_dir_images = os.path.join(target_dir, "images")
132
 
133
- # Clean up if somehow that folder already exists
134
  if os.path.exists(target_dir):
135
  shutil.rmtree(target_dir)
136
- os.makedirs(target_dir)
137
- os.makedirs(target_dir_images)
138
 
139
- image_paths = []
140
 
141
- # --- Handle images ---
142
- if input_images is not None:
143
  for file_data in input_images:
144
- if isinstance(file_data, dict) and "name" in file_data:
145
- file_path = file_data["name"]
146
- else:
147
- file_path = file_data
148
  dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
149
  shutil.copy(file_path, dst_path)
150
  image_paths.append(dst_path)
151
 
152
- # --- Handle video ---
153
- if input_video is not None:
154
- if isinstance(input_video, dict) and "name" in input_video:
155
- video_path = input_video["name"]
156
- else:
157
- video_path = input_video
158
 
 
 
 
159
  vs = cv2.VideoCapture(video_path)
160
- fps = vs.get(cv2.CAP_PROP_FPS)
161
- frame_interval = int(fps * 1) # 1 frame/sec
162
-
163
  count = 0
164
  video_frame_num = 0
165
  while True:
@@ -172,30 +214,30 @@ def handle_uploads(input_video, input_images):
172
  cv2.imwrite(image_path, frame)
173
  image_paths.append(image_path)
174
  video_frame_num += 1
 
175
 
176
- # Sort final images for gallery
177
- image_paths = sorted(image_paths)
178
 
179
  end_time = time.time()
180
- print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
181
  return target_dir, image_paths
182
 
183
 
 
184
  # -------------------------------------------------------------------------
185
  # 3) Update gallery on upload
186
  # -------------------------------------------------------------------------
187
- def update_gallery_on_upload(input_video, input_images):
188
  """
189
- Whenever user uploads or changes files, immediately handle them
190
- and show in the gallery. Return (target_dir, image_paths).
191
- If nothing is uploaded, returns "None" and empty list.
192
  """
193
- if not input_video and not input_images:
194
  return None, None, None, None
195
- target_dir, image_paths = handle_uploads(input_video, input_images)
196
  return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
197
 
198
 
 
199
  # -------------------------------------------------------------------------
200
  # 4) Reconstruction: uses the target_dir plus any viz parameters
201
  # -------------------------------------------------------------------------
@@ -460,7 +502,8 @@ with gr.Blocks(
460
  with gr.Row():
461
  with gr.Column(scale=2):
462
  input_video = gr.Video(label="Upload Video", interactive=True)
463
- input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True, max_files=100)
 
464
 
465
  image_gallery = gr.Gallery(
466
  label="Preview",
@@ -757,13 +800,18 @@ with gr.Blocks(
757
  # -------------------------------------------------------------------------
758
  input_video.change(
759
  fn=update_gallery_on_upload,
760
- inputs=[input_video, input_images],
761
  outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
762
  )
763
  input_images.change(
764
  fn=update_gallery_on_upload,
765
- inputs=[input_video, input_images],
 
 
 
 
 
766
  outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
767
  )
768
 
769
- demo.queue(max_size=20).launch(show_error=True, share=True)
 
14
  import glob
15
  import gc
16
  import time
17
+ import zipfile
18
  from stream3r.models.stream3r import STream3R
19
  from stream3r.stream_session import StreamSession
20
  from stream3r.models.components.utils.load_fn import load_and_preprocess_images
 
31
 
32
  model = STream3R.from_pretrained("yslan/STream3R")
33
 
34
+ def handle_zip(zip_file):
35
+ outdir = "uploads"
36
+ os.makedirs(outdir, exist_ok=True)
37
+ with zipfile.ZipFile(zip_file.name, "r") as zf:
38
+ zf.extractall(outdir)
39
+ return f"Extracted {len(os.listdir(outdir))} files"
40
+
41
+ # --- add near your imports/helpers ---
42
+ ALLOWED_IMG_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
43
+
44
+ def _is_within_dir(base_dir: str, path: str) -> bool:
45
+ # Prevent zip-slip: ensure extracted path stays inside base_dir
46
+ base_dir = os.path.abspath(base_dir)
47
+ path = os.path.abspath(path)
48
+ return os.path.commonpath([base_dir]) == os.path.commonpath([base_dir, path])
49
+
50
+ def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]:
51
+ """
52
+ Extracts only image files from a zip into outdir.
53
+ Returns list of extracted file paths.
54
+ """
55
+ os.makedirs(outdir, exist_ok=True)
56
+ extracted = []
57
+ with zipfile.ZipFile(zip_path, "r") as zf:
58
+ for member in zf.infolist():
59
+ # Skip directories and non-image files
60
+ name = member.filename
61
+ if name.endswith("/"):
62
+ continue
63
+ ext = os.path.splitext(name)[1].lower()
64
+ if ext not in ALLOWED_IMG_EXT:
65
+ continue
66
+ # Construct final path safely
67
+ dest_path = os.path.join(outdir, os.path.basename(name))
68
+ # Zip-slip guard (in case filename has ../ etc.)
69
+ if not _is_within_dir(outdir, dest_path):
70
+ continue
71
+ with zf.open(member) as src, open(dest_path, "wb") as dst:
72
+ shutil.copyfileobj(src, dst)
73
+ extracted.append(dest_path)
74
+ return extracted
75
+
76
+
77
  # -------------------------------------------------------------------------
78
  # 1) Core model inference
79
  # -------------------------------------------------------------------------
 
159
  # -------------------------------------------------------------------------
160
  # 2) Handle uploaded video/images --> produce target_dir + images
161
  # -------------------------------------------------------------------------
162
+ def handle_uploads(input_video, input_images, input_zip=None):
163
  """
164
+ Create a new 'target_dir' + 'images' subfolder.
165
+ - Copies uploaded images
166
+ - Optionally extracts images from a ZIP
167
+ - Optionally extracts frames from a video (1 fps)
168
+ Returns (target_dir, image_paths).
169
  """
170
  start_time = time.time()
171
  gc.collect()
 
176
  target_dir = os.path.join("demo_cache", f"input_images_{timestamp}")
177
  target_dir_images = os.path.join(target_dir, "images")
178
 
 
179
  if os.path.exists(target_dir):
180
  shutil.rmtree(target_dir)
181
+ os.makedirs(target_dir_images, exist_ok=True)
 
182
 
183
+ image_paths: list[str] = []
184
 
185
+ # --- Handle images (list) ---
186
+ if input_images:
187
  for file_data in input_images:
188
+ file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data
 
 
 
189
  dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
190
  shutil.copy(file_path, dst_path)
191
  image_paths.append(dst_path)
192
 
193
+ # --- Handle ZIP (extract images) ---
194
+ if input_zip:
195
+ zip_path = input_zip["name"] if isinstance(input_zip, dict) and "name" in input_zip else input_zip
196
+ extracted = extract_images_from_zip(zip_path, target_dir_images)
197
+ image_paths.extend(extracted)
 
198
 
199
+ # --- Handle video (extract frames at 1 fps) ---
200
+ if input_video:
201
+ video_path = input_video["name"] if isinstance(input_video, dict) and "name" in input_video else input_video
202
  vs = cv2.VideoCapture(video_path)
203
+ fps = vs.get(cv2.CAP_PROP_FPS) or 30.0
204
+ frame_interval = max(1, int(fps * 1)) # 1 frame/sec
 
205
  count = 0
206
  video_frame_num = 0
207
  while True:
 
214
  cv2.imwrite(image_path, frame)
215
  image_paths.append(image_path)
216
  video_frame_num += 1
217
+ vs.release()
218
 
219
+ image_paths = sorted(set(image_paths)) # de-dupe + sort
 
220
 
221
  end_time = time.time()
222
+ print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s")
223
  return target_dir, image_paths
224
 
225
 
226
+
227
  # -------------------------------------------------------------------------
228
  # 3) Update gallery on upload
229
  # -------------------------------------------------------------------------
230
+ def update_gallery_on_upload(input_video, input_images, input_zip):
231
  """
232
+ Handle any new uploads (video, images, or zip) and render preview.
 
 
233
  """
234
+ if not input_video and not input_images and not input_zip:
235
  return None, None, None, None
236
+ target_dir, image_paths = handle_uploads(input_video, input_images, input_zip)
237
  return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
238
 
239
 
240
+
241
  # -------------------------------------------------------------------------
242
  # 4) Reconstruction: uses the target_dir plus any viz parameters
243
  # -------------------------------------------------------------------------
 
502
  with gr.Row():
503
  with gr.Column(scale=2):
504
  input_video = gr.Video(label="Upload Video", interactive=True)
505
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
506
+ input_zip = gr.File(file_types=[".zip"], label="Upload ZIP of Images", interactive=True)
507
 
508
  image_gallery = gr.Gallery(
509
  label="Preview",
 
800
  # -------------------------------------------------------------------------
801
  input_video.change(
802
  fn=update_gallery_on_upload,
803
+ inputs=[input_video, input_images, input_zip],
804
  outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
805
  )
806
  input_images.change(
807
  fn=update_gallery_on_upload,
808
+ inputs=[input_video, input_images, input_zip],
809
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
810
+ )
811
+ input_zip.change(
812
+ fn=update_gallery_on_upload,
813
+ inputs=[input_video, input_images, input_zip],
814
  outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
815
  )
816
 
817
+ demo.queue(max_size=20).launch(show_error=True, share=False)
stream3r/__pycache__/stream_session.cpython-310.pyc CHANGED
Binary files a/stream3r/__pycache__/stream_session.cpython-310.pyc and b/stream3r/__pycache__/stream_session.cpython-310.pyc differ