mmmno commited on
Commit
c927e0a
·
verified ·
1 Parent(s): 6864135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -36
app.py CHANGED
@@ -7,84 +7,94 @@ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
7
  import tempfile
8
  import os
9
 
10
- # --- MODEL SETUP ---
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
12
  CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf"
13
 
14
  processor = AutoImageProcessor.from_pretrained(CHECKPOINT)
15
  model = AutoModelForDepthEstimation.from_pretrained(CHECKPOINT).to(DEVICE)
16
 
17
- def create_point_cloud(input_image):
18
  if input_image is None:
19
  return None, None
20
 
21
- # 1. Generate Depth
22
  inputs = processor(images=input_image, return_tensors="pt").to(DEVICE)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
 
25
  depth = torch.nn.functional.interpolate(
26
  outputs.predicted_depth.unsqueeze(1),
27
  size=input_image.size[::-1],
28
  mode="bicubic",
29
  ).squeeze().cpu().numpy()
30
 
31
- # 2. Advanced Projection Logic
32
  width, height = input_image.size
33
  rgb = np.array(input_image)
34
  x, y = np.meshgrid(np.arange(width), np.arange(height))
35
 
36
- # Scale depth to a visible range
37
- z = depth / depth.max() * 150.0
38
 
39
- # THE FIX: Calculate focal length based on image width
40
  focal_length = width
 
 
41
 
42
- # THE FIX: Center X and Y by subtracting half the width/height
43
- # This places the center of your photo at (0,0,z)
44
- x_centered = (x - width / 2) * z / focal_length
45
- y_centered = (y - height / 2) * z / focal_length
46
-
47
- points = np.stack((x_centered, y_centered, z), axis=-1).reshape(-1, 3)
48
  colors = rgb.reshape(-1, 3) / 255.0
49
 
50
- # 3. Open3D Point Cloud Processing
51
  pcd = o3d.geometry.PointCloud()
52
  pcd.points = o3d.utility.Vector3dVector(points)
53
  pcd.colors = o3d.utility.Vector3dVector(colors)
54
 
55
- # THE FIX: Center the entire cloud globally at (0,0,0)
56
- # This handles the 'Z' offset as well
57
  center = pcd.get_center()
58
  pcd.translate(-center)
59
 
60
- # THE FIX: Increase visibility by merging points (Voxelization)
61
- pcd = pcd.voxel_down_sample(voxel_size=0.4)
 
62
 
63
- # 4. Save to OBJ (Most common format)
64
  temp_dir = tempfile.gettempdir()
65
- output_path = os.path.join(temp_dir, "centered_model.obj")
66
- o3d.io.write_point_cloud(output_path, pcd)
 
67
 
68
  return output_path, output_path
69
 
70
- # --- GRADIO UI ---
71
- with gr.Blocks() as demo:
72
- gr.Markdown("# 🧊 Auto-Centered 3D Point Cloud")
 
73
 
74
  with gr.Row():
75
- with gr.Column():
76
- img_input = gr.Image(type="pil")
77
- run_btn = gr.Button("Generate 3D OBJ", variant="primary")
78
-
79
- with gr.Column():
80
- # Radius 200 starts the camera at a nice zoom level
 
81
  view_3d = gr.Model3D(
82
- label="3D Preview",
83
- camera_position=(0, 90, 200),
84
- clear_color=(0.1, 0.1, 0.1, 1.0)
 
85
  )
86
- dl_file = gr.DownloadButton("Download .OBJ File")
87
 
88
- run_btn.click(fn=create_point_cloud, inputs=[img_input], outputs=[view_3d, dl_file])
 
 
 
 
 
89
 
90
- demo.launch()
 
 
7
  import tempfile
8
  import os
9
 
10
+ # --- 1. SETTINGS & MODEL ---
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ # Using Depth Anything V2 for maximum compatibility
13
  CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf"
14
 
15
  processor = AutoImageProcessor.from_pretrained(CHECKPOINT)
16
  model = AutoModelForDepthEstimation.from_pretrained(CHECKPOINT).to(DEVICE)
17
 
18
+ def process_to_3d(input_image):
19
  if input_image is None:
20
  return None, None
21
 
22
+ # --- 2. DEPTH ESTIMATION ---
23
  inputs = processor(images=input_image, return_tensors="pt").to(DEVICE)
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
+ # Resize depth map to match original image resolution
27
  depth = torch.nn.functional.interpolate(
28
  outputs.predicted_depth.unsqueeze(1),
29
  size=input_image.size[::-1],
30
  mode="bicubic",
31
  ).squeeze().cpu().numpy()
32
 
33
+ # --- 3. POINT CLOUD PROJECTION ---
34
  width, height = input_image.size
35
  rgb = np.array(input_image)
36
  x, y = np.meshgrid(np.arange(width), np.arange(height))
37
 
38
+ # Scale depth to a standard 3D unit range
39
+ z = (depth / depth.max()) * 10.0
40
 
41
+ # Projection math (pinhole camera model)
42
  focal_length = width
43
+ x_coords = (x - width / 2) * z / focal_length
44
+ y_coords = (y - height / 2) * z / focal_length
45
 
46
+ points = np.stack((x_coords, y_coords, z), axis=-1).reshape(-1, 3)
 
 
 
 
 
47
  colors = rgb.reshape(-1, 3) / 255.0
48
 
49
+ # --- 4. THE SPLAT TRICK (Open3D) ---
50
  pcd = o3d.geometry.PointCloud()
51
  pcd.points = o3d.utility.Vector3dVector(points)
52
  pcd.colors = o3d.utility.Vector3dVector(colors)
53
 
54
+ # Centering: Move the model so its 3D center is at (0, 0, 0)
55
+ # This ensures the camera rotates around the object, not the corner.
56
  center = pcd.get_center()
57
  pcd.translate(-center)
58
 
59
+ # Voxelization: This merges tiny points into larger "Splats"
60
+ # Adjust voxel_size to make the model more or less "dense"
61
+ pcd = pcd.voxel_down_sample(voxel_size=0.05)
62
 
63
+ # --- 5. EXPORT ---
64
  temp_dir = tempfile.gettempdir()
65
+ # Saving as .ply (Gradio 5+ renders binary PLY as splats in Solid mode)
66
+ output_path = os.path.join(temp_dir, "model_output.ply")
67
+ o3d.io.write_point_cloud(output_path, pcd, write_ascii=False)
68
 
69
  return output_path, output_path
70
 
71
+ # --- 6. GRADIO UI ---
72
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
73
+ gr.Markdown("# 🌌 3D Gaussian Splat Generator")
74
+ gr.Markdown("Transform any 2D image into a centered, solid-looking 3D Splat.")
75
 
76
  with gr.Row():
77
+ with gr.Column(scale=1):
78
+ img_input = gr.Image(type="pil", label="Input Image")
79
+ run_btn = gr.Button("🔨 Build 3D Splat", variant="primary")
80
+
81
+ with gr.Column(scale=2):
82
+ # display_mode="solid" tells Gradio to render the points as Gaussians
83
+ # camera_position=(alpha, beta, radius)
84
  view_3d = gr.Model3D(
85
+ label="3D Viewport",
86
+ display_mode="solid",
87
+ camera_position=(0, 90, 15),
88
+ clear_color=(0.0, 0.0, 0.0, 1.0)
89
  )
90
+ dl_btn = gr.DownloadButton("💾 Download Model (.PLY)")
91
 
92
+ # Define behavior
93
+ run_btn.click(
94
+ fn=process_to_3d,
95
+ inputs=[img_input],
96
+ outputs=[view_3d, dl_btn]
97
+ )
98
 
99
+ if __name__ == "__main__":
100
+ demo.launch()