depthmap / app.py
niye4's picture
Update app.py
7b3e1ce verified
# app.py
import os
import shutil
import subprocess
import cv2
import numpy as np
import torch
from PIL import Image
import gradio as gr
from depth_anything_v2.dpt import DepthAnythingV2
# -------------------
# Configuration
# -------------------
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT = "checkpoints/depth_anything_v2_vitb.pth" # vitb only
WORKDIR = "workspace"
FRAMES_DIR = os.path.join(WORKDIR, "frames")
OUT_FRAMES_DIR = os.path.join(WORKDIR, "depth_frames")
RAW_FRAMES_DIR = os.path.join(WORKDIR, "raw16")
OUTPUT_DIR = "output"
os.makedirs(FRAMES_DIR, exist_ok=True)
os.makedirs(OUT_FRAMES_DIR, exist_ok=True)
os.makedirs(RAW_FRAMES_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# -------------------
# Load model (vitb)
# -------------------
model = DepthAnythingV2(
encoder='vitb',
features=128,
out_channels=[96, 192, 384, 768]
)
state_dict = torch.load(CHECKPOINT, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
# -------------------
# Depth functions
# -------------------
def predict_depth(frame_rgb):
"""Return depth map float32 like original image app.py."""
return model.infer_image(frame_rgb).astype(np.float32)
def depth_to_gray8(depth):
dmin, dmax = float(depth.min()), float(depth.max())
if dmax - dmin < 1e-8:
return np.zeros_like(depth, dtype=np.uint8)
norm = ((depth - dmin) / (dmax - dmin) * 255.0).astype(np.uint8)
return norm
def clear_workspace():
shutil.rmtree(WORKDIR, ignore_errors=True)
os.makedirs(FRAMES_DIR, exist_ok=True)
os.makedirs(OUT_FRAMES_DIR, exist_ok=True)
os.makedirs(RAW_FRAMES_DIR, exist_ok=True)
# -------------------
# Main Processing
# -------------------
def process_video(video_file):
"""Extract β†’ Infer each frame β†’ Save β†’ Merge β†’ Return MP4 + preview frames."""
clear_workspace()
# Copy video to workspace
in_path = os.path.join(WORKDIR, "input.mp4")
shutil.copy(video_file.name, in_path)
# Read FPS
cap = cv2.VideoCapture(in_path)
if not cap.isOpened():
raise RuntimeError("Cannot open uploaded video.")
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
cap.release()
# Extract PNG frames (lossless)
subprocess.run([
"ffmpeg", "-y",
"-i", in_path,
os.path.join(FRAMES_DIR, "frame_%06d.png")
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
frame_files = sorted(os.listdir(FRAMES_DIR))
if len(frame_files) == 0:
raise RuntimeError("No frames extracted.")
preview_images = []
total = len(frame_files)
sample_step = max(1, total // 20)
# Process frames
for i, fname in enumerate(frame_files):
fp = os.path.join(FRAMES_DIR, fname)
bgr = cv2.imread(fp, cv2.IMREAD_COLOR)
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
depth = predict_depth(rgb)
# Save raw 16-bit PNG
raw16 = depth.astype(np.uint16)
Image.fromarray(raw16).save(os.path.join(RAW_FRAMES_DIR, fname))
# Save normalized grayscale preview
gray8 = depth_to_gray8(depth)
Image.fromarray(gray8).save(os.path.join(OUT_FRAMES_DIR, fname))
if i % sample_step == 0:
preview_images.append(Image.fromarray(gray8))
# Merge video using ffmpeg
out_video = os.path.join(
OUTPUT_DIR,
os.path.basename(video_file.name).replace(".mp4", "_depth.mp4")
)
subprocess.run([
"ffmpeg", "-y",
"-framerate", str(fps),
"-i", os.path.join(OUT_FRAMES_DIR, "frame_%06d.png"),
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
out_video
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return preview_images, out_video
# -------------------
# UI
# -------------------
with gr.Blocks() as demo:
gr.Markdown("# Depth Anything V2 ")
gr.Markdown(
"https://github.com/DepthAnything/Depth-Anything-V2 "
)
video_in = gr.File(label="Upload a video (mp4)", file_types=[".mp4"])
gallery = gr.Gallery(
label="Preview Depth Frames",
columns=5,
height="auto"
)
out_video = gr.Video(label="Depthmap Video Output")
btn = gr.Button("Render High-Quality Depth Video")
btn.click(process_video, inputs=[video_in], outputs=[gallery, out_video])
if __name__ == "__main__":
demo.queue().launch()