niye4 commited on
Commit
9712ea3
·
verified ·
1 Parent(s): e6f85b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -53
app.py CHANGED
@@ -13,12 +13,13 @@ from depth_anything_v2.dpt import DepthAnythingV2
13
  # Configuration
14
  # -------------------
15
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
16
- CHECKPOINT = "checkpoints/depth_anything_v2_vitb.pth" # expect this to exist
17
  WORKDIR = "workspace"
18
  FRAMES_DIR = os.path.join(WORKDIR, "frames")
19
  OUT_FRAMES_DIR = os.path.join(WORKDIR, "depth_frames")
20
  RAW_FRAMES_DIR = os.path.join(WORKDIR, "raw16")
21
  OUTPUT_DIR = "output"
 
22
  os.makedirs(FRAMES_DIR, exist_ok=True)
23
  os.makedirs(OUT_FRAMES_DIR, exist_ok=True)
24
  os.makedirs(RAW_FRAMES_DIR, exist_ok=True)
@@ -27,22 +28,27 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
27
  # -------------------
28
  # Load model (vitb)
29
  # -------------------
30
- model = DepthAnythingV2(encoder='vitb', features=128, out_channels=[96,192,384,768])
 
 
 
 
31
  state_dict = torch.load(CHECKPOINT, map_location="cpu")
32
  model.load_state_dict(state_dict)
33
  model = model.to(DEVICE).eval()
34
 
 
 
 
35
  def predict_depth(frame_rgb):
36
- """Return depth map as float32 numpy array (same semantics as original app.py)."""
37
  return model.infer_image(frame_rgb).astype(np.float32)
38
 
39
  def depth_to_gray8(depth):
40
- """Normalize depth to 0-255 uint8 for preview (same formula as original app.py)."""
41
  dmin, dmax = float(depth.min()), float(depth.max())
42
  if dmax - dmin < 1e-8:
43
- norm = np.zeros_like(depth, dtype=np.uint8)
44
- else:
45
- norm = ((depth - dmin) / (dmax - dmin) * 255.0).astype(np.uint8)
46
  return norm
47
 
48
  def clear_workspace():
@@ -52,104 +58,93 @@ def clear_workspace():
52
  os.makedirs(RAW_FRAMES_DIR, exist_ok=True)
53
 
54
  # -------------------
55
- # Main pipeline
56
  # -------------------
57
  def process_video(video_file):
58
- """
59
- 1) ffmpeg extract frames -> workspace/frames/frame_000001.png ...
60
- 2) for each frame:
61
- - run model.infer_image on RGB frame
62
- - save raw 16-bit PNG to workspace/raw16/frame_XXXXXX.png
63
- - save normalized 8-bit PNG to workspace/depth_frames/frame_XXXXXX.png
64
- 3) ffmpeg merge workspace/depth_frames/frame_%06d.png -> output MP4 (same fps)
65
- Returns: list of preview PIL images (sampled) and output video path
66
- """
67
  clear_workspace()
68
 
69
- # copy input to workspace (avoids /tmp path issues)
70
  in_path = os.path.join(WORKDIR, "input.mp4")
71
  shutil.copy(video_file.name, in_path)
72
 
73
- # read fps and ensure video can be opened
74
  cap = cv2.VideoCapture(in_path)
75
  if not cap.isOpened():
76
  raise RuntimeError("Cannot open uploaded video.")
77
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
78
  cap.release()
79
 
80
- # 1) extract frames with ffmpeg as PNG (lossless)
81
- extract_cmd = [
82
  "ffmpeg", "-y",
83
  "-i", in_path,
84
  os.path.join(FRAMES_DIR, "frame_%06d.png")
85
- ]
86
- subprocess.run(extract_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
87
 
88
  frame_files = sorted(os.listdir(FRAMES_DIR))
89
  if len(frame_files) == 0:
90
- raise RuntimeError("No frames extracted from video (ffmpeg step failed).")
91
 
92
  preview_images = []
93
  total = len(frame_files)
94
- sample_step = max(1, total // 20) # max ~20 preview images
95
 
96
- # 2) run per-frame inference and save raw16 + preview8
97
  for i, fname in enumerate(frame_files):
98
  fp = os.path.join(FRAMES_DIR, fname)
99
- # read frame (BGR) -> convert to RGB for model
100
  bgr = cv2.imread(fp, cv2.IMREAD_COLOR)
101
- if bgr is None:
102
- raise RuntimeError(f"Failed to read extracted frame {fp}")
103
  rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
104
 
105
- # predict
106
- depth = predict_depth(rgb) # float32 numpy
107
 
108
- # save raw 16-bit (like app.py original: uint16 PNG)
109
  raw16 = depth.astype(np.uint16)
110
- raw_out_path = os.path.join(RAW_FRAMES_DIR, fname)
111
- Image.fromarray(raw16).save(raw_out_path) # PIL will write 16-bit PNG
112
 
113
- # normalized 8-bit preview (exact same normalization as original app)
114
  gray8 = depth_to_gray8(depth)
115
- preview_out_path = os.path.join(OUT_FRAMES_DIR, fname)
116
- Image.fromarray(gray8).save(preview_out_path)
117
 
118
- # sample for gallery preview
119
  if i % sample_step == 0:
120
  preview_images.append(Image.fromarray(gray8))
121
 
122
- # 3) merge preview frames into MP4 using ffmpeg (use libx264 for compatibility)
123
- out_video = os.path.join(OUTPUT_DIR, os.path.basename(video_file.name).replace(".mp4","_depth.mp4"))
124
- merge_cmd = [
 
 
 
 
125
  "ffmpeg", "-y",
126
  "-framerate", str(fps),
127
  "-i", os.path.join(OUT_FRAMES_DIR, "frame_%06d.png"),
128
  "-c:v", "libx264",
129
  "-pix_fmt", "yuv420p",
130
  out_video
131
- ]
132
- subprocess.run(merge_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
133
 
134
- # return gallery (list of PIL images) and output path
135
  return preview_images, out_video
136
 
137
  # -------------------
138
- # Gradio UI
139
  # -------------------
140
  with gr.Blocks() as demo:
141
- gr.Markdown("# Depth Anything V2 Video (framewise, app.py-style, high-quality)")
142
  gr.Markdown(
143
- "This pipeline extracts frames with ffmpeg, runs the original DepthAnythingV2 per-frame inference, "
144
- "saves raw 16-bit PNGs (in workspace/raw16) and normalized 8-bit PNG previews (in workspace/depth_frames), "
145
- "then merges previews into an MP4 at the original FPS. This reproduces the app.py image quality for video."
146
  )
147
 
148
- video_in = gr.File(label="Upload MP4", file_types=[".mp4"])
149
- gallery = gr.Gallery(label="Preview frames (sampled)").style(grid=5)
150
- out_video = gr.Video(label="Depthmap Video (downloadable)")
 
 
 
 
151
 
152
- btn = gr.Button("Render (framewise, high-quality)")
153
  btn.click(process_video, inputs=[video_in], outputs=[gallery, out_video])
154
 
155
  if __name__ == "__main__":
 
13
  # Configuration
14
  # -------------------
15
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ CHECKPOINT = "checkpoints/depth_anything_v2_vitb.pth" # vitb only
17
  WORKDIR = "workspace"
18
  FRAMES_DIR = os.path.join(WORKDIR, "frames")
19
  OUT_FRAMES_DIR = os.path.join(WORKDIR, "depth_frames")
20
  RAW_FRAMES_DIR = os.path.join(WORKDIR, "raw16")
21
  OUTPUT_DIR = "output"
22
+
23
  os.makedirs(FRAMES_DIR, exist_ok=True)
24
  os.makedirs(OUT_FRAMES_DIR, exist_ok=True)
25
  os.makedirs(RAW_FRAMES_DIR, exist_ok=True)
 
28
  # -------------------
29
  # Load model (vitb)
30
  # -------------------
31
+ model = DepthAnythingV2(
32
+ encoder='vitb',
33
+ features=128,
34
+ out_channels=[96, 192, 384, 768]
35
+ )
36
  state_dict = torch.load(CHECKPOINT, map_location="cpu")
37
  model.load_state_dict(state_dict)
38
  model = model.to(DEVICE).eval()
39
 
40
+ # -------------------
41
+ # Depth functions
42
+ # -------------------
43
  def predict_depth(frame_rgb):
44
+ """Return depth map float32 like original image app.py."""
45
  return model.infer_image(frame_rgb).astype(np.float32)
46
 
47
  def depth_to_gray8(depth):
 
48
  dmin, dmax = float(depth.min()), float(depth.max())
49
  if dmax - dmin < 1e-8:
50
+ return np.zeros_like(depth, dtype=np.uint8)
51
+ norm = ((depth - dmin) / (dmax - dmin) * 255.0).astype(np.uint8)
 
52
  return norm
53
 
54
  def clear_workspace():
 
58
  os.makedirs(RAW_FRAMES_DIR, exist_ok=True)
59
 
60
  # -------------------
61
+ # Main Processing
62
  # -------------------
63
  def process_video(video_file):
64
+ """Extract → Infer each frame → Save → Merge → Return MP4 + preview frames."""
 
 
 
 
 
 
 
 
65
  clear_workspace()
66
 
67
+ # Copy video to workspace
68
  in_path = os.path.join(WORKDIR, "input.mp4")
69
  shutil.copy(video_file.name, in_path)
70
 
71
+ # Read FPS
72
  cap = cv2.VideoCapture(in_path)
73
  if not cap.isOpened():
74
  raise RuntimeError("Cannot open uploaded video.")
75
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
76
  cap.release()
77
 
78
+ # Extract PNG frames (lossless)
79
+ subprocess.run([
80
  "ffmpeg", "-y",
81
  "-i", in_path,
82
  os.path.join(FRAMES_DIR, "frame_%06d.png")
83
+ ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
 
84
 
85
  frame_files = sorted(os.listdir(FRAMES_DIR))
86
  if len(frame_files) == 0:
87
+ raise RuntimeError("No frames extracted.")
88
 
89
  preview_images = []
90
  total = len(frame_files)
91
+ sample_step = max(1, total // 20)
92
 
93
+ # Process frames
94
  for i, fname in enumerate(frame_files):
95
  fp = os.path.join(FRAMES_DIR, fname)
 
96
  bgr = cv2.imread(fp, cv2.IMREAD_COLOR)
 
 
97
  rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
98
 
99
+ depth = predict_depth(rgb)
 
100
 
101
+ # Save raw 16-bit PNG
102
  raw16 = depth.astype(np.uint16)
103
+ Image.fromarray(raw16).save(os.path.join(RAW_FRAMES_DIR, fname))
 
104
 
105
+ # Save normalized grayscale preview
106
  gray8 = depth_to_gray8(depth)
107
+ Image.fromarray(gray8).save(os.path.join(OUT_FRAMES_DIR, fname))
 
108
 
 
109
  if i % sample_step == 0:
110
  preview_images.append(Image.fromarray(gray8))
111
 
112
+ # Merge video using ffmpeg
113
+ out_video = os.path.join(
114
+ OUTPUT_DIR,
115
+ os.path.basename(video_file.name).replace(".mp4", "_depth.mp4")
116
+ )
117
+
118
+ subprocess.run([
119
  "ffmpeg", "-y",
120
  "-framerate", str(fps),
121
  "-i", os.path.join(OUT_FRAMES_DIR, "frame_%06d.png"),
122
  "-c:v", "libx264",
123
  "-pix_fmt", "yuv420p",
124
  out_video
125
+ ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
 
126
 
 
127
  return preview_images, out_video
128
 
129
  # -------------------
130
+ # UI
131
  # -------------------
132
  with gr.Blocks() as demo:
133
+ gr.Markdown("# Depth Anything V2 High-Quality Video Depth (Frame-wise)")
134
  gr.Markdown(
135
+ "This reproduces the **exact image quality** of the official Depth Anything V2 app.py, "
136
+ "but applied **frame-by-frame** to video using ffmpeg for perfect sharpness."
 
137
  )
138
 
139
+ video_in = gr.File(label="Upload a video (mp4)", file_types=[".mp4"])
140
+ gallery = gr.Gallery(
141
+ label="Preview Depth Frames",
142
+ columns=5,
143
+ height="auto"
144
+ )
145
+ out_video = gr.Video(label="Depthmap Video Output")
146
 
147
+ btn = gr.Button("Render High-Quality Depth Video")
148
  btn.click(process_video, inputs=[video_in], outputs=[gallery, out_video])
149
 
150
  if __name__ == "__main__":