Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,15 +13,15 @@ from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
|
|
| 13 |
from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
|
| 14 |
|
| 15 |
# Device setup
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
# Load pretrained components
|
| 20 |
-
vae = AutoencoderKL.from_pretrained(
|
| 21 |
-
scheduler = DDIMScheduler.from_pretrained(
|
| 22 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 23 |
-
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 24 |
-
unet = UNet2DConditionModel.from_pretrained(
|
| 25 |
|
| 26 |
# Instantiate pipeline
|
| 27 |
pipe = DepthNormalEstimationPipeline(
|
|
@@ -30,14 +30,17 @@ pipe = DepthNormalEstimationPipeline(
|
|
| 30 |
feature_extractor=feature_extractor,
|
| 31 |
unet=unet,
|
| 32 |
scheduler=scheduler
|
| 33 |
-
).to(
|
| 34 |
pipe.unet.eval()
|
| 35 |
|
| 36 |
# UI texts
|
| 37 |
title = "# End-to-End Fine-Tuned GeoWizard Video"
|
| 38 |
-
description =
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
@spaces.GPU
|
| 43 |
def predict(image: Image.Image, processing_res_choice: int):
|
|
@@ -70,24 +73,24 @@ def on_submit_video(video_path: str, processing_res_choice: int):
|
|
| 70 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 71 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 72 |
|
| 73 |
-
#
|
| 74 |
tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
| 75 |
tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
| 76 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 77 |
out_depth = cv2.VideoWriter(tmp_depth.name, fourcc, fps, (width, height))
|
| 78 |
out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
|
| 79 |
|
| 80 |
-
# Process
|
| 81 |
for _ in tqdm(range(frame_count), desc="Processing frames"):
|
| 82 |
ret, frame = cap.read()
|
| 83 |
if not ret:
|
| 84 |
break
|
| 85 |
|
| 86 |
-
# Convert
|
| 87 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 88 |
pil_image = Image.fromarray(rgb)
|
| 89 |
|
| 90 |
-
#
|
| 91 |
result = predict(pil_image, processing_res_choice)
|
| 92 |
depth_colored = result.depth_colored
|
| 93 |
normal_colored = result.normal_colored
|
|
@@ -107,9 +110,10 @@ def on_submit_video(video_path: str, processing_res_choice: int):
|
|
| 107 |
out_depth.release()
|
| 108 |
out_normal.release()
|
| 109 |
|
| 110 |
-
# Return paths for download
|
| 111 |
return tmp_depth.name, tmp_normal.name
|
| 112 |
|
|
|
|
| 113 |
# Build Gradio interface
|
| 114 |
with gr.Blocks() as demo:
|
| 115 |
gr.Markdown(title)
|
|
@@ -117,7 +121,7 @@ with gr.Blocks() as demo:
|
|
| 117 |
gr.Markdown("### Depth and Normals Prediction on Video")
|
| 118 |
|
| 119 |
with gr.Row():
|
| 120 |
-
|
| 121 |
label="Input Video",
|
| 122 |
elem_id='video-display-input'
|
| 123 |
)
|
|
@@ -133,8 +137,14 @@ with gr.Blocks() as demo:
|
|
| 133 |
submit = gr.Button(value="Compute Depth and Normals")
|
| 134 |
|
| 135 |
with gr.Row():
|
| 136 |
-
output_depth_video = gr.Video(
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
submit.click(
|
| 140 |
fn=on_submit_video,
|
|
|
|
| 13 |
from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
|
| 14 |
|
| 15 |
# Device setup
|
| 16 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
+
CHECKPOINT_PATH = "GonzaloMG/geowizard-e2e-ft"
|
| 18 |
|
| 19 |
# Load pretrained components
|
| 20 |
+
vae = AutoencoderKL.from_pretrained(CHECKPOINT_PATH, subfolder='vae')
|
| 21 |
+
scheduler = DDIMScheduler.from_pretrained(CHECKPOINT_PATH, timestep_spacing="trailing", subfolder='scheduler')
|
| 22 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CHECKPOINT_PATH, subfolder="image_encoder")
|
| 23 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(CHECKPOINT_PATH, subfolder="feature_extractor")
|
| 24 |
+
unet = UNet2DConditionModel.from_pretrained(CHECKPOINT_PATH, subfolder="unet")
|
| 25 |
|
| 26 |
# Instantiate pipeline
|
| 27 |
pipe = DepthNormalEstimationPipeline(
|
|
|
|
| 30 |
feature_extractor=feature_extractor,
|
| 31 |
unet=unet,
|
| 32 |
scheduler=scheduler
|
| 33 |
+
).to(DEVICE)
|
| 34 |
pipe.unet.eval()
|
| 35 |
|
| 36 |
# UI texts
|
| 37 |
title = "# End-to-End Fine-Tuned GeoWizard Video"
|
| 38 |
+
description = (
|
| 39 |
+
"""
|
| 40 |
+
Please refer to our [paper](https://arxiv.org/abs/2409.11355) and
|
| 41 |
+
[GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
|
| 42 |
+
"""
|
| 43 |
+
)
|
| 44 |
|
| 45 |
@spaces.GPU
|
| 46 |
def predict(image: Image.Image, processing_res_choice: int):
|
|
|
|
| 73 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 74 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 75 |
|
| 76 |
+
# Create temporary output files
|
| 77 |
tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
| 78 |
tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
| 79 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 80 |
out_depth = cv2.VideoWriter(tmp_depth.name, fourcc, fps, (width, height))
|
| 81 |
out_normal = cv2.VideoWriter(tmp_normal.name, fourcc, fps, (width, height))
|
| 82 |
|
| 83 |
+
# Process each frame
|
| 84 |
for _ in tqdm(range(frame_count), desc="Processing frames"):
|
| 85 |
ret, frame = cap.read()
|
| 86 |
if not ret:
|
| 87 |
break
|
| 88 |
|
| 89 |
+
# Convert frame to PIL image
|
| 90 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 91 |
pil_image = Image.fromarray(rgb)
|
| 92 |
|
| 93 |
+
# Predict depth and normals
|
| 94 |
result = predict(pil_image, processing_res_choice)
|
| 95 |
depth_colored = result.depth_colored
|
| 96 |
normal_colored = result.normal_colored
|
|
|
|
| 110 |
out_depth.release()
|
| 111 |
out_normal.release()
|
| 112 |
|
| 113 |
+
# Return video paths for download
|
| 114 |
return tmp_depth.name, tmp_normal.name
|
| 115 |
|
| 116 |
+
|
| 117 |
# Build Gradio interface
|
| 118 |
with gr.Blocks() as demo:
|
| 119 |
gr.Markdown(title)
|
|
|
|
| 121 |
gr.Markdown("### Depth and Normals Prediction on Video")
|
| 122 |
|
| 123 |
with gr.Row():
|
| 124 |
+
input_video = gr.Video(
|
| 125 |
label="Input Video",
|
| 126 |
elem_id='video-display-input'
|
| 127 |
)
|
|
|
|
| 137 |
submit = gr.Button(value="Compute Depth and Normals")
|
| 138 |
|
| 139 |
with gr.Row():
|
| 140 |
+
output_depth_video = gr.Video(
|
| 141 |
+
label="Depth Video",
|
| 142 |
+
elem_id='download'
|
| 143 |
+
)
|
| 144 |
+
output_normal_video = gr.Video(
|
| 145 |
+
label="Normal Video",
|
| 146 |
+
elem_id='download'
|
| 147 |
+
)
|
| 148 |
|
| 149 |
submit.click(
|
| 150 |
fn=on_submit_video,
|