Tsmith2024 commited on
Commit
f7e223f
·
verified ·
1 Parent(s): 6ab3c53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -87
app.py CHANGED
@@ -12,7 +12,6 @@ from torch import Tensor
12
  from genstereo import GenStereo, AdaptiveFusionLayer
13
  import ssl
14
  from huggingface_hub import hf_hub_download
15
- import spaces
16
 
17
  from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
18
  ssl._create_default_https_context = ssl._create_unverified_context
@@ -146,39 +145,31 @@ with tempfile.TemporaryDirectory() as tmpdir:
146
  IMAGE_SIZE = 768
147
  CHECKPOINT_NAME = 'genstereo-v2.1'
148
  print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}")
149
- return None, None, None, None, None, None
150
 
151
  @spaces.GPU()
152
- def cb_generate(image, depth, scale_factor, sd_version):
153
- depth_tensor = torch.tensor(depth).unsqueeze(0).unsqueeze(0).float()
154
- norm_disp = normalize_disp(depth_tensor.cuda())
155
- disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
156
 
157
- genstereo = get_genstereo_model(sd_version)
158
- fusion_model = get_fusion_model()
 
 
 
 
 
 
 
159
 
160
- renders = genstereo(
161
- src_image=image,
162
- src_disparity=disp,
163
- ratio=None,
164
- )
165
- warped = (renders['warped'] + 1) / 2
166
-
167
- synthesized = renders['synthesized']
168
- mask = renders['mask']
169
- fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
170
 
171
- warped_pil = to_pil_image(warped[0])
172
- fusion_pil = to_pil_image(fusion_image[0])
173
 
174
- # Create full SBS for Quest 2
175
- left_resized = image.resize((1832, 1920))
176
- right_resized = fusion_pil.resize((1832, 1920))
177
- sbs = Image.new('RGB', (3664, 1920))
178
- sbs.paste(left_resized, (0, 0))
179
- sbs.paste(right_resized, (1832, 0))
180
 
181
- return warped_pil, fusion_pil, sbs
182
 
183
  @spaces.GPU()
184
  def cb_generate(image, depth, scale_factor, sd_version):
@@ -195,7 +186,7 @@ with tempfile.TemporaryDirectory() as tmpdir:
195
  ratio=None,
196
  )
197
  warped = (renders['warped'] + 1) / 2
198
-
199
  synthesized = renders['synthesized']
200
  mask = renders['mask']
201
  fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
@@ -209,21 +200,20 @@ with tempfile.TemporaryDirectory() as tmpdir:
209
  sbs = Image.new('RGB', (3664, 1920))
210
  sbs.paste(left_resized, (0, 0))
211
  sbs.paste(right_resized, (1832, 0))
212
- sbs.save('/home/user/app/sbs_quest2.jpg', quality=95)
213
 
214
- return warped_pil, fusion_pil
215
 
216
  # Blocks.
217
  gr.Markdown(
218
  """
219
  # [ICCV 2025] Towards Open-World Generation of Stereo Images and Unsupervised Matching
220
  [Project Site](https://qjizhi.github.io/genstereo) | [Spaces](https://huggingface.co/spaces/FQiao/GenStereo) | [Github](https://github.com/Qjizhi/GenStereo) | [Models](https://huggingface.co/FQiao/GenStereo-sd2.1/tree/main) | [arXiv](https://arxiv.org/abs/2503.12720)
221
-
222
- ## Introduction
223
  This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
224
-
225
  ## How to Use
226
-
227
  1. Select the GenStereo version
228
  - v1.5: 512px, faster.
229
  - v2.1: 768px, better performance, high resolution, takes more time.
@@ -241,60 +231,8 @@ with tempfile.TemporaryDirectory() as tmpdir:
241
  )
242
 
243
  with gr.Row():
244
-
245
  file = gr.File(label='Left', file_types=['image'])
246
  examples = gr.Examples(
247
  examples=['./assets/COCO_val2017_000000070229.jpg',
248
  './assets/COCO_val2017_000000092839.jpg',
249
- './assets/KITTI2015_000003_10.png',
250
- './assets/KITTI2015_000147_10.png'],
251
- inputs=file
252
- )
253
- with gr.Row():
254
- image_widget = gr.Image(
255
- label='Left Image', type='filepath',
256
- interactive=False
257
- )
258
- depth_widget = gr.Image(label='Estimated Depth', type='pil')
259
-
260
- # Add scale factor slider
261
- scale_slider = gr.Slider(
262
- label='Scale Factor',
263
- minimum=1.0,
264
- maximum=30.0,
265
- value=15.0,
266
- step=0.1,
267
- )
268
-
269
- button = gr.Button('Generate a right image', size='lg', variant='primary')
270
- with gr.Row():
271
- warped_widget = gr.Image(
272
- label='Warped Image', type='pil', interactive=False
273
- )
274
- gen_widget = gr.Image(
275
- label='Generated Right', type='pil', interactive=False
276
- )
277
-
278
- # Events
279
- sd_version_radio.change(
280
- fn=cb_update_sd_version,
281
- inputs=sd_version_radio,
282
- outputs=[
283
- image_widget, depth_widget, # Clear image displays
284
- src_image, src_depth, # Clear internal states
285
- warped_widget, gen_widget # Clear generation outputs
286
- ]
287
- )
288
- file.change(
289
- fn=cb_mde,
290
- inputs=[file, sd_version_radio],
291
- outputs=[image_widget, depth_widget, src_image, src_depth]
292
- )
293
- button.click(
294
- fn=cb_generate,
295
- inputs=[src_image, src_depth, scale_slider, sd_version_radio],
296
- outputs=[warped_widget, gen_widget]
297
- )
298
-
299
- if __name__ == '__main__':
300
- demo.launch()
 
12
  from genstereo import GenStereo, AdaptiveFusionLayer
13
  import ssl
14
  from huggingface_hub import hf_hub_download
 
15
 
16
  from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
17
  ssl._create_default_https_context = ssl._create_unverified_context
 
145
  IMAGE_SIZE = 768
146
  CHECKPOINT_NAME = 'genstereo-v2.1'
147
  print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}")
148
+ return None, None, None, None, None, None, None
149
 
150
  @spaces.GPU()
151
+ def cb_mde(image_file: str, sd_version):
152
+ if not image_file:
153
+ return None, None, None, None
 
154
 
155
+ image = crop(Image.open(image_file).convert('RGB'))
156
+ if sd_version == "v1.5":
157
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
158
+ elif sd_version == "v2.1":
159
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
160
+ else:
161
+ gr.Warning(f"Unknown SD version: {sd_version}. Defaulting to {IMAGE_SIZE}.")
162
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
163
+ gr.Info(f"Generating with GenStereo {sd_version} at {IMAGE_SIZE}px resolution.")
164
 
165
+ image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
166
 
167
+ dam2 = get_dam2_model()
168
+ depth_dam2 = dam2.infer_image(image_bgr)
169
 
170
+ depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET)
 
 
 
 
 
171
 
172
+ return image, depth_image, image, depth_dam2
173
 
174
  @spaces.GPU()
175
  def cb_generate(image, depth, scale_factor, sd_version):
 
186
  ratio=None,
187
  )
188
  warped = (renders['warped'] + 1) / 2
189
+
190
  synthesized = renders['synthesized']
191
  mask = renders['mask']
192
  fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
 
200
  sbs = Image.new('RGB', (3664, 1920))
201
  sbs.paste(left_resized, (0, 0))
202
  sbs.paste(right_resized, (1832, 0))
 
203
 
204
+ return warped_pil, fusion_pil, sbs
205
 
206
  # Blocks.
207
  gr.Markdown(
208
  """
209
  # [ICCV 2025] Towards Open-World Generation of Stereo Images and Unsupervised Matching
210
  [Project Site](https://qjizhi.github.io/genstereo) | [Spaces](https://huggingface.co/spaces/FQiao/GenStereo) | [Github](https://github.com/Qjizhi/GenStereo) | [Models](https://huggingface.co/FQiao/GenStereo-sd2.1/tree/main) | [arXiv](https://arxiv.org/abs/2503.12720)
211
+
212
+ ## Introduction
213
  This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
214
+
215
  ## How to Use
216
+
217
  1. Select the GenStereo version
218
  - v1.5: 512px, faster.
219
  - v2.1: 768px, better performance, high resolution, takes more time.
 
231
  )
232
 
233
  with gr.Row():
 
234
  file = gr.File(label='Left', file_types=['image'])
235
  examples = gr.Examples(
236
  examples=['./assets/COCO_val2017_000000070229.jpg',
237
  './assets/COCO_val2017_000000092839.jpg',
238
+ './asset