niye4 commited on
Commit
fdc64d6
·
verified ·
1 Parent(s): 73ad184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -46
app.py CHANGED
@@ -1,46 +1,82 @@
1
  import os
 
2
  import cv2
 
3
  import torch
4
- import tempfile
5
- import shutil
6
- import subprocess
7
  from PIL import Image
8
  import gradio as gr
9
  from gradio_imageslider import ImageSlider
10
  from depth_anything_v2.dpt import DepthAnythingV2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
12
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
 
14
- # Model config (vitb local)
15
  model_configs = {
16
- 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}
 
 
 
17
  }
18
- encoder = 'vitb'
 
19
  model = DepthAnythingV2(**model_configs[encoder])
20
- checkpoint_path = f"checkpoints/depth_anything_v2_{encoder}.pth"
21
- state_dict = torch.load(checkpoint_path, map_location="cpu")
22
  model.load_state_dict(state_dict)
23
  model = model.to(DEVICE).eval()
24
 
 
 
 
25
  def predict_depth(frame_rgb):
26
  return model.infer_image(frame_rgb)
27
 
 
 
 
28
  def process_video(video_file):
29
- temp_dir = tempfile.mkdtemp()
30
- cap = cv2.VideoCapture(video_file.name)
 
 
 
 
31
  if not cap.isOpened() or int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) == 0:
32
  raise RuntimeError("Cannot open video or empty video file.")
33
 
34
  fps = cap.get(cv2.CAP_PROP_FPS)
35
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
36
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
37
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
38
 
 
 
 
 
 
39
  slider_frames = []
40
- max_slider = 30
41
- step = max(1, total_frames // max_slider)
 
42
  idx = 0
43
- frame_idx = 0
44
 
45
  while True:
46
  ret, frame = cap.read()
@@ -48,48 +84,31 @@ def process_video(video_file):
48
  break
49
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
50
  depth_map = predict_depth(frame_rgb)
 
 
 
51
 
52
- # Normalize to 0-255 grayscale
53
- depth_gray = ((depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0).astype('uint8')
54
- img = Image.fromarray(depth_gray)
55
- frame_path = os.path.join(temp_dir, f"{frame_idx:05d}.png")
56
- img.save(frame_path)
57
- frame_idx += 1
58
-
59
- # Slider preview
60
  if idx % step == 0:
61
- slider_frames.append(img)
62
  idx += 1
63
 
64
  cap.release()
65
-
66
- # Output MP4 path
67
- output_dir = "output"
68
- os.makedirs(output_dir, exist_ok=True)
69
- output_video = os.path.join(output_dir, os.path.basename(video_file.name).replace(".mp4","_depth.mp4"))
70
-
71
- # FFmpeg encode PNG sequence → MP4, keep FPS & resolution
72
- cmd = [
73
- "ffmpeg",
74
- "-y",
75
- "-framerate", str(fps),
76
- "-i", os.path.join(temp_dir, "%05d.png"),
77
- "-c:v", "libx264",
78
- "-pix_fmt", "yuv420p",
79
- output_video
80
- ]
81
- subprocess.run(cmd, check=True)
82
-
83
- shutil.rmtree(temp_dir)
84
  return slider_frames, output_video
85
 
 
 
 
86
  with gr.Blocks() as demo:
87
- gr.Markdown("# Depth Anything V2 – Grayscale Video")
88
- gr.Markdown("Upload a video and get a grayscale DepthMap video at original resolution & FPS.")
89
- video_input = gr.File(label="Upload MP4", file_types=[".mp4"])
 
90
  depth_slider = ImageSlider(label="DepthMap Slider Preview")
91
  video_output = gr.Video(label="DepthMap Video")
92
  submit = gr.Button("Render DepthMap")
 
93
  submit.click(fn=process_video, inputs=[video_input], outputs=[depth_slider, video_output])
94
 
95
  if __name__ == "__main__":
 
1
  import os
2
+ import shutil
3
  import cv2
4
+ import numpy as np
5
  import torch
 
 
 
6
  from PIL import Image
7
  import gradio as gr
8
  from gradio_imageslider import ImageSlider
9
  from depth_anything_v2.dpt import DepthAnythingV2
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ # ===============================
13
+ # Auto-download checkpoint if missing
14
+ # ===============================
15
+ MODEL_PATH = "checkpoints/depth_anything_v2_vitl.pth"
16
+
17
+ if not os.path.exists(MODEL_PATH):
18
+ print("Downloading Depth Anything V2 model (~1.3GB), please wait 1-3 minutes...")
19
+ hf_hub_download(
20
+ repo_id="niye4/depthmap-checkpoints", # Repo containing checkpoint
21
+ filename="depth_anything_v2_vitl.pth", # Actual filename
22
+ local_dir="checkpoints",
23
+ local_dir_use_symlinks=False
24
+ )
25
+ print("Model download complete! Starting the app...")
26
+ else:
27
+ print("Model already exists, starting the app immediately!")
28
 
29
+ # ===============================
30
+ # Device and Model Setup
31
+ # ===============================
32
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
33
 
 
34
  model_configs = {
35
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
36
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
37
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
38
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
39
  }
40
+
41
+ encoder = 'vitl'
42
  model = DepthAnythingV2(**model_configs[encoder])
43
+ state_dict = torch.load(MODEL_PATH, map_location="cpu")
 
44
  model.load_state_dict(state_dict)
45
  model = model.to(DEVICE).eval()
46
 
47
+ # ===============================
48
+ # Depth prediction for one frame
49
+ # ===============================
50
  def predict_depth(frame_rgb):
51
  return model.infer_image(frame_rgb)
52
 
53
+ # ===============================
54
+ # Process video
55
+ # ===============================
56
  def process_video(video_file):
57
+ OUTPUT_DIR = "output"
58
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
59
+ video_path = os.path.join(OUTPUT_DIR, os.path.basename(video_file.name))
60
+ shutil.copy(video_file.name, video_path)
61
+
62
+ cap = cv2.VideoCapture(video_path)
63
  if not cap.isOpened() or int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) == 0:
64
  raise RuntimeError("Cannot open video or empty video file.")
65
 
66
  fps = cap.get(cv2.CAP_PROP_FPS)
67
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
68
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
69
 
70
+ output_video = os.path.join(OUTPUT_DIR, os.path.basename(video_path).replace(".mp4","_depth.mp4"))
71
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
72
+ out = cv2.VideoWriter(output_video, fourcc, fps, (width,height), isColor=True)
73
+
74
+ # Prepare slider preview frames
75
  slider_frames = []
76
+ max_slider_frames = 30
77
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
78
+ step = max(1, total_frames // max_slider_frames)
79
  idx = 0
 
80
 
81
  while True:
82
  ret, frame = cap.read()
 
84
  break
85
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
  depth_map = predict_depth(frame_rgb)
87
+ depth_gray = ((depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0).astype(np.uint8)
88
+ depth_rgb = cv2.cvtColor(depth_gray, cv2.COLOR_GRAY2BGR)
89
+ out.write(depth_rgb)
90
 
91
+ # Add frame to slider preview
 
 
 
 
 
 
 
92
  if idx % step == 0:
93
+ slider_frames.append(Image.fromarray(depth_gray))
94
  idx += 1
95
 
96
  cap.release()
97
+ out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return slider_frames, output_video
99
 
100
+ # ===============================
101
+ # Gradio Interface
102
+ # ===============================
103
  with gr.Blocks() as demo:
104
+ gr.Markdown("# Depth Anything V2 – Grayscale Video (vitl)")
105
+ gr.Markdown("Upload a video and get a grayscale DepthMap video at original resolution and FPS.")
106
+
107
+ video_input = gr.File(label="Upload MP4", file_types=['.mp4'])
108
  depth_slider = ImageSlider(label="DepthMap Slider Preview")
109
  video_output = gr.Video(label="DepthMap Video")
110
  submit = gr.Button("Render DepthMap")
111
+
112
  submit.click(fn=process_video, inputs=[video_input], outputs=[depth_slider, video_output])
113
 
114
  if __name__ == "__main__":