samrobertsondev commited on
Commit
8ffa450
·
verified ·
1 Parent(s): 82542be

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -47
app.py CHANGED
@@ -39,65 +39,100 @@ def run_moge_on_image(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
39
  image: HxWx3 RGB uint8 numpy array.
40
 
41
  Returns:
42
- points: (N, 3) float32 XYZ in some model-defined coordinates
43
- colors: (N, 3) uint8 RGB in [0, 255]
44
  """
45
 
46
  # Convert to float tensor [0, 1], CHW, batch
47
  img = image.astype(np.float32) / 255.0
48
  tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(DEVICE) # (1,3,H,W)
49
 
50
- # NOTE: This assumes MoGeModel.infer returns something like:
51
- # {"points": (1, N, 3), "colors": (1, N, 3)} or similar.
52
- # You may need to adapt this part to the actual MoGe API.
53
  out = MODEL.infer(tensor)
54
 
55
- # ----- Adapt this based on the actual return structure -----
56
- #
57
- # Common patterns:
58
- # out["points"]: (B, N, 3) point coordinates
59
- # out["colors"]: (B, N, 3) colors in [0, 1] or [0, 255]
60
- #
61
- # If your actual keys or shapes are different, adjust here.
62
-
63
- if "points" in out:
64
- points = out["points"]
65
- elif "point_cloud" in out:
66
- points = out["point_cloud"]
67
- else:
68
- raise RuntimeError(f"Cannot find point cloud in MoGe output keys: {list(out.keys())}")
69
-
70
- # remove batch dim
71
- if points.ndim == 3:
72
- points = points[0]
73
-
74
- points = points.detach().cpu().float().numpy() # (N,3)
75
-
76
- # Try to get colors if available, else default to white
77
- colors_raw = None
78
- for k in ["colors", "rgb", "point_colors"]:
79
- if k in out:
80
- colors_raw = out[k]
81
- break
82
-
83
- if colors_raw is not None:
84
- if colors_raw.ndim == 3:
85
- colors_raw = colors_raw[0]
86
- colors_np = colors_raw.detach().cpu().float().numpy()
87
- # Normalize to [0,255] if necessary
88
- if colors_np.max() <= 1.0:
89
- colors_np = (colors_np * 255.0).clip(0, 255)
90
- colors = colors_np.astype(np.uint8)
91
- else:
92
- # fallback: all white
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  colors = np.full_like(points, 255, dtype=np.uint8)
94
 
95
- # Ensure shapes
96
- assert points.shape[-1] == 3, f"Expected points (N,3), got {points.shape}"
97
- assert colors.shape[-1] == 3, f"Expected colors (N,3), got {colors.shape}"
 
98
 
99
- return points, colors
 
 
 
 
 
100
 
 
 
 
 
 
101
 
102
  # ---------- Helper: write PLY into memory ----------
103
 
 
39
  image: HxWx3 RGB uint8 numpy array.
40
 
41
  Returns:
42
+ points: (N, 3) float32 XYZ
43
+ colors: (N, 3) uint8 RGB
44
  """
45
 
46
  # Convert to float tensor [0, 1], CHW, batch
47
  img = image.astype(np.float32) / 255.0
48
  tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(DEVICE) # (1,3,H,W)
49
 
50
+ # --- Run MoGe ---
 
 
51
  out = MODEL.infer(tensor)
52
 
53
+ # --- DEBUG: log what MoGe actually returned ---
54
+ print("MoGe output keys:", list(out.keys()))
55
+ shaped = {}
56
+ for k, v in out.items():
57
+ if torch.is_tensor(v):
58
+ shaped[k] = (v.shape, v.dtype, float(v.min()), float(v.max()))
59
+ else:
60
+ shaped[k] = type(v).__name__
61
+ print("MoGe output summary:", shaped)
62
+
63
+ # --- Try several common patterns ---
64
+
65
+ points = None
66
+ colors = None
67
+
68
+ # 1) Single tensor with xyzrgb in last dim: (B, N, 6)
69
+ if "pcd" in out:
70
+ pcd = out["pcd"]
71
+ if pcd.ndim == 3 and pcd.shape[-1] >= 3:
72
+ # remove batch
73
+ if pcd.shape[0] == 1:
74
+ pcd = pcd[0]
75
+ pcd_np = pcd.detach().cpu().float().numpy() # (N, C)
76
+ points = pcd_np[:, :3]
77
+ if pcd_np.shape[1] >= 6:
78
+ cols = pcd_np[:, 3:6]
79
+ if cols.max() <= 1.0:
80
+ cols = (cols * 255.0).clip(0, 255)
81
+ colors = cols.astype(np.uint8)
82
+
83
+ # 2) Separate "points" and "colors"/"rgb"
84
+ if points is None:
85
+ if "points" in out:
86
+ pts = out["points"]
87
+ elif "point_cloud" in out:
88
+ pts = out["point_cloud"]
89
+ else:
90
+ pts = None
91
+
92
+ if pts is not None:
93
+ if pts.ndim == 3 and pts.shape[0] == 1:
94
+ pts = pts[0]
95
+ pts_np = pts.detach().cpu().float().numpy()
96
+ if pts_np.shape[-1] != 3:
97
+ raise RuntimeError(f"Expected points last dim=3, got {pts_np.shape}")
98
+ points = pts_np
99
+
100
+ # colors
101
+ col_tensor = None
102
+ for k in ["colors", "rgb", "point_colors"]:
103
+ if k in out:
104
+ col_tensor = out[k]
105
+ break
106
+
107
+ if col_tensor is not None:
108
+ if col_tensor.ndim == 3 and col_tensor.shape[0] == 1:
109
+ col_tensor = col_tensor[0]
110
+ col_np = col_tensor.detach().cpu().float().numpy()
111
+ if col_np.max() <= 1.0:
112
+ col_np = (col_np * 255.0).clip(0, 255)
113
+ colors = col_np.astype(np.uint8)
114
+
115
+ # 3) If still no colors, default to white
116
+ if points is not None and colors is None:
117
  colors = np.full_like(points, 255, dtype=np.uint8)
118
 
119
+ if points is None:
120
+ raise RuntimeError(
121
+ f"Could not find point cloud in MoGe output; keys: {list(out.keys())}"
122
+ )
123
 
124
+ # ensure 2D
125
+ points = points.reshape(-1, 3)
126
+ colors = colors.reshape(-1, 3)
127
+
128
+ n = points.shape[0]
129
+ print("MoGe point count:", n)
130
 
131
+ # sanity check: bail if the model gave us basically nothing
132
+ if n < 100:
133
+ raise RuntimeError(f"MoGe returned too few points (N={n}), refusing to write bogus PLY.")
134
+
135
+ return points, colors
136
 
137
  # ---------- Helper: write PLY into memory ----------
138