File size: 11,864 Bytes
1739a6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, glob, re, torch, numpy as np
from typing import List, Tuple
from safetensors.torch import load_file
import gradio as gr
import plotly.graph_objects as go
import torch.nn as nn
# =========================
# Config
# =========================
LOCAL_FILE_DEFAULT = "assets/scene0073_00.safetensors" # local safetensors file
PM_KEY_DEFAULT = "point_map"
TOPK_VIEWS_DEFAULT = 3
VOXEL_SIZE = 0.035
DOWNSAMPLE_N_MAX = 600_000
POINT_SIZE_MIN = 1.2
POINT_SIZE_MAX = 2.0
CLR_RED = "rgba(230,40,40,0.98)"
BG_COLOR = "#f7f9fb"
GRID_COLOR = "#e6ecf2"
BOX_COLOR = "rgba(80,80,80,0.6)"
TOP_VIEW_IMAGE_PATH = "assets/scene0073_00.png"
DEFAULT_CAM = dict(
eye=dict(x=1.35, y=1.35, z=0.95),
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
projection=dict(type="perspective"),
)
# =========================
# Load pretrained (your existing loader)
# =========================
def load_pretrain(model: nn.Module, pretrain_ckpt_path: str):
print(f"📂 Loading pretrained weights from: {str(pretrain_ckpt_path)}")
weight_files: List[str] = []
if os.path.isdir(pretrain_ckpt_path):
weight_files = sorted(glob.glob(os.path.join(pretrain_ckpt_path, "model*.safetensors")))
elif os.path.isfile(pretrain_ckpt_path):
if pretrain_ckpt_path.endswith(".safetensors"):
weight_files = [pretrain_ckpt_path]
elif pretrain_ckpt_path.endswith((".pth", ".pt")):
state = torch.load(pretrain_ckpt_path, map_location="cpu")
if isinstance(state, dict) and any(k.startswith(("model.", "target_model.")) for k in state.keys()):
state = {k.split(".", 1)[1] if k.startswith(("model.", "target_model.")) else k: v for k, v in state.items()}
result = model.load_state_dict(state, strict=False)
print(f"✅ Loaded .pth/.pt with {len(state)} keys | missing={len(result.missing_keys)} unexpected={len(result.unexpected_keys)}")
return
else:
raise FileNotFoundError(f"❌ Unsupported checkpoint extension: {pretrain_ckpt_path}")
else:
raise FileNotFoundError(f"❌ Path not found: {pretrain_ckpt_path}")
weights = {}
for wf in weight_files:
print(f"📥 Loading weights from: {wf}")
weights.update(load_file(wf, device="cpu"))
result = model.load_state_dict(weights, strict=False)
model_keys = set(model.state_dict().keys())
loaded_keys = model_keys.intersection(weights.keys())
print(f"✅ Loaded keys: {len(loaded_keys)} / {len(model_keys)} | missing={len(result.missing_keys)} unexpected={len(result.unexpected_keys)}")
# =========================
# Representation model (fg-clip-base + LoRA)
# =========================
def build_model(device: torch.device):
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
class RepModel(torch.nn.Module):
def __init__(self, model_root="fg-clip-base"):
super().__init__()
lora_cfg = LoraConfig(
r=32, lora_alpha=64,
target_modules=["q_proj","k_proj","v_proj","fc1","fc2"],
lora_dropout=0.05, bias="none",
task_type="FEATURE_EXTRACTION"
)
base = AutoModelForCausalLM.from_pretrained(model_root, trust_remote_code=True)
self.target_model = get_peft_model(base, lora_cfg)
self.tokenizer = AutoTokenizer.from_pretrained(model_root, trust_remote_code=True, use_fast=True)
@torch.no_grad()
def get_text_feature(self, texts, device):
tok = self.tokenizer(texts, padding="max_length", truncation=True, max_length=248, return_tensors="pt").to(device)
feats = self.target_model.get_text_features(tok["input_ids"], walk_short_pos=False)
feats = torch.nn.functional.normalize(feats.float(), dim=-1)
return feats
@torch.no_grad()
def get_image_feature(self, pm_batched):
_, feats = self.target_model.get_image_features(pm_batched)
feats = torch.nn.functional.normalize(feats.float(), dim=-1)
return feats
m = RepModel().to(device).eval()
print("Using fg-clip-base RepModel.")
return m
# =========================
# Data loading & helpers
# =========================
def load_scene_local(path: str, pm_key: str = PM_KEY_DEFAULT) -> torch.Tensor:
if not os.path.exists(path):
raise FileNotFoundError(f"Local file not found: {path}")
sd = load_file(path)
if pm_key not in sd:
raise KeyError(f"Key '{pm_key}' not found in {list(sd.keys())}")
pm = sd[pm_key] # (V,H,W,3)
if pm.dim() != 4 or pm.shape[-1] != 3:
raise ValueError(f"Invalid shape {tuple(pm.shape)}, expected (V,H,W,3)")
return pm.permute(0, 3, 1, 2).contiguous() # -> (V,3,H,W)
def _xyz_to_numpy(xyz: torch.Tensor) -> np.ndarray:
pts = xyz.permute(1, 2, 0).reshape(-1, 3).cpu().numpy().astype(np.float32)
mask = np.isfinite(pts).all(axis=1)
return pts[mask]
def stack_views(pm: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
pts_all, vid_all = [], []
for v in range(pm.shape[0]):
pts = _xyz_to_numpy(pm[v])
if pts.size == 0: continue
pts_all.append(pts)
vid_all.append(np.full((pts.shape[0],), v, dtype=np.int32))
pts_all = np.concatenate(pts_all, axis=0)
vid_all = np.concatenate(vid_all, axis=0)
return pts_all, vid_all
def voxel_downsample_with_ids(pts, vids, voxel: float):
if pts.shape[0] == 0: return pts, vids
grid = np.floor(pts / voxel).astype(np.int64)
key = np.core.records.fromarrays(grid.T, names="x,y,z", formats="i8,i8,i8")
_, uniq_idx = np.unique(key, return_index=True)
return pts[uniq_idx], vids[uniq_idx]
def hard_cap(pts, vids, cap: int):
N = pts.shape[0]
if N <= cap: return pts, vids
idx = np.random.choice(N, size=cap, replace=False)
return pts[idx], vids[idx]
def adaptive_point_size(n: int) -> float:
ps = 2.4 * (150_000 / max(n, 10)) ** 0.25
return float(np.clip(ps, POINT_SIZE_MIN, POINT_SIZE_MAX))
def scene_bbox(pts: np.ndarray):
mn, mx = pts.min(axis=0), pts.max(axis=0)
x0,y0,z0 = mn; x1,y1,z1 = mx
corners = np.array([
[x0,y0,z0],[x1,y0,z0],[x1,y1,z0],[x0,y1,z0],
[x0,y0,z1],[x1,y0,z1],[x1,y1,z1],[x0,y1,z1]
])
edges = [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),(0,4),(1,5),(2,6),(3,7)]
xs,ys,zs=[],[],[]
for a,b in edges:
xs += [corners[a,0], corners[b,0], None]
ys += [corners[a,1], corners[b,1], None]
zs += [corners[a,2], corners[b,2], None]
return xs,ys,zs
@torch.no_grad()
def rank_views_for_text(model, text, pm, device, topk: int):
img_feats = model.get_image_feature(pm.float().to(device))
txt_feat = model.get_text_feature([text], device=device)[0]
sims = torch.matmul(img_feats, txt_feat)
order = torch.argsort(sims, descending=True)[:max(1, int(topk))]
return order.tolist()
# =========================
# Visualization
# =========================
def depth_values(pts: np.ndarray) -> np.ndarray:
z = pts[:, 2]
z_min, z_max = z.min(), z.max()
return (z - z_min) / (z_max - z_min + 1e-9)
def base_figure_gray_depth(pts: np.ndarray, point_size: float, camera=DEFAULT_CAM) -> go.Figure:
depth = depth_values(pts)
fig = go.Figure(go.Scatter3d(
x=pts[:,0], y=pts[:,1], z=pts[:,2],
mode="markers",
marker=dict(size=point_size, color=depth, colorscale="Greys", reversescale=True, opacity=0.50),
hoverinfo="skip"
))
bx,by,bz = scene_bbox(pts)
fig.add_trace(go.Scatter3d(x=bx,y=by,z=bz,mode="lines",line=dict(color=BOX_COLOR,width=2),hoverinfo="skip"))
fig.update_layout(scene=dict(aspectmode="data",camera=camera),
margin=dict(l=0,r=0,b=0,t=0),
paper_bgcolor=BG_COLOR,
showlegend=False)
return fig
def highlight_views_3d(pts, view_ids, selected, point_size, camera=DEFAULT_CAM):
depth = depth_values(pts)
colors = np.stack([depth, depth, depth], axis=1)
if selected:
sel_mask = np.isin(view_ids, np.array(selected, dtype=np.int32))
colors[sel_mask] = np.array([1, 0, 0])
fig = go.Figure(go.Scatter3d(
x=pts[:,0], y=pts[:,1], z=pts[:,2],
mode="markers",
marker=dict(size=point_size,
color=[f"rgb({int(r*255)},{int(g*255)},{int(b*255)})"
for r,g,b in colors],
opacity=0.98),
hoverinfo="skip"
))
bx,by,bz = scene_bbox(pts)
fig.add_trace(go.Scatter3d(x=bx,y=by,z=bz,mode="lines",
line=dict(color=BOX_COLOR,width=2),hoverinfo="skip"))
fig.update_layout(scene=dict(aspectmode="data",camera=camera),
margin=dict(l=0,r=0,b=0,t=0),
paper_bgcolor=BG_COLOR,
showlegend=False)
return fig
# =========================
# App setup
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(device)
load_pretrain(model, "assets/ckpt_100.pth")
with gr.Blocks(
title="POMA-3D: Text-conditioned 3D Scene Visualization",
css="#plot3d, #img_ref {height: 450px !important;}"
) as demo:
gr.Markdown("### POMA-3D: The Point Map Way to 3D Scene Understanding\n"
"Enter agent's situation text and choose **Top-K**; the most relevant views will turn **red**.")
with gr.Row():
text_in = gr.Textbox(label="Text query", value="I am sleeping on the bed.", scale=4)
topk_in = gr.Number(label="Top-K views", value=TOPK_VIEWS_DEFAULT, precision=0, minimum=1, maximum=12)
submit_btn = gr.Button("Locate", variant="primary")
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=500):
plot3d = gr.Plot(label="3D Point Cloud (rotatable)", elem_id="plot3d")
with gr.Column(scale=1, min_width=500):
img_ref = gr.Image(label="Top-Down Reference View", value=TOP_VIEW_IMAGE_PATH, elem_id="img_ref")
status = gr.Markdown()
pm_state = gr.State(None)
pts_state = gr.State(None)
vids_state = gr.State(None)
# Load scene automatically from LOCAL_FILE_DEFAULT
def on_load():
pm = load_scene_local(LOCAL_FILE_DEFAULT)
pts_all, vids_all = stack_views(pm)
pts_vx, vids_vx = voxel_downsample_with_ids(pts_all, vids_all, VOXEL_SIZE)
pts_vx, vids_vx = hard_cap(pts_vx, vids_vx, DOWNSAMPLE_N_MAX)
ps = adaptive_point_size(pts_vx.shape[0])
fig3d = base_figure_gray_depth(pts_vx, ps, camera=DEFAULT_CAM)
msg = f"✅ Loaded {os.path.basename(LOCAL_FILE_DEFAULT)} | Views: {pm.shape[0]} | Points: {pts_vx.shape[0]:,}"
return fig3d, TOP_VIEW_IMAGE_PATH, msg, pm, pts_vx, vids_vx
def on_submit(text, topk, pm, pts_vx, vids_vx):
if pm is None:
return gr.update(), TOP_VIEW_IMAGE_PATH, "⚠️ Scene not loaded yet."
k = int(max(1, min(12, int(topk)))) if topk else TOPK_VIEWS_DEFAULT
top_views = rank_views_for_text(model, text, pm, device, topk=k)
ps = adaptive_point_size(pts_vx.shape[0])
fig = highlight_views_3d(pts_vx, vids_vx, top_views, ps, camera=DEFAULT_CAM)
msg = f"Highlighted views (top-{k}): {top_views}"
return fig, TOP_VIEW_IMAGE_PATH, msg
demo.load(on_load, inputs=[], outputs=[plot3d, img_ref, status, pm_state, pts_state, vids_state])
submit_btn.click(on_submit, inputs=[text_in, topk_in, pm_state, pts_state, vids_state],
outputs=[plot3d, img_ref, status])
if __name__ == "__main__":
demo.launch() |