marigold2 / app.py
ciucinciu's picture
Update app.py
4be156f verified
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://marigolddepthcompletion.github.io/
# https://rollingdepth.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# https://github.com/prs-eth/Marigold-DC#-citation
# https://github.com/prs-eth/rollingdepth#-citation
# --------------------------------------------------------------------------
import os
import gradio as gr
import spaces
import torch
from diffusers import MarigoldDepthPipeline, DDIMScheduler
from huggingface_hub import login
CHECKPOINT = "prs-eth/marigold-depth-v1-1"
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
)
pipe = pipe.to(device=device, dtype=dtype)
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
@spaces.GPU
def run_depth(image_in, ensemble_size, denoise_steps, processing_res):
if image_in is None:
raise gr.Error("Încarcă o imagine mai întâi.")
pipe_out = pipe(
image_in,
ensemble_size=int(ensemble_size),
num_inference_steps=int(denoise_steps),
processing_resolution=int(processing_res),
match_input_resolution=True,
batch_size=1 if int(processing_res) == 0 else 2,
output_uncertainty=int(ensemble_size) >= 3,
)
pred = pipe_out.prediction
# vizualizare normală (color)
depth_vis = pipe.image_processor.visualize_depth(pred)[0]
# fișier 16-bit real
depth_16bit_path = pipe.image_processor.export_depth_to_16bit_png(pred)[0]
# preview grayscale pentru UI
depth_np = pred[0].detach().float().cpu().numpy()
depth_np = np.squeeze(depth_np)
dmin = float(depth_np.min())
dmax = float(depth_np.max())
if dmax > dmin:
depth_preview = ((depth_np - dmin) / (dmax - dmin) * 255.0).clip(0, 255).astype(np.uint8)
else:
depth_preview = np.zeros_like(depth_np, dtype=np.uint8)
depth_preview = Image.fromarray(depth_preview, mode="L")
uncertainty = None
if int(ensemble_size) >= 3 and getattr(pipe_out, "uncertainty", None) is not None:
uncertainty = pipe.image_processor.visualize_uncertainty(pipe_out.uncertainty)[0]
return depth_vis, depth_preview, depth_16bit_path, uncertainty
with gr.Blocks(title="Marigold Depth") as demo:
gr.Markdown("# Marigold Depth")
with gr.Row():
with gr.Column():
image_in = gr.Image(type="pil", label="Input image")
ensemble_size = gr.Slider(1, 10, value=1, step=1, label="Ensemble size")
denoise_steps = gr.Slider(1, 20, value=4, step=1, label="Denoising steps")
processing_res = gr.Radio(
choices=[("Native", 0), ("Recommended", 768)],
value=768,
label="Processing resolution",
)
run_btn = gr.Button("Run")
with gr.Column():
depth_vis_out = gr.Image(type="pil", label="Depth visualization")
depth_preview_out = gr.Image(type="pil", label="Depth grayscale preview")
depth_16bit_out = gr.File(label="Depth 16-bit PNG")
uncertainty_out = gr.Image(type="pil", label="Uncertainty")
run_btn.click(
fn=run_depth,
inputs=[...],
outputs=[depth_vis_out, depth_preview_out, depth_16bit_out, uncertainty_out],
)
demo.queue(api_open=False).launch()