toshas commited on
Commit
ac3faee
·
1 Parent(s): 1585c37

initial commit

Browse files
Files changed (6) hide show
  1. .gitignore +4 -0
  2. README.md +13 -0
  3. app.py +361 -0
  4. gradio_patches/__init__.py +0 -0
  5. gradio_patches/radio.py +31 -0
  6. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ stereospace/
2
+ .gradio/
3
+ *.pyi
4
+ *.pyc
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StereoSpace
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 6.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ print("\n".join(f"{k}={v}" for k, v in os.environ.items()))
3
+ os.system("pip freeze")
4
+
5
+ import glob
6
+ import gradio as gr
7
+ import shutil
8
+ import spaces
9
+ import subprocess
10
+ import sys
11
+ import tempfile
12
+ from PIL import Image
13
+ from gradio_patches.radio import Radio
14
+
15
+ REPO_URL = "https://github.com/prs-eth/stereospace.git"
16
+ COMMIT_SHA = "d7bbae6"
17
+ REPO_DIR = "stereospace"
18
+ DEVICE = "cuda"
19
+
20
+
21
+
22
+ def clone_repository():
23
+ if os.path.exists(REPO_DIR) and os.path.isdir(os.path.join(REPO_DIR, ".git")):
24
+ print(f"Repository {REPO_DIR} already exists, checking out commit...")
25
+ subprocess.run(["git", "fetch"], cwd=REPO_DIR, check=True, capture_output=True)
26
+ subprocess.run(["git", "checkout", COMMIT_SHA], cwd=REPO_DIR, check=True)
27
+ else:
28
+ print(f"Cloning repository {REPO_URL} at commit {COMMIT_SHA}...")
29
+ subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
30
+ subprocess.run(["git", "checkout", COMMIT_SHA], cwd=REPO_DIR, check=True)
31
+ print(f"Repository ready at {REPO_DIR}")
32
+
33
+
34
+ clone_repository()
35
+
36
+ sys.path.insert(0, REPO_DIR)
37
+ from inference import generate_novel_view
38
+ from src import StereoSpace
39
+
40
+
41
+
42
+ def create_placeholder_image():
43
+ placeholder = Image.new('RGB', (1, 1), color='black')
44
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
45
+ placeholder.save(tmp.name)
46
+ return tmp.name
47
+
48
+
49
+ def find_output_file(output_dir, base_name, output_mode):
50
+ if output_mode == "Anaglyph":
51
+ pattern = os.path.join(output_dir, f"{base_name}_anaglyph.png")
52
+ elif output_mode == "Side-by-side":
53
+ pattern = os.path.join(output_dir, f"{base_name}_sbs.png")
54
+ elif output_mode == "Generated view":
55
+ pattern = os.path.join(output_dir, f"{base_name}_generated_*.png")
56
+ matches = glob.glob(pattern)
57
+ if matches:
58
+ pattern = matches[0]
59
+ else:
60
+ raise FileNotFoundError(f"No generated file found matching {pattern}")
61
+ elif output_mode == "Input view":
62
+ pattern = os.path.join(output_dir, f"{base_name}_source.png")
63
+ else:
64
+ raise ValueError(f"Unknown output mode: {output_mode}")
65
+
66
+ if not os.path.exists(pattern):
67
+ raise FileNotFoundError(f"Output file not found: {pattern}")
68
+
69
+ return pattern
70
+
71
+
72
+ def find_all_output_files(output_dir, base_name):
73
+ outputs = {}
74
+ modes = ["Anaglyph", "Side-by-side", "Generated view", "Input view"]
75
+ for mode in modes:
76
+ try:
77
+ output_file = find_output_file(output_dir, base_name, mode)
78
+ outputs[mode] = output_file
79
+ except FileNotFoundError:
80
+ pass
81
+ return outputs
82
+
83
+
84
+ @spaces.GPU
85
+ def process_all_modes(input_image):
86
+ if input_image is None:
87
+ raise gr.Error("Please upload an image or select an example.")
88
+
89
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_input:
90
+ input_path = tmp_input.name
91
+ if isinstance(input_image, str):
92
+ shutil.copy(input_image, input_path)
93
+ else:
94
+ Image.open(input_image).convert("RGB").save(input_path)
95
+
96
+ try:
97
+ with tempfile.TemporaryDirectory() as tmp_output:
98
+ base_name = os.path.splitext(os.path.basename(input_path))[0]
99
+
100
+ inference_script = "inference.py"
101
+ cmd = [
102
+ "python", inference_script,
103
+ "--input", input_path,
104
+ "--output", tmp_output
105
+ ]
106
+
107
+ print(f"Running command: {' '.join(cmd)}")
108
+ print(f"Working directory: {REPO_DIR}")
109
+ result = subprocess.run(
110
+ cmd,
111
+ cwd=REPO_DIR,
112
+ capture_output=True,
113
+ text=True
114
+ )
115
+
116
+ if result.returncode != 0:
117
+ error_msg = f"Inference failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
118
+ print(error_msg)
119
+ raise gr.Error(error_msg)
120
+
121
+ try:
122
+ output_files = find_all_output_files(tmp_output, base_name)
123
+ except Exception as e:
124
+ all_files = os.listdir(tmp_output)
125
+ error_msg = f"Output files not found. Available files: {all_files}\nError: {str(e)}"
126
+ print(error_msg)
127
+ raise gr.Error(error_msg)
128
+
129
+ outputs = {}
130
+ for mode, output_file in output_files.items():
131
+ output_image = Image.open(output_file).convert("RGB")
132
+ output_image.load()
133
+ outputs[mode] = output_image
134
+
135
+ return outputs
136
+ finally:
137
+ if os.path.exists(input_path):
138
+ os.unlink(input_path)
139
+
140
+ def get_example_images():
141
+ example_dir = os.path.join(REPO_DIR, "example_images")
142
+ if not os.path.exists(example_dir):
143
+ return []
144
+
145
+ image_extensions = ["*.png", "*.jpg", "*.jpeg", "*.PNG", "*.JPG", "*.JPEG"]
146
+ examples = []
147
+ for ext in image_extensions:
148
+ examples.extend(glob.glob(os.path.join(example_dir, ext)))
149
+
150
+ return sorted(examples)
151
+
152
+
153
+ with gr.Blocks(
154
+ title="StereoSpace Demo",
155
+ ) as demo:
156
+ gr.Markdown(
157
+ """
158
+ <div align="center">
159
+ <h2>StereoSpace: Depth-Free Synthesis of Stereo Geometry via End-to-End Diffusion in a Canonical Space</h2>
160
+ </div>
161
+ """
162
+ )
163
+ with gr.Row(elem_classes="remove-elements"):
164
+ gr.Markdown(
165
+ f"""
166
+ <p align="center">
167
+ <a title="Website" href="https://hf.co/spaces/prs-eth/stereospace_web" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
168
+ <img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
169
+ </a>
170
+ <a title="arXiv" href="https://arxiv.org/abs/2512.10959" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
171
+ <img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv%20-Paper-AF3436">
172
+ </a>
173
+ <a title="Github" href="https://github.com/prs-eth/stereospace" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
174
+ <img src="https://img.shields.io/github/stars/prs-eth/stereospace?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
175
+ </a>
176
+ <a title="Model weights" href="https://hf.co/prs-eth/stereospace-v1-0" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
177
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model%20-Weights-yellow" alt="imagedepth">
178
+ </a>
179
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
180
+ <img src="https://shields.io/twitter/follow/:?label=Subscribe%20for%20updates!" alt="social">
181
+ </a>
182
+ </p>
183
+ <p align="center" style="margin-top: 0px;">
184
+ Upload a photo or pick an example below to create stereo space, wait for the result, then watch it in anaglyph, side-by-side, or generated view.
185
+ If a quota limit appears, duplicate the space to continue.
186
+ </p>
187
+ """
188
+ )
189
+
190
+ with gr.Row():
191
+ output_mode = Radio(
192
+ choices=["Anaglyph", "Side-by-side", "Input view", "Generated view"],
193
+ value="Anaglyph",
194
+ label=None,
195
+ container=False,
196
+ scale=1,
197
+ elem_classes="horizontal-radio"
198
+ )
199
+
200
+ with gr.Row():
201
+ image = gr.Image(
202
+ type="filepath",
203
+ label="Input/Output Image",
204
+ elem_classes="result-image",
205
+ height=480,
206
+ )
207
+
208
+ outputs_gallery = gr.Gallery(
209
+ visible=False,
210
+ label="Computed Outputs",
211
+ show_label=False,
212
+ value=[],
213
+ )
214
+
215
+ def update_image_from_gallery(gallery_data, current_mode, current_image=None):
216
+ if gallery_data is None or not gallery_data:
217
+ return current_image if current_image is not None else None
218
+
219
+ result = None
220
+ for item in gallery_data:
221
+ if isinstance(item, (list, tuple)) and len(item) >= 2:
222
+ img, label = item[0], item[1]
223
+ if label == current_mode:
224
+ result = img
225
+ break
226
+ elif isinstance(item, str):
227
+ continue
228
+
229
+ if result is None and gallery_data:
230
+ first_item = gallery_data[0]
231
+ if isinstance(first_item, (list, tuple)) and len(first_item) >= 1:
232
+ result = first_item[0]
233
+ elif isinstance(first_item, str):
234
+ result = first_item
235
+
236
+ if result is None:
237
+ return current_image if current_image is not None else None
238
+
239
+ if isinstance(result, Image.Image):
240
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
241
+ result.save(tmp.name)
242
+ return tmp.name
243
+
244
+ return result
245
+
246
+ def process_new_image(img, current_mode):
247
+ if img is None:
248
+ placeholder = create_placeholder_image()
249
+ return [], placeholder, gr.update()
250
+
251
+ all_outputs = process_all_modes(img)
252
+ gallery_data = []
253
+ available_modes = []
254
+ for mode_name, mode_image in all_outputs.items():
255
+ gallery_data.append([mode_image, mode_name])
256
+ available_modes.append(mode_name)
257
+
258
+ placeholder = create_placeholder_image()
259
+ radio_value = current_mode if current_mode in available_modes else (available_modes[0] if available_modes else "Anaglyph")
260
+ return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value)
261
+
262
+ def process_example_simple(img):
263
+ if img is None:
264
+ placeholder = create_placeholder_image()
265
+ return [], placeholder, gr.update()
266
+
267
+ all_outputs = process_all_modes(img)
268
+
269
+ gallery_data = []
270
+ available_modes = []
271
+ for mode_name, mode_image in all_outputs.items():
272
+ gallery_data.append([mode_image, mode_name])
273
+ available_modes.append(mode_name)
274
+
275
+ placeholder = create_placeholder_image()
276
+ radio_value = available_modes[0] if available_modes else "Anaglyph"
277
+ return gallery_data, placeholder, gr.update(choices=available_modes, value=radio_value)
278
+
279
+ def clear_image():
280
+ return None, []
281
+
282
+ examples_list = get_example_images()
283
+ if examples_list:
284
+ def process_example_wrapper(img):
285
+ gallery_data, placeholder_image, radio_update = process_example_simple(img)
286
+ return gallery_data, placeholder_image, radio_update
287
+
288
+ examples_component = gr.Examples(
289
+ examples=examples_list,
290
+ inputs=[image],
291
+ outputs=[outputs_gallery, image, output_mode],
292
+ fn=process_example_wrapper,
293
+ cache_examples=True,
294
+ cache_mode="lazy",
295
+ label="Example Images",
296
+ elem_id="example-images-gallery",
297
+ )
298
+
299
+ def process_upload_wrapper(img, current_mode):
300
+ gallery_data, blocked_image, radio_update = process_new_image(img, current_mode)
301
+ return gallery_data, blocked_image, radio_update
302
+
303
+ upload_event = image.upload(
304
+ fn=process_upload_wrapper,
305
+ inputs=[image, output_mode],
306
+ outputs=[outputs_gallery, image, output_mode]
307
+ )
308
+
309
+ def on_gallery_change(gallery_data, current_mode, current_image):
310
+ if not gallery_data or len(gallery_data) == 0:
311
+ return current_image, gr.update(interactive=True)
312
+ updated_image = update_image_from_gallery(gallery_data, current_mode, current_image)
313
+ return updated_image, gr.update(interactive=True)
314
+
315
+ gallery_change_event = outputs_gallery.change(
316
+ fn=on_gallery_change,
317
+ inputs=[outputs_gallery, output_mode, image],
318
+ outputs=[image, image]
319
+ )
320
+
321
+ def switch_mode_handler(current_mode, gallery_data, current_image):
322
+ updated_image = update_image_from_gallery(gallery_data, current_mode, current_image)
323
+ return updated_image
324
+
325
+ output_mode.change(
326
+ fn=switch_mode_handler,
327
+ inputs=[output_mode, outputs_gallery, image],
328
+ outputs=image
329
+ )
330
+
331
+ image.clear(
332
+ fn=clear_image,
333
+ outputs=[image, outputs_gallery]
334
+ )
335
+
336
+ if __name__ == "__main__":
337
+ demo.queue().launch(
338
+ server_name="0.0.0.0",
339
+ ssr_mode=False,
340
+ css="""
341
+ #example-images-gallery button[class*="gallery-item"][class*="svelte-"] {
342
+ min-width: max(96px, calc(100vw / 8));
343
+ min-height: max(96px, calc(100vw / 8));
344
+ width: max(96px, calc(100vw / 8));
345
+ height: max(96px, calc(100vw / 8));
346
+ }
347
+ #example-images-gallery button[class*="gallery-item"] div[class*="container"] {
348
+ min-width: max(96px, calc(100vw / 8));
349
+ min-height: max(96px, calc(100vw / 8));
350
+ width: max(96px, calc(100vw / 8));
351
+ height: max(96px, calc(100vw / 8));
352
+ }
353
+ #example-images-gallery button[class*="gallery-item"] img {
354
+ min-width: max(96px, calc(100vw / 8));
355
+ min-height: max(96px, calc(100vw / 8));
356
+ width: max(96px, calc(100vw / 8));
357
+ height: max(96px, calc(100vw / 8));
358
+ object-fit: cover;
359
+ }
360
+ """,
361
+ )
gradio_patches/__init__.py ADDED
File without changes
gradio_patches/radio.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Workaround for https://github.com/gradio-app/gradio/issues/12564
3
+ """
4
+ import gradio as gr
5
+
6
+
7
+ class Radio(gr.Radio):
8
+ # Default values for attributes that Block.get_config() and Component.get_config() expect
9
+ _default_attributes = {
10
+ # Block attributes (from Block.__init__)
11
+ 'proxy_url': None,
12
+ 'rendered_in': None,
13
+ 'key': None,
14
+ 'visible': True,
15
+ 'elem_id': None,
16
+ 'elem_classes': [],
17
+ 'parent': None,
18
+ 'is_rendered': False,
19
+ # Component attributes (from Component.__init__)
20
+ 'info': None,
21
+ 'server_fns': [],
22
+ '_selectable': False,
23
+ 'label': None,
24
+ }
25
+
26
+ def __getattr__(self, name):
27
+ if name not in self._default_attributes:
28
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
29
+ value = self._default_attributes[name]
30
+ setattr(self, name, value)
31
+ return value
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ einops
3
+ gradio==6.1.0
4
+ jaxtyping
5
+ numpy
6
+ omegaconf
7
+ peft
8
+ Pillow
9
+ torch
10
+ torchvision
11
+ transformers
12
+ spaces