x10z commited on
Commit
2ab4d40
·
verified ·
1 Parent(s): 91ef6d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -13,15 +13,15 @@ from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
13
  from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
14
 
15
  # Device setup
16
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
- checkpoint_path = "GonzaloMG/geowizard-e2e-ft"
18
 
19
  # Load pretrained components
20
- vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder='vae')
21
- scheduler = DDIMScheduler.from_pretrained(checkpoint_path, timestep_spacing="trailing", subfolder='scheduler')
22
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(checkpoint_path, subfolder="image_encoder")
23
- feature_extractor = CLIPImageProcessor.from_pretrained(checkpoint_path, subfolder="feature_extractor")
24
- unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
25
 
26
  # Instantiate pipeline
27
  pipe = DepthNormalEstimationPipeline(
@@ -30,14 +30,17 @@ pipe = DepthNormalEstimationPipeline(
30
  feature_extractor=feature_extractor,
31
  unet=unet,
32
  scheduler=scheduler
33
- ).to(device)
34
  pipe.unet.eval()
35
 
36
  # UI texts
37
  title = "# End-to-End Fine-Tuned GeoWizard Video"
38
- description = """
39
- Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
40
- """
 
 
 
41
 
42
  @spaces.GPU
43
  def predict(image: Image.Image, processing_res_choice: int):
@@ -70,24 +73,24 @@ def on_submit_video(video_path: str, processing_res_choice: int):
70
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
71
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
72
 
73
- # Temporary output files
74
  tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
75
  tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
76
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
77
  out_depth = cv2.VideoWriter(tmp_depth.name, fourcc, fps, (width, height))
78
  out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
79
 
80
- # Process frames
81
  for _ in tqdm(range(frame_count), desc="Processing frames"):
82
  ret, frame = cap.read()
83
  if not ret:
84
  break
85
 
86
- # Convert BGR to RGB and to PIL
87
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
88
  pil_image = Image.fromarray(rgb)
89
 
90
- # Run prediction
91
  result = predict(pil_image, processing_res_choice)
92
  depth_colored = result.depth_colored
93
  normal_colored = result.normal_colored
@@ -107,9 +110,10 @@ def on_submit_video(video_path: str, processing_res_choice: int):
107
  out_depth.release()
108
  out_normal.release()
109
 
110
- # Return paths for download
111
  return tmp_depth.name, tmp_normal.name
112
 
 
113
  # Build Gradio interface
114
  with gr.Blocks() as demo:
115
  gr.Markdown(title)
@@ -117,7 +121,7 @@ with gr.Blocks() as demo:
117
  gr.Markdown("### Depth and Normals Prediction on Video")
118
 
119
  with gr.Row():
120
- input_video = gr.Video(
121
  label="Input Video",
122
  elem_id='video-display-input'
123
  )
@@ -133,8 +137,14 @@ with gr.Blocks() as demo:
133
  submit = gr.Button(value="Compute Depth and Normals")
134
 
135
  with gr.Row():
136
- output_depth_video = gr.Video(label="Depth Video", elem_id='download')
137
- output_normal_video = gr.Video(label="Normal Video", elem_id='download')
 
 
 
 
 
 
138
 
139
  submit.click(
140
  fn=on_submit_video,
 
13
  from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
14
 
15
  # Device setup
16
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft"
18
 
19
  # Load pretrained components
20
+ vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae')
21
+ scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler')
22
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder")
23
+ feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor")
24
+ unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet")
25
 
26
  # Instantiate pipeline
27
  pipe = DepthNormalEstimationPipeline(
 
30
  feature_extractor=feature_extractor,
31
  unet=unet,
32
  scheduler=scheduler
33
+ ).to(DEVICE)
34
  pipe.unet.eval()
35
 
36
  # UI texts
37
  title = "# End-to-End Fine-Tuned GeoWizard Video"
38
+ description = (
39
+ """
40
+ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and
41
+ [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
42
+ """
43
+ )
44
 
45
  @spaces.GPU
46
  def predict(image: Image.Image, processing_res_choice: int):
 
73
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
74
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
75
 
76
+ # Create temporary output files
77
  tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
78
  tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
79
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
80
  out_depth = cv2.VideoWriter(tmp_depth.name, fourcc, fps, (width, height))
81
  out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
82
 
83
+ # Process each frame
84
  for _ in tqdm(range(frame_count), desc="Processing frames"):
85
  ret, frame = cap.read()
86
  if not ret:
87
  break
88
 
89
+ # Convert frame to PIL image
90
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
91
  pil_image = Image.fromarray(rgb)
92
 
93
+ # Predict depth and normals
94
  result = predict(pil_image, processing_res_choice)
95
  depth_colored = result.depth_colored
96
  normal_colored = result.normal_colored
 
110
  out_depth.release()
111
  out_normal.release()
112
 
113
+ # Return video paths for download
114
  return tmp_depth.name, tmp_normal.name
115
 
116
+
117
  # Build Gradio interface
118
  with gr.Blocks() as demo:
119
  gr.Markdown(title)
 
121
  gr.Markdown("### Depth and Normals Prediction on Video")
122
 
123
  with gr.Row():
124
+ input_video = gr.Video(
125
  label="Input Video",
126
  elem_id='video-display-input'
127
  )
 
137
  submit = gr.Button(value="Compute Depth and Normals")
138
 
139
  with gr.Row():
140
+ output_depth_video = gr.Video(
141
+ label="Depth Video",
142
+ elem_id='download'
143
+ )
144
+ output_normal_video = gr.Video(
145
+ label="Normal Video",
146
+ elem_id='download'
147
+ )
148
 
149
  submit.click(
150
  fn=on_submit_video,