Yebulabula commited on
Commit
1739a6b
·
1 Parent(s): ce6ba12

Initial commit

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
+ <<<<<<< HEAD
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ =======
38
+ >>>>>>> Initial commit
39
+ assets/*.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: POMA-3D Demo
3
+ emoji: 🧠
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # POMA-3D: The Point Map Way to 3D Scene Understanding
13
+
14
+ - Rotatable Plotly 3D point cloud (voxel downsampled)
15
+ - Depth shading for structure + **red** recolor for Top-K retrieved views
16
+ - Uses local `assets/scene0073_00.safetensors` (no downloads)
17
+
18
+ **How to use**
19
+ 1. Wait for the scene to load (depth-shaded point cloud appears).
20
+ 2. Enter an agent situation text.
21
+ 3. Choose Top-K (e.g., 3).
22
+ 4. Click **Highlight** — the most relevant views turn **red**.
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os, glob, re, torch, numpy as np
5
+ from typing import List, Tuple
6
+ from safetensors.torch import load_file
7
+ import gradio as gr
8
+ import plotly.graph_objects as go
9
+ import torch.nn as nn
10
+
11
+ # =========================
12
+ # Config
13
+ # =========================
14
+ LOCAL_FILE_DEFAULT = "assets/scene0073_00.safetensors" # local safetensors file
15
+ PM_KEY_DEFAULT = "point_map"
16
+ TOPK_VIEWS_DEFAULT = 3
17
+
18
+ VOXEL_SIZE = 0.035
19
+ DOWNSAMPLE_N_MAX = 600_000
20
+ POINT_SIZE_MIN = 1.2
21
+ POINT_SIZE_MAX = 2.0
22
+
23
+ CLR_RED = "rgba(230,40,40,0.98)"
24
+ BG_COLOR = "#f7f9fb"
25
+ GRID_COLOR = "#e6ecf2"
26
+ BOX_COLOR = "rgba(80,80,80,0.6)"
27
+
28
+ TOP_VIEW_IMAGE_PATH = "assets/scene0073_00.png"
29
+
30
+ DEFAULT_CAM = dict(
31
+ eye=dict(x=1.35, y=1.35, z=0.95),
32
+ up=dict(x=0, y=0, z=1),
33
+ center=dict(x=0, y=0, z=0),
34
+ projection=dict(type="perspective"),
35
+ )
36
+
37
+ # =========================
38
+ # Load pretrained (your existing loader)
39
+ # =========================
40
+ def load_pretrain(model: nn.Module, pretrain_ckpt_path: str):
41
+ print(f"📂 Loading pretrained weights from: {str(pretrain_ckpt_path)}")
42
+ weight_files: List[str] = []
43
+ if os.path.isdir(pretrain_ckpt_path):
44
+ weight_files = sorted(glob.glob(os.path.join(pretrain_ckpt_path, "model*.safetensors")))
45
+ elif os.path.isfile(pretrain_ckpt_path):
46
+ if pretrain_ckpt_path.endswith(".safetensors"):
47
+ weight_files = [pretrain_ckpt_path]
48
+ elif pretrain_ckpt_path.endswith((".pth", ".pt")):
49
+ state = torch.load(pretrain_ckpt_path, map_location="cpu")
50
+ if isinstance(state, dict) and any(k.startswith(("model.", "target_model.")) for k in state.keys()):
51
+ state = {k.split(".", 1)[1] if k.startswith(("model.", "target_model.")) else k: v for k, v in state.items()}
52
+ result = model.load_state_dict(state, strict=False)
53
+ print(f"✅ Loaded .pth/.pt with {len(state)} keys | missing={len(result.missing_keys)} unexpected={len(result.unexpected_keys)}")
54
+ return
55
+ else:
56
+ raise FileNotFoundError(f"❌ Unsupported checkpoint extension: {pretrain_ckpt_path}")
57
+ else:
58
+ raise FileNotFoundError(f"❌ Path not found: {pretrain_ckpt_path}")
59
+
60
+ weights = {}
61
+ for wf in weight_files:
62
+ print(f"📥 Loading weights from: {wf}")
63
+ weights.update(load_file(wf, device="cpu"))
64
+
65
+ result = model.load_state_dict(weights, strict=False)
66
+ model_keys = set(model.state_dict().keys())
67
+ loaded_keys = model_keys.intersection(weights.keys())
68
+ print(f"✅ Loaded keys: {len(loaded_keys)} / {len(model_keys)} | missing={len(result.missing_keys)} unexpected={len(result.unexpected_keys)}")
69
+
70
+ # =========================
71
+ # Representation model (fg-clip-base + LoRA)
72
+ # =========================
73
+ def build_model(device: torch.device):
74
+ from transformers import AutoModelForCausalLM, AutoTokenizer
75
+ from peft import LoraConfig, get_peft_model
76
+
77
+ class RepModel(torch.nn.Module):
78
+ def __init__(self, model_root="fg-clip-base"):
79
+ super().__init__()
80
+ lora_cfg = LoraConfig(
81
+ r=32, lora_alpha=64,
82
+ target_modules=["q_proj","k_proj","v_proj","fc1","fc2"],
83
+ lora_dropout=0.05, bias="none",
84
+ task_type="FEATURE_EXTRACTION"
85
+ )
86
+ base = AutoModelForCausalLM.from_pretrained(model_root, trust_remote_code=True)
87
+ self.target_model = get_peft_model(base, lora_cfg)
88
+ self.tokenizer = AutoTokenizer.from_pretrained(model_root, trust_remote_code=True, use_fast=True)
89
+
90
+ @torch.no_grad()
91
+ def get_text_feature(self, texts, device):
92
+ tok = self.tokenizer(texts, padding="max_length", truncation=True, max_length=248, return_tensors="pt").to(device)
93
+ feats = self.target_model.get_text_features(tok["input_ids"], walk_short_pos=False)
94
+ feats = torch.nn.functional.normalize(feats.float(), dim=-1)
95
+ return feats
96
+
97
+ @torch.no_grad()
98
+ def get_image_feature(self, pm_batched):
99
+ _, feats = self.target_model.get_image_features(pm_batched)
100
+ feats = torch.nn.functional.normalize(feats.float(), dim=-1)
101
+ return feats
102
+
103
+ m = RepModel().to(device).eval()
104
+ print("Using fg-clip-base RepModel.")
105
+ return m
106
+
107
+ # =========================
108
+ # Data loading & helpers
109
+ # =========================
110
+ def load_scene_local(path: str, pm_key: str = PM_KEY_DEFAULT) -> torch.Tensor:
111
+ if not os.path.exists(path):
112
+ raise FileNotFoundError(f"Local file not found: {path}")
113
+ sd = load_file(path)
114
+ if pm_key not in sd:
115
+ raise KeyError(f"Key '{pm_key}' not found in {list(sd.keys())}")
116
+ pm = sd[pm_key] # (V,H,W,3)
117
+ if pm.dim() != 4 or pm.shape[-1] != 3:
118
+ raise ValueError(f"Invalid shape {tuple(pm.shape)}, expected (V,H,W,3)")
119
+ return pm.permute(0, 3, 1, 2).contiguous() # -> (V,3,H,W)
120
+
121
+ def _xyz_to_numpy(xyz: torch.Tensor) -> np.ndarray:
122
+ pts = xyz.permute(1, 2, 0).reshape(-1, 3).cpu().numpy().astype(np.float32)
123
+ mask = np.isfinite(pts).all(axis=1)
124
+ return pts[mask]
125
+
126
+ def stack_views(pm: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
127
+ pts_all, vid_all = [], []
128
+ for v in range(pm.shape[0]):
129
+ pts = _xyz_to_numpy(pm[v])
130
+ if pts.size == 0: continue
131
+ pts_all.append(pts)
132
+ vid_all.append(np.full((pts.shape[0],), v, dtype=np.int32))
133
+ pts_all = np.concatenate(pts_all, axis=0)
134
+ vid_all = np.concatenate(vid_all, axis=0)
135
+ return pts_all, vid_all
136
+
137
+ def voxel_downsample_with_ids(pts, vids, voxel: float):
138
+ if pts.shape[0] == 0: return pts, vids
139
+ grid = np.floor(pts / voxel).astype(np.int64)
140
+ key = np.core.records.fromarrays(grid.T, names="x,y,z", formats="i8,i8,i8")
141
+ _, uniq_idx = np.unique(key, return_index=True)
142
+ return pts[uniq_idx], vids[uniq_idx]
143
+
144
+ def hard_cap(pts, vids, cap: int):
145
+ N = pts.shape[0]
146
+ if N <= cap: return pts, vids
147
+ idx = np.random.choice(N, size=cap, replace=False)
148
+ return pts[idx], vids[idx]
149
+
150
+ def adaptive_point_size(n: int) -> float:
151
+ ps = 2.4 * (150_000 / max(n, 10)) ** 0.25
152
+ return float(np.clip(ps, POINT_SIZE_MIN, POINT_SIZE_MAX))
153
+
154
+ def scene_bbox(pts: np.ndarray):
155
+ mn, mx = pts.min(axis=0), pts.max(axis=0)
156
+ x0,y0,z0 = mn; x1,y1,z1 = mx
157
+ corners = np.array([
158
+ [x0,y0,z0],[x1,y0,z0],[x1,y1,z0],[x0,y1,z0],
159
+ [x0,y0,z1],[x1,y0,z1],[x1,y1,z1],[x0,y1,z1]
160
+ ])
161
+ edges = [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),(0,4),(1,5),(2,6),(3,7)]
162
+ xs,ys,zs=[],[],[]
163
+ for a,b in edges:
164
+ xs += [corners[a,0], corners[b,0], None]
165
+ ys += [corners[a,1], corners[b,1], None]
166
+ zs += [corners[a,2], corners[b,2], None]
167
+ return xs,ys,zs
168
+
169
+ @torch.no_grad()
170
+ def rank_views_for_text(model, text, pm, device, topk: int):
171
+ img_feats = model.get_image_feature(pm.float().to(device))
172
+ txt_feat = model.get_text_feature([text], device=device)[0]
173
+ sims = torch.matmul(img_feats, txt_feat)
174
+ order = torch.argsort(sims, descending=True)[:max(1, int(topk))]
175
+ return order.tolist()
176
+
177
+ # =========================
178
+ # Visualization
179
+ # =========================
180
+ def depth_values(pts: np.ndarray) -> np.ndarray:
181
+ z = pts[:, 2]
182
+ z_min, z_max = z.min(), z.max()
183
+ return (z - z_min) / (z_max - z_min + 1e-9)
184
+
185
+ def base_figure_gray_depth(pts: np.ndarray, point_size: float, camera=DEFAULT_CAM) -> go.Figure:
186
+ depth = depth_values(pts)
187
+ fig = go.Figure(go.Scatter3d(
188
+ x=pts[:,0], y=pts[:,1], z=pts[:,2],
189
+ mode="markers",
190
+ marker=dict(size=point_size, color=depth, colorscale="Greys", reversescale=True, opacity=0.50),
191
+ hoverinfo="skip"
192
+ ))
193
+ bx,by,bz = scene_bbox(pts)
194
+ fig.add_trace(go.Scatter3d(x=bx,y=by,z=bz,mode="lines",line=dict(color=BOX_COLOR,width=2),hoverinfo="skip"))
195
+ fig.update_layout(scene=dict(aspectmode="data",camera=camera),
196
+ margin=dict(l=0,r=0,b=0,t=0),
197
+ paper_bgcolor=BG_COLOR,
198
+ showlegend=False)
199
+ return fig
200
+
201
+ def highlight_views_3d(pts, view_ids, selected, point_size, camera=DEFAULT_CAM):
202
+ depth = depth_values(pts)
203
+ colors = np.stack([depth, depth, depth], axis=1)
204
+ if selected:
205
+ sel_mask = np.isin(view_ids, np.array(selected, dtype=np.int32))
206
+ colors[sel_mask] = np.array([1, 0, 0])
207
+ fig = go.Figure(go.Scatter3d(
208
+ x=pts[:,0], y=pts[:,1], z=pts[:,2],
209
+ mode="markers",
210
+ marker=dict(size=point_size,
211
+ color=[f"rgb({int(r*255)},{int(g*255)},{int(b*255)})"
212
+ for r,g,b in colors],
213
+ opacity=0.98),
214
+ hoverinfo="skip"
215
+ ))
216
+ bx,by,bz = scene_bbox(pts)
217
+ fig.add_trace(go.Scatter3d(x=bx,y=by,z=bz,mode="lines",
218
+ line=dict(color=BOX_COLOR,width=2),hoverinfo="skip"))
219
+ fig.update_layout(scene=dict(aspectmode="data",camera=camera),
220
+ margin=dict(l=0,r=0,b=0,t=0),
221
+ paper_bgcolor=BG_COLOR,
222
+ showlegend=False)
223
+ return fig
224
+
225
+ # =========================
226
+ # App setup
227
+ # =========================
228
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229
+ model = build_model(device)
230
+ load_pretrain(model, "assets/ckpt_100.pth")
231
+
232
+ with gr.Blocks(
233
+ title="POMA-3D: Text-conditioned 3D Scene Visualization",
234
+ css="#plot3d, #img_ref {height: 450px !important;}"
235
+ ) as demo:
236
+ gr.Markdown("### POMA-3D: The Point Map Way to 3D Scene Understanding\n"
237
+ "Enter agent's situation text and choose **Top-K**; the most relevant views will turn **red**.")
238
+
239
+ with gr.Row():
240
+ text_in = gr.Textbox(label="Text query", value="I am sleeping on the bed.", scale=4)
241
+ topk_in = gr.Number(label="Top-K views", value=TOPK_VIEWS_DEFAULT, precision=0, minimum=1, maximum=12)
242
+ submit_btn = gr.Button("Locate", variant="primary")
243
+
244
+ with gr.Row(equal_height=True):
245
+ with gr.Column(scale=1, min_width=500):
246
+ plot3d = gr.Plot(label="3D Point Cloud (rotatable)", elem_id="plot3d")
247
+ with gr.Column(scale=1, min_width=500):
248
+ img_ref = gr.Image(label="Top-Down Reference View", value=TOP_VIEW_IMAGE_PATH, elem_id="img_ref")
249
+
250
+ status = gr.Markdown()
251
+
252
+ pm_state = gr.State(None)
253
+ pts_state = gr.State(None)
254
+ vids_state = gr.State(None)
255
+
256
+ # Load scene automatically from LOCAL_FILE_DEFAULT
257
+ def on_load():
258
+ pm = load_scene_local(LOCAL_FILE_DEFAULT)
259
+ pts_all, vids_all = stack_views(pm)
260
+ pts_vx, vids_vx = voxel_downsample_with_ids(pts_all, vids_all, VOXEL_SIZE)
261
+ pts_vx, vids_vx = hard_cap(pts_vx, vids_vx, DOWNSAMPLE_N_MAX)
262
+ ps = adaptive_point_size(pts_vx.shape[0])
263
+ fig3d = base_figure_gray_depth(pts_vx, ps, camera=DEFAULT_CAM)
264
+ msg = f"✅ Loaded {os.path.basename(LOCAL_FILE_DEFAULT)} | Views: {pm.shape[0]} | Points: {pts_vx.shape[0]:,}"
265
+ return fig3d, TOP_VIEW_IMAGE_PATH, msg, pm, pts_vx, vids_vx
266
+
267
+ def on_submit(text, topk, pm, pts_vx, vids_vx):
268
+ if pm is None:
269
+ return gr.update(), TOP_VIEW_IMAGE_PATH, "⚠️ Scene not loaded yet."
270
+ k = int(max(1, min(12, int(topk)))) if topk else TOPK_VIEWS_DEFAULT
271
+ top_views = rank_views_for_text(model, text, pm, device, topk=k)
272
+ ps = adaptive_point_size(pts_vx.shape[0])
273
+ fig = highlight_views_3d(pts_vx, vids_vx, top_views, ps, camera=DEFAULT_CAM)
274
+ msg = f"Highlighted views (top-{k}): {top_views}"
275
+ return fig, TOP_VIEW_IMAGE_PATH, msg
276
+
277
+ demo.load(on_load, inputs=[], outputs=[plot3d, img_ref, status, pm_state, pts_state, vids_state])
278
+ submit_btn.click(on_submit, inputs=[text_in, topk_in, pm_state, pts_state, vids_state],
279
+ outputs=[plot3d, img_ref, status])
280
+
281
+ if __name__ == "__main__":
282
+ demo.launch()
assets/ckpt_100.pth/custom_checkpoint_0.pkl ADDED
Binary file (1.04 kB). View file
 
assets/ckpt_100.pth/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1a61ca716a4213a494971e6d482f3337a03669be322584477b86099dbd56f83
3
+ size 1926257444
assets/ckpt_100.pth/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d95b10b6e140a9626a7058d5038528f2ff80148dc4569b881db56052046509
3
+ size 40
assets/ckpt_100.pth/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa11ac4f90f2bdfe8f6fca4e1519a3cca7e107b48cfeb83504b08f98e9d6322a
3
+ size 847312239
assets/ckpt_100.pth/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63fd0a7fcb6f13000479f1e5225f2858947b004b0a552e9117c11a8cb8b1d7a
3
+ size 16100
assets/ckpt_100.pth/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c6477bba6676da431d35912e3ec747ec6841cbf03d3fef95159d46fb68108f5
3
+ size 1064
assets/scene0000_00.png ADDED

Git LFS Details

  • SHA256: 33e118b387d97928b62eeb69131a99eb7de3458efc865ab0f2283a5cdba4c05d
  • Pointer size: 131 Bytes
  • Size of remote file: 807 kB
assets/scene0073_00.png ADDED

Git LFS Details

  • SHA256: fd22cac52198821f19797ccfbba79c26fee3390fd30017d4a05bf68ce894b1a8
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
assets/scene0073_00.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f99163e32be5d157c264a54d35b2514b57377713150c2da5fd1dd71eb4ad957
3
+ size 169390352
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ plotly==5.24.1
3
+ numpy==1.26.4
4
+ torch==2.3.1
5
+ safetensors==0.4.4
6
+ transformers==4.43.3
7
+ peft==0.11.1
8
+ accelerate==0.33.0
9
+ huggingface_hub==0.24.6