niye4 commited on
Commit
9eeb5d9
·
verified ·
1 Parent(s): 3879970

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import shutil
3
  import cv2
4
  import numpy as np
5
  import torch
@@ -7,13 +6,14 @@ from PIL import Image
7
  import gradio as gr
8
  from gradio_imageslider import ImageSlider
9
  from depth_anything_v2.dpt import DepthAnythingV2
 
 
10
 
11
  # ===============================
12
  # Device & Model
13
  # ===============================
14
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
 
16
- # Only vitb
17
  MODEL_PATH = "checkpoints/depth_anything_v2_vitb.pth"
18
  model = DepthAnythingV2(encoder='vitb', features=128, out_channels=[96,192,384,768])
19
  state_dict = torch.load(MODEL_PATH, map_location="cpu")
@@ -24,20 +24,35 @@ model = model.to(DEVICE).eval()
24
  # Predict depth for single frame
25
  # ===============================
26
  def predict_depth(frame_rgb):
27
- return model.infer_image(frame_rgb)
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # ===============================
30
  # Process video
31
  # ===============================
32
  def process_video(video_file):
33
  """
34
- Render a grayscale DepthMap video from uploaded MP4.
35
- Only vitb model, fast, high quality for vitb.
36
  """
37
  OUTPUT_DIR = "output"
38
  os.makedirs(OUTPUT_DIR, exist_ok=True)
39
 
40
  video_path = os.path.join(OUTPUT_DIR, os.path.basename(video_file.name))
 
 
41
  shutil.copy(video_file.name, video_path)
42
 
43
  cap = cv2.VideoCapture(video_path)
@@ -48,10 +63,12 @@ def process_video(video_file):
48
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
49
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
50
 
 
51
  output_video_path = os.path.join(OUTPUT_DIR, os.path.basename(video_path).replace(".mp4","_depth.mp4"))
52
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # browser-compatible MP4
53
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width,height), isColor=True)
54
 
 
55
  slider_frames = []
56
  max_slider_frames = 30
57
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -66,17 +83,13 @@ def process_video(video_file):
66
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
67
  depth_map = predict_depth(frame_rgb)
68
 
69
- # Keep 16-bit depth for video output (preserve details)
70
- depth_16bit = depth_map.astype(np.uint16)
71
-
72
- # Scale for preview slider (8-bit)
73
- depth_8bit = ((depth_16bit / depth_16bit.max()) * 255).astype(np.uint8)
74
- depth_rgb_preview = cv2.cvtColor(depth_8bit, cv2.COLOR_GRAY2BGR)
75
- out.write(depth_rgb_preview)
76
 
77
- # Add sampled frames for slider
78
  if idx % step == 0:
79
- slider_frames.append(Image.fromarray(depth_8bit))
80
  idx += 1
81
 
82
  cap.release()
@@ -87,10 +100,11 @@ def process_video(video_file):
87
  # Gradio Interface
88
  # ===============================
89
  with gr.Blocks() as demo:
90
- gr.Markdown("# Depth Anything V2 – Grayscale Video (vitb)")
91
  gr.Markdown(
92
- "Upload an MP4 video and generate a grayscale DepthMap video.\n\n"
93
- "**Model:** vitb – fast and high quality for real-time preview."
 
94
  )
95
 
96
  video_input = gr.File(label="Upload MP4", file_types=['.mp4'])
 
1
  import os
 
2
  import cv2
3
  import numpy as np
4
  import torch
 
6
  import gradio as gr
7
  from gradio_imageslider import ImageSlider
8
  from depth_anything_v2.dpt import DepthAnythingV2
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib
11
 
12
  # ===============================
13
  # Device & Model
14
  # ===============================
15
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
16
 
 
17
  MODEL_PATH = "checkpoints/depth_anything_v2_vitb.pth"
18
  model = DepthAnythingV2(encoder='vitb', features=128, out_channels=[96,192,384,768])
19
  state_dict = torch.load(MODEL_PATH, map_location="cpu")
 
24
  # Predict depth for single frame
25
  # ===============================
26
  def predict_depth(frame_rgb):
27
+ """Return depth map as float32"""
28
+ depth = model.infer_image(frame_rgb)
29
+ return depth.astype(np.float32)
30
+
31
+ # ===============================
32
+ # Colormap for preview
33
+ # ===============================
34
+ cmap = matplotlib.cm.get_cmap('magma') # nice perceptual colormap
35
+
36
+ def apply_colormap(depth):
37
+ """Scale depth to 0-1 and apply colormap, return uint8 RGB"""
38
+ norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
39
+ colored = (cmap(norm)[:, :, :3] * 255).astype(np.uint8)
40
+ return colored
41
 
42
  # ===============================
43
  # Process video
44
  # ===============================
45
  def process_video(video_file):
46
  """
47
+ Render depthmap video with colormap.
48
+ Keep original resolution & FPS.
49
  """
50
  OUTPUT_DIR = "output"
51
  os.makedirs(OUTPUT_DIR, exist_ok=True)
52
 
53
  video_path = os.path.join(OUTPUT_DIR, os.path.basename(video_file.name))
54
+ # Copy input video
55
+ import shutil
56
  shutil.copy(video_file.name, video_path)
57
 
58
  cap = cv2.VideoCapture(video_path)
 
63
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
64
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
65
 
66
+ # Video output path
67
  output_video_path = os.path.join(OUTPUT_DIR, os.path.basename(video_path).replace(".mp4","_depth.mp4"))
68
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
69
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width,height), isColor=True)
70
 
71
+ # Slider preview (sample frames)
72
  slider_frames = []
73
  max_slider_frames = 30
74
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
83
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
84
  depth_map = predict_depth(frame_rgb)
85
 
86
+ # Apply colormap for video output
87
+ colored_frame = apply_colormap(depth_map)
88
+ out.write(cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR))
 
 
 
 
89
 
90
+ # Add sampled frames for slider preview
91
  if idx % step == 0:
92
+ slider_frames.append(Image.fromarray(colored_frame))
93
  idx += 1
94
 
95
  cap.release()
 
100
  # Gradio Interface
101
  # ===============================
102
  with gr.Blocks() as demo:
103
+ gr.Markdown("# Depth Anything V2 – Depth Video (vitb)")
104
  gr.Markdown(
105
+ "Upload an MP4 video to generate a **colored DepthMap video**.\n\n"
106
+ "**Model:** vitb – fast and high quality for real-time processing.\n"
107
+ "Resolution and FPS are preserved from the original video."
108
  )
109
 
110
  video_input = gr.File(label="Upload MP4", file_types=['.mp4'])