mmmno commited on
Commit
63c21cb
·
verified ·
1 Parent(s): c927e0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -9,7 +9,6 @@ import os
9
 
10
  # --- 1. SETTINGS & MODEL ---
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
- # Using Depth Anything V2 for maximum compatibility
13
  CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf"
14
 
15
  processor = AutoImageProcessor.from_pretrained(CHECKPOINT)
@@ -23,7 +22,6 @@ def process_to_3d(input_image):
23
  inputs = processor(images=input_image, return_tensors="pt").to(DEVICE)
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
- # Resize depth map to match original image resolution
27
  depth = torch.nn.functional.interpolate(
28
  outputs.predicted_depth.unsqueeze(1),
29
  size=input_image.size[::-1],
@@ -35,10 +33,10 @@ def process_to_3d(input_image):
35
  rgb = np.array(input_image)
36
  x, y = np.meshgrid(np.arange(width), np.arange(height))
37
 
38
- # Scale depth to a standard 3D unit range
39
  z = (depth / depth.max()) * 10.0
40
 
41
- # Projection math (pinhole camera model)
42
  focal_length = width
43
  x_coords = (x - width / 2) * z / focal_length
44
  y_coords = (y - height / 2) * z / focal_length
@@ -46,50 +44,47 @@ def process_to_3d(input_image):
46
  points = np.stack((x_coords, y_coords, z), axis=-1).reshape(-1, 3)
47
  colors = rgb.reshape(-1, 3) / 255.0
48
 
49
- # --- 4. THE SPLAT TRICK (Open3D) ---
50
  pcd = o3d.geometry.PointCloud()
51
  pcd.points = o3d.utility.Vector3dVector(points)
52
  pcd.colors = o3d.utility.Vector3dVector(colors)
53
 
54
- # Centering: Move the model so its 3D center is at (0, 0, 0)
55
- # This ensures the camera rotates around the object, not the corner.
56
  center = pcd.get_center()
57
  pcd.translate(-center)
58
 
59
- # Voxelization: This merges tiny points into larger "Splats"
60
- # Adjust voxel_size to make the model more or less "dense"
61
- pcd = pcd.voxel_down_sample(voxel_size=0.05)
62
 
63
- # --- 5. EXPORT ---
64
  temp_dir = tempfile.gettempdir()
65
- # Saving as .ply (Gradio 5+ renders binary PLY as splats in Solid mode)
66
- output_path = os.path.join(temp_dir, "model_output.ply")
 
67
  o3d.io.write_point_cloud(output_path, pcd, write_ascii=False)
68
 
69
  return output_path, output_path
70
 
71
  # --- 6. GRADIO UI ---
72
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
73
- gr.Markdown("# 🌌 3D Gaussian Splat Generator")
74
- gr.Markdown("Transform any 2D image into a centered, solid-looking 3D Splat.")
75
 
76
  with gr.Row():
77
  with gr.Column(scale=1):
78
  img_input = gr.Image(type="pil", label="Input Image")
79
- run_btn = gr.Button("🔨 Build 3D Splat", variant="primary")
80
 
81
  with gr.Column(scale=2):
82
- # display_mode="solid" tells Gradio to render the points as Gaussians
83
- # camera_position=(alpha, beta, radius)
84
  view_3d = gr.Model3D(
85
  label="3D Viewport",
86
- display_mode="solid",
87
  camera_position=(0, 90, 15),
88
  clear_color=(0.0, 0.0, 0.0, 1.0)
89
  )
90
- dl_btn = gr.DownloadButton("💾 Download Model (.PLY)")
 
91
 
92
- # Define behavior
93
  run_btn.click(
94
  fn=process_to_3d,
95
  inputs=[img_input],
 
9
 
10
  # --- 1. SETTINGS & MODEL ---
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
12
  CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf"
13
 
14
  processor = AutoImageProcessor.from_pretrained(CHECKPOINT)
 
22
  inputs = processor(images=input_image, return_tensors="pt").to(DEVICE)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
 
25
  depth = torch.nn.functional.interpolate(
26
  outputs.predicted_depth.unsqueeze(1),
27
  size=input_image.size[::-1],
 
33
  rgb = np.array(input_image)
34
  x, y = np.meshgrid(np.arange(width), np.arange(height))
35
 
36
+ # Scale depth (Z-axis) for a clean 3D range
37
  z = (depth / depth.max()) * 10.0
38
 
39
+ # Projection math
40
  focal_length = width
41
  x_coords = (x - width / 2) * z / focal_length
42
  y_coords = (y - height / 2) * z / focal_length
 
44
  points = np.stack((x_coords, y_coords, z), axis=-1).reshape(-1, 3)
45
  colors = rgb.reshape(-1, 3) / 255.0
46
 
47
+ # --- 4. CENTERING & VOXELIZATION ---
48
  pcd = o3d.geometry.PointCloud()
49
  pcd.points = o3d.utility.Vector3dVector(points)
50
  pcd.colors = o3d.utility.Vector3dVector(colors)
51
 
52
+ # Centering: Critical for the camera to lock onto the model
 
53
  center = pcd.get_center()
54
  pcd.translate(-center)
55
 
56
+ # Voxelization: Merges points into larger "splats" for solid visibility
57
+ pcd = pcd.voxel_down_sample(voxel_size=0.04)
 
58
 
59
+ # --- 5. EXPORT AS .PLY ---
60
  temp_dir = tempfile.gettempdir()
61
+ output_path = os.path.join(temp_dir, "model.ply")
62
+
63
+ # write_ascii=False saves it in Binary format (required for fast web loading)
64
  o3d.io.write_point_cloud(output_path, pcd, write_ascii=False)
65
 
66
  return output_path, output_path
67
 
68
  # --- 6. GRADIO UI ---
69
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
70
+ gr.Markdown("# 🌊 Depth Anything Splat Creator")
 
71
 
72
  with gr.Row():
73
  with gr.Column(scale=1):
74
  img_input = gr.Image(type="pil", label="Input Image")
75
+ run_btn = gr.Button("🔨 Generate .PLY Splat", variant="primary")
76
 
77
  with gr.Column(scale=2):
 
 
78
  view_3d = gr.Model3D(
79
  label="3D Viewport",
80
+ display_mode="solid", # Renders PLY points as Gaussians
81
  camera_position=(0, 90, 15),
82
  clear_color=(0.0, 0.0, 0.0, 1.0)
83
  )
84
+ # Explicitly set the download button
85
+ dl_btn = gr.DownloadButton("💾 Download .PLY File")
86
 
87
+ # Link the logic
88
  run_btn.click(
89
  fn=process_to_3d,
90
  inputs=[img_input],