ZenosArrows's picture
Add looking glass and upscaling support
7ca1c9a verified
raw
history blame
10.9 kB
import glob
import gradio as gr
import numpy as np
import torch
import tempfile
import uuid
from PIL import Image, ImageOps, ImageEnhance
from pathlib import Path
from zipfile import ZipFile, is_zipfile
from pypdf import PdfReader
from depth_anything_v2.dpt import DepthAnythingV2
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
.thumbnail-item {
aspect-ratio: var(--ratio-wide)
}
.thumbnail-item img {
object-fit: contain
}
"""
head = """
<script type="module">
import { BridgeClient, RGBDHologram } from "/gradio_api/file=assets/looking-glass-bridge.js";
window.BridgeClient = BridgeClient;
window.RGBDHologram = RGBDHologram;
window.updating = false;
window.settings = {
depthiness: 1.0,
focus: 0,
aspect: 1,
chroma_depth: 0,
depth_inversion: 0,
depth_loc: 2,
depth_cutoff: 1,
zoom: 1,
crop_pos_x: 0,
crop_pos_y: 0,
};
window.castHologram = async function() {
const uri = document.querySelector('#img-display-output .thumbnail-item.selected img').src;
if (!uri)
return;
const Bridge = BridgeClient.getInstance();
if (!Bridge.isConnected)
await Bridge.connect();
await Bridge.getDisplays();
if (Bridge.isCastPending)
return;
const rgbd = new RGBDHologram({ uri, settings });
await Bridge.cast(rgbd);
};
window.updateHologram = async function(value, parameter) {
settings[parameter] = value;
const Bridge = BridgeClient.getInstance();
if (!Bridge.isConnected || window.updating)
return;
const name = Bridge.getCurrentPlaylist().name;
window.updating = true;
await Bridge.updateCurrentHologram({ name, parameter, value });
window.updating = false;
};
</script>
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
title = "# Depth Anything V2"
description = """Looking Glass demo for **Depth Anything V2**.
Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), or [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
def predict_depth(image, model):
w, h = image.size
depth = model.infer_image(np.array(image.convert("RGB"))[:, :, ::-1])
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
gray_depth = Image.fromarray(depth)
rgbd = Image.new(image.mode, (w * 2, h))
rgbd.paste(image, (0, 0))
rgbd.paste(gray_depth, (w, 0))
return rgbd
def upscale_image(image, model, background, discard_alpha):
if image.mode == "RGBA":
if discard_alpha:
image = Image.alpha_composite(ImageOps.pad(background, image.size, color=(0, 0, 0)), image);
elif image.mode != "RGB":
image = image.convert("RGB")
if model is not None:
image = model.infer(image)
return image.convert("RGB") if discard_alpha else image
def on_submit(image, batch_images, book, config, upscale_model, upscale_method, denoise_level, discard_alpha, progress=gr.Progress()):
model = DepthAnythingV2(**model_configs[config])
state_dict = torch.load(f'checkpoints/depth_anything_v2_{config}.pth', map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
superresolution = None
if upscale_method is not None:
superresolution = torch.hub.load("nagadomi/nunif:master", "waifu2x",
model_type=upscale_model, method=upscale_method, noise_level=denoise_level,
keep_alpha=not discard_alpha, trust_repo=True).to(DEVICE)
gradient = ImageEnhance.Brightness(Image.radial_gradient("L"))
background = ImageOps.invert(gradient.enhance(1.5)).convert("RGBA")
result = []
if image is not None:
image = upscale_image(image, superresolution, background, discard_alpha)
result.append((predict_depth(image, model), None))
if batch_images is not None:
for path in progress.tqdm(batch_images):
with Image.open(path) as img:
img = upscale_image(img, superresolution, background, discard_alpha)
result.append((predict_depth(img, model), Path(path).name))
if book is not None:
if is_zipfile(book):
with ZipFile(book, "r") as zf:
for entry in progress.tqdm(zf.infolist()):
with zf.open(entry) as file:
with Image.open(file) as img:
img = upscale_image(img, superresolution, background, discard_alpha)
result.append((predict_depth(img, model), entry.filename))
else:
reader = PdfReader(book)
for page in progress.tqdm(reader.pages):
for image_file_object in page.images:
img = upscale_image(image_file_object.image, superresolution, background, discard_alpha)
result.append((predict_depth(img, model), image_file_object.name))
return result
def zip_gallery(gallery, progress=gr.Progress()):
if gallery is None:
return None
if len(gallery) == 1:
return gallery[0][0]
temp = Path(tempfile.gettempdir()) / uuid.uuid4().hex
zip = temp.with_suffix(".zip")
with ZipFile(zip, "w") as zf:
for index, image in progress.tqdm(enumerate(gallery)):
fn = Path(image[0]).name if image[1] is None else Path(image[1]).with_suffix(".rgbd.png")
zf.write(image[0], "{:02d}_{}".format(index, fn))
return zip
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
with gr.Blocks(css=css, head=head) as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Tab("Single Image"):
input_image = gr.Image(
label="Input Image",
elem_id='img-display-input',
type='pil',
image_mode=None
)
with gr.Tab("Batch Mode"):
batch_images = gr.File(
label="Images",
file_types=["image"],
file_count="multiple"
)
with gr.Tab("Document Mode"):
book = gr.File(
label="Document",
file_types=[".pdf", ".zip"],
)
with gr.Row():
clear = gr.ClearButton(components=[input_image, batch_images, book])
submit = gr.Button(value="Compute Depth", variant="primary")
model_size = gr.Radio(
label="Model Size",
choices=[('Small', 'vits'), ('Base', 'vitb'), ('Large', 'vitl')],
value="vitl"
)
upscale_method = gr.Radio(
label="Upscale Method",
choices=[("No Upscaling or Denoising", None), ("Denoise Only", "noise"), ("2x Upscaling", "scale2x"), ("4x Upscaling", "scale4x")]
)
upscale_model = gr.Dropdown(
choices=["art", "art_scan", "photo", "swin_unet/art", "swin_unet/art_scan", "swin_unet/photo", "cunet/art", "upconv_7/art", "upconv_7/photo"],
label="Upscaling Model",
value="art"
)
denoise_level = gr.Slider(
label="Denoise Level (-1 = None)",
value=0,
step=1,
minimum=-1,
maximum=4
)
discard_alpha = gr.Checkbox(label="Add radial gradient background to transparent images", value=True)
with gr.Column():
gallery = gr.Gallery(
label="RGBD Images",
elem_id='img-display-output',
format="png",
columns=4,
object_fit="contain",
preview=True,
interactive=True
)
download_btn = gr.DownloadButton()
depthiness = gr.Slider(
label="Depthiness",
elem_id="depthiness",
interactive=True,
minimum=0,
maximum=3,
value=1
)
focus = gr.Slider(
label="Focus",
interactive=True,
minimum=-0.03,
maximum=0.03,
value=0
)
zoom = gr.Slider(
label="Zoom",
interactive=True,
minimum=0,
maximum=10,
value=1
)
pos_x = gr.Slider(
label="Position X",
interactive=True,
minimum=-1,
maximum=1,
value=0
)
pos_y = gr.Slider(
label="Position Y",
interactive=True,
minimum=-1,
maximum=1,
value=0
)
reset = gr.Button(value="Reset All Parameters")
gallery.select(fn=None, js="castHologram")
gallery.change(fn=zip_gallery, inputs=gallery, outputs=download_btn).then(fn=None, js="castHologram")
submit.click(
on_submit,
inputs=[input_image, batch_images, book, model_size, upscale_model, upscale_method, denoise_level, discard_alpha],
outputs=[gallery]
).then(fn=zip_gallery, inputs=gallery, outputs=download_btn).then(fn=None, js="castHologram")
depthiness.change(fn=None, inputs=depthiness, js="(value) => updateHologram (value, 'depthiness')")
focus.change(fn=None, inputs=focus, js="(value) => updateHologram (value, 'focus')")
zoom.change(fn=None, inputs=zoom, js="(value) => updateHologram (value, 'zoom')")
pos_x.change(fn=None, inputs=pos_x, js="(value) => updateHologram (value, 'crop_pos_x')")
pos_y.change(fn=None, inputs=pos_y, js="(value) => updateHologram (value, 'crop_pos_y')")
reset.click(fn=None, js="""
() => {
document.querySelectorAll('button.reset-button').forEach(b => b.click());
}
""")
example_files = glob.glob('assets/examples/*')
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[gallery], fn=on_submit)
if __name__ == '__main__':
demo.queue().launch()