mmmno commited on
Commit
8718a23
·
verified ·
1 Parent(s): 63c21cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -44
app.py CHANGED
@@ -7,7 +7,7 @@ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
7
  import tempfile
8
  import os
9
 
10
- # --- 1. SETTINGS & MODEL ---
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
  CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf"
13
 
@@ -18,7 +18,11 @@ def process_to_3d(input_image):
18
  if input_image is None:
19
  return None, None
20
 
21
- # --- 2. DEPTH ESTIMATION ---
 
 
 
 
22
  inputs = processor(images=input_image, return_tensors="pt").to(DEVICE)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
@@ -28,68 +32,61 @@ def process_to_3d(input_image):
28
  mode="bicubic",
29
  ).squeeze().cpu().numpy()
30
 
31
- # --- 3. POINT CLOUD PROJECTION ---
32
  width, height = input_image.size
33
- rgb = np.array(input_image)
34
- x, y = np.meshgrid(np.arange(width), np.arange(height))
35
-
36
- # Scale depth (Z-axis) for a clean 3D range
37
- z = (depth / depth.max()) * 10.0
38
 
39
- # Projection math
40
- focal_length = width
41
- x_coords = (x - width / 2) * z / focal_length
42
- y_coords = (y - height / 2) * z / focal_length
43
 
44
- points = np.stack((x_coords, y_coords, z), axis=-1).reshape(-1, 3)
45
- colors = rgb.reshape(-1, 3) / 255.0
 
 
 
 
 
46
 
47
- # --- 4. CENTERING & VOXELIZATION ---
48
  pcd = o3d.geometry.PointCloud()
49
  pcd.points = o3d.utility.Vector3dVector(points)
50
- pcd.colors = o3d.utility.Vector3dVector(colors)
51
 
52
- # Centering: Critical for the camera to lock onto the model
53
- center = pcd.get_center()
54
- pcd.translate(-center)
55
 
56
- # Voxelization: Merges points into larger "splats" for solid visibility
57
- pcd = pcd.voxel_down_sample(voxel_size=0.04)
58
 
59
- # --- 5. EXPORT AS .PLY ---
60
  temp_dir = tempfile.gettempdir()
61
  output_path = os.path.join(temp_dir, "model.ply")
62
 
63
- # write_ascii=False saves it in Binary format (required for fast web loading)
64
  o3d.io.write_point_cloud(output_path, pcd, write_ascii=False)
65
 
66
  return output_path, output_path
67
 
68
- # --- 6. GRADIO UI ---
69
- with gr.Blocks(theme=gr.themes.Default()) as demo:
70
- gr.Markdown("# 🌊 Depth Anything Splat Creator")
71
 
72
  with gr.Row():
73
- with gr.Column(scale=1):
74
- img_input = gr.Image(type="pil", label="Input Image")
75
- run_btn = gr.Button("🔨 Generate .PLY Splat", variant="primary")
76
 
77
- with gr.Column(scale=2):
78
- view_3d = gr.Model3D(
 
79
  label="3D Viewport",
80
- display_mode="solid", # Renders PLY points as Gaussians
81
- camera_position=(0, 90, 15),
82
- clear_color=(0.0, 0.0, 0.0, 1.0)
83
  )
84
- # Explicitly set the download button
85
- dl_btn = gr.DownloadButton("💾 Download .PLY File")
86
 
87
- # Link the logic
88
- run_btn.click(
89
- fn=process_to_3d,
90
- inputs=[img_input],
91
- outputs=[view_3d, dl_btn]
92
- )
93
 
94
- if __name__ == "__main__":
95
- demo.launch()
 
7
  import tempfile
8
  import os
9
 
10
+ # --- 1. SETTINGS ---
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
  CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf"
13
 
 
18
  if input_image is None:
19
  return None, None
20
 
21
+ # Resize image to a manageable size for 3D viewing if too large
22
+ if max(input_image.size) > 1024:
23
+ input_image.thumbnail((1024, 1024))
24
+
25
+ # --- 2. DEPTH INFERENCE ---
26
  inputs = processor(images=input_image, return_tensors="pt").to(DEVICE)
27
  with torch.no_grad():
28
  outputs = model(**inputs)
 
32
  mode="bicubic",
33
  ).squeeze().cpu().numpy()
34
 
35
+ # --- 3. COLOR & COORDINATE CALCULATION ---
36
  width, height = input_image.size
37
+ rgb = np.array(input_image).reshape(-1, 3) / 255.0 # Normalize to 0-1 for O3D
 
 
 
 
38
 
39
+ # Create normalized grid
40
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
 
 
41
 
42
+ # Flatten and project to 3D
43
+ # Scale depth (z) significantly down so it doesn't "stretch" too far back
44
+ z = (depth.flatten() / depth.max()) * 5.0
45
+ x = (x.flatten() - width / 2) / (width / 5.0)
46
+ y = (height / 2 - y.flatten()) / (height / 5.0) # Invert Y for correct orientation
47
+
48
+ points = np.stack((x, y, z), axis=-1)
49
 
50
+ # --- 4. THE SPLAT TRICK (OPEN3D) ---
51
  pcd = o3d.geometry.PointCloud()
52
  pcd.points = o3d.utility.Vector3dVector(points)
53
+ pcd.colors = o3d.utility.Vector3dVector(rgb)
54
 
55
+ # RE-CENTER: This is the fix for the "Blank Viewer"
56
+ # It ensures the model is exactly at 0,0,0
57
+ pcd.translate(-pcd.get_center())
58
 
59
+ # DENSITY: Downsample to make points "thicker" and load faster
60
+ pcd = pcd.voxel_down_sample(voxel_size=0.02)
61
 
62
+ # --- 5. EXPORT ---
63
  temp_dir = tempfile.gettempdir()
64
  output_path = os.path.join(temp_dir, "model.ply")
65
 
66
+ # write_ascii=False is required for Binary PLY (Colors work best here)
67
  o3d.io.write_point_cloud(output_path, pcd, write_ascii=False)
68
 
69
  return output_path, output_path
70
 
71
+ # --- 6. UI ---
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("## 🪐 3D Splat View (Color-Matched)")
74
 
75
  with gr.Row():
76
+ with gr.Column():
77
+ img_in = gr.Image(type="pil", label="Upload Photo")
78
+ btn = gr.Button("🔨 Generate 3D", variant="primary")
79
 
80
+ with gr.Column():
81
+ # radius=10 starts the camera at the perfect zoom level
82
+ v3d = gr.Model3D(
83
  label="3D Viewport",
84
+ display_mode="solid",
85
+ camera_position=(0, 90, 10),
86
+ clear_color=(0.08, 0.08, 0.08, 1.0)
87
  )
88
+ dl = gr.DownloadButton("💾 Download .PLY")
 
89
 
90
+ btn.click(fn=process_to_3d, inputs=[img_in], outputs=[v3d, dl])
 
 
 
 
 
91
 
92
+ demo.launch()