Spaces:
Running
Running
File size: 14,087 Bytes
4016bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, glob, re, torch, numpy as np
from typing import List, Tuple
from safetensors.torch import load_file
import gradio as gr
import plotly.graph_objects as go
import torch.nn as nn
from huggingface_hub import hf_hub_download, list_repo_files
# =========================
# Config
# =========================
LOCAL_FILE_DEFAULT = "assets/scene0073_00.safetensors" # local safetensors file
PM_KEY_DEFAULT = "point_map"
TOPK_VIEWS_DEFAULT = 3
VOXEL_SIZE = 0.02
DOWNSAMPLE_N_MAX = 600_000
POINT_SIZE_MIN = 1.2
POINT_SIZE_MAX = 2.0
CLR_RED = "rgba(230,40,40,0.98)"
BG_COLOR = "#f7f9fb"
GRID_COLOR = "#e6ecf2"
BOX_COLOR = "rgba(80,80,80,0.6)"
TOP_VIEW_IMAGE_PATH = "assets/scene0073_00.png"
DEFAULT_CAM = dict(
eye=dict(x=1.35, y=1.35, z=0.95),
up=dict(x=0, y=0, z=1),
center=dict(x=0, y=0, z=0),
projection=dict(type="perspective"),
)
def _merge_safetensors_dicts(paths: List[str]):
merged = {}
for p in paths:
sd = load_file(p, device="cpu")
merged.update(sd)
return merged
def _local_all_under(path: str) -> List[str]:
out = []
if os.path.isfile(path):
return [path]
for root, _, files in os.walk(path):
for f in files:
out.append(os.path.join(root, f))
return sorted(out)
# =========================
# Load pretrained (your existing loader)
# =========================
def load_pretrain(
model: torch.nn.Module,
ckpt_path: str, # e.g. "assets/ckpt_100.pth" or "assets/model.safetensors"
repo_id: str = "MatchLab/poma3d-demo",
revision: str = "main",
allow_local_fallback: bool = True,
):
if allow_local_fallback and (os.path.isfile(ckpt_path) or os.path.isdir(ckpt_path)):
print(f"📂 Using local checkpoint(s): {ckpt_path}")
local_files = _local_all_under(ckpt_path)
else:
# 2) REMOTE: resolve file list from Space
print(f"📦 Resolving from HF Space: {repo_id}/{ckpt_path} (rev={revision})")
files = list_repo_files(repo_id=repo_id, repo_type='model', revision=revision)
# Exact file hit?
if ckpt_path in files:
to_fetch = [ckpt_path]
else:
# Treat ckpt_path as a folder prefix (ensure trailing slash for matching)
prefix = ckpt_path if ckpt_path.endswith("/") else ckpt_path + "/"
to_fetch = [f for f in files if f.startswith(prefix)]
if not to_fetch:
preview = "\n".join(files[:100])
raise FileNotFoundError(
f"'{ckpt_path}' not found in Space '{repo_id}' (rev='{revision}').\n"
f"Files present (first 100):\n{preview}"
)
# Download all matching files locally
local_files = []
for rel in to_fetch:
lp = hf_hub_download(repo_id=repo_id, filename=rel, repo_type='model', revision=revision)
local_files.append(lp)
local_files.sort()
# Filter by types we know how to load
safes = [p for p in local_files if p.endswith(".safetensors")]
pths = [p for p in local_files if re.search(r"\.(?:pth|pt)$", p)]
if safes:
print(f"🧩 Found {len(safes)} .safetensors shard(s); merging…")
state = _merge_safetensors_dicts(safes)
elif pths:
# pick the largest .pth/.pt to avoid optimizer/state variants
pths_sorted = sorted(pths, key=lambda p: os.path.getsize(p), reverse=True)
pick = pths_sorted[0]
print(f"🧩 Using .pth/.pt: {os.path.basename(pick)} (largest of {len(pths)} candidates)")
state = torch.load(pick, map_location="cpu")
# strip common prefixes
if isinstance(state, dict) and any(k.startswith(("model.", "target_model.")) for k in state.keys()):
state = { (k.split(".", 1)[1] if k.startswith(("model.", "target_model.")) else k): v
for k, v in state.items() }
# nested 'state_dict'
if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
state = state["state_dict"]
else:
raise FileNotFoundError(
"No loadable checkpoint found. Expecting one or more of: "
".safetensors or .pth/.pt under the given path."
)
# Load into model
result = model.load_state_dict(state, strict=False)
# Report
weight_keys = set(state.keys()) if isinstance(state, dict) else set()
model_keys = set(model.state_dict().keys())
loaded_keys = model_keys.intersection(weight_keys)
print("✅ Weights loaded")
print(f" • Loaded keys: {len(loaded_keys)}")
print(f" • Missing keys: {len(result.missing_keys)}")
print(f" • Unexpected keys: {len(result.unexpected_keys)}")
return result
# =========================
# Representation model (fg-clip-base + LoRA)
# =========================
def build_model(device: torch.device):
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import LoraConfig, get_peft_model
class RepModel(torch.nn.Module):
def __init__(self, model_root="qihoo360/fg-clip-base"):
super().__init__()
lora_cfg = LoraConfig(
r=32, lora_alpha=64,
target_modules=["q_proj","k_proj","v_proj","fc1","fc2"],
lora_dropout=0.05, bias="none",
task_type="FEATURE_EXTRACTION"
)
cfg = AutoConfig.from_pretrained(model_root, trust_remote_code=True)
base = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True)
self.target_model = get_peft_model(base, lora_cfg)
self.tokenizer = AutoTokenizer.from_pretrained(model_root, trust_remote_code=True, use_fast=True)
@torch.no_grad()
def get_text_feature(self, texts, device):
tok = self.tokenizer(texts, padding="max_length", truncation=True, max_length=248, return_tensors="pt").to(device)
feats = self.target_model.get_text_features(tok["input_ids"], walk_short_pos=False)
feats = torch.nn.functional.normalize(feats.float(), dim=-1)
return feats
@torch.no_grad()
def get_image_feature(self, pm_batched):
feats = self.target_model.get_image_features(pm_batched)
feats = torch.nn.functional.normalize(feats.float(), dim=-1)
return feats
m = RepModel().to(device).eval()
print("Using fg-clip-base RepModel.")
return m
# =========================
# Data loading & helpers
# =========================
def load_scene_local(path: str, pm_key: str = PM_KEY_DEFAULT) -> torch.Tensor:
if not os.path.exists(path):
raise FileNotFoundError(f"Local file not found: {path}")
sd = load_file(path)
if pm_key not in sd:
raise KeyError(f"Key '{pm_key}' not found in {list(sd.keys())}")
pm = sd[pm_key] # (V,H,W,3)
if pm.dim() != 4 or pm.shape[-1] != 3:
raise ValueError(f"Invalid shape {tuple(pm.shape)}, expected (V,H,W,3)")
return pm.permute(0, 3, 1, 2).contiguous() # -> (V,3,H,W)
def _xyz_to_numpy(xyz: torch.Tensor) -> np.ndarray:
pts = xyz.permute(1, 2, 0).reshape(-1, 3).cpu().numpy().astype(np.float32)
mask = np.isfinite(pts).all(axis=1)
return pts[mask]
def stack_views(pm: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
pts_all, vid_all = [], []
for v in range(pm.shape[0]):
pts = _xyz_to_numpy(pm[v])
if pts.size == 0: continue
pts_all.append(pts)
vid_all.append(np.full((pts.shape[0],), v, dtype=np.int32))
pts_all = np.concatenate(pts_all, axis=0)
vid_all = np.concatenate(vid_all, axis=0)
return pts_all, vid_all
def voxel_downsample_with_ids(pts, vids, voxel: float):
if pts.shape[0] == 0: return pts, vids
grid = np.floor(pts / voxel).astype(np.int64)
key = np.core.records.fromarrays(grid.T, names="x,y,z", formats="i8,i8,i8")
_, uniq_idx = np.unique(key, return_index=True)
return pts[uniq_idx], vids[uniq_idx]
def hard_cap(pts, vids, cap: int):
N = pts.shape[0]
if N <= cap: return pts, vids
idx = np.random.choice(N, size=cap, replace=False)
return pts[idx], vids[idx]
def adaptive_point_size(n: int) -> float:
ps = 2.4 * (150_000 / max(n, 10)) ** 0.25
return float(np.clip(ps, POINT_SIZE_MIN, POINT_SIZE_MAX))
def scene_bbox(pts: np.ndarray):
mn, mx = pts.min(axis=0), pts.max(axis=0)
x0,y0,z0 = mn; x1,y1,z1 = mx
corners = np.array([
[x0,y0,z0],[x1,y0,z0],[x1,y1,z0],[x0,y1,z0],
[x0,y0,z1],[x1,y0,z1],[x1,y1,z1],[x0,y1,z1]
])
edges = [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),(0,4),(1,5),(2,6),(3,7)]
xs,ys,zs=[],[],[]
for a,b in edges:
xs += [corners[a,0], corners[b,0], None]
ys += [corners[a,1], corners[b,1], None]
zs += [corners[a,2], corners[b,2], None]
return xs,ys,zs
@torch.no_grad()
def rank_views_for_text(model, text, pm, device, topk: int):
img_feats = model.get_image_feature(pm.float().to(device))
txt_feat = model.get_text_feature([text], device=device)[0]
sims = torch.matmul(img_feats, txt_feat)
order = torch.argsort(sims, descending=True)[:max(1, int(topk))]
return order.tolist()
# =========================
# Visualization
# =========================
def depth_values(pts: np.ndarray) -> np.ndarray:
z = pts[:, 2]
z_min, z_max = z.min(), z.max()
return (z - z_min) / (z_max - z_min + 1e-9)
def base_figure_gray_depth(pts: np.ndarray, point_size: float, camera=DEFAULT_CAM) -> go.Figure:
depth = depth_values(pts)
fig = go.Figure(go.Scatter3d(
x=pts[:,0], y=pts[:,1], z=pts[:,2],
mode="markers",
marker=dict(size=point_size, color=depth, colorscale="Greys", reversescale=True, opacity=0.50),
hoverinfo="skip"
))
bx,by,bz = scene_bbox(pts)
fig.add_trace(go.Scatter3d(x=bx,y=by,z=bz,mode="lines",line=dict(color=BOX_COLOR,width=2),hoverinfo="skip"))
fig.update_layout(scene=dict(aspectmode="data",camera=camera),
margin=dict(l=0,r=0,b=0,t=0),
paper_bgcolor=BG_COLOR,
showlegend=False)
return fig
def highlight_views_3d(pts, view_ids, selected, point_size, camera=DEFAULT_CAM):
depth = depth_values(pts)
colors = np.stack([depth, depth, depth], axis=1)
if selected:
sel_mask = np.isin(view_ids, np.array(selected, dtype=np.int32))
colors[sel_mask] = np.array([1, 0, 0])
fig = go.Figure(go.Scatter3d(
x=pts[:,0], y=pts[:,1], z=pts[:,2],
mode="markers",
marker=dict(size=point_size,
color=[f"rgb({int(r*255)},{int(g*255)},{int(b*255)})"
for r,g,b in colors],
opacity=0.98),
hoverinfo="skip"
))
bx,by,bz = scene_bbox(pts)
fig.add_trace(go.Scatter3d(x=bx,y=by,z=bz,mode="lines",
line=dict(color=BOX_COLOR,width=2),hoverinfo="skip"))
fig.update_layout(scene=dict(aspectmode="data",camera=camera),
margin=dict(l=0,r=0,b=0,t=0),
paper_bgcolor=BG_COLOR,
showlegend=False)
return fig
# =========================
# App setup
# =========================
device = 'cpu'
model = build_model(device)
load_pretrain(model, "assets/ckpt_100.pth")
with gr.Blocks(
title="POMA-3D: Text-conditioned 3D Scene Visualization",
css="#plot3d, #img_ref {height: 450px !important;}"
) as demo:
gr.Markdown("### POMA-3D: The Point Map Way to 3D Scene Understanding - Embodied Localization Demo\n"
"Enter agent's situation text and choose **Top-K**; the most relevant views will turn **red**.")
with gr.Row():
text_in = gr.Textbox(label="Text query", value="I am sleeping on the bed.", scale=4)
topk_in = gr.Number(label="Top-K views", value=TOPK_VIEWS_DEFAULT, precision=0, minimum=1, maximum=12)
submit_btn = gr.Button("Locate", variant="primary")
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=500):
plot3d = gr.Plot(label="3D Point Cloud (rotatable)", elem_id="plot3d")
with gr.Column(scale=1, min_width=500):
img_ref = gr.Image(label="Top-Down Reference View", value=TOP_VIEW_IMAGE_PATH, elem_id="img_ref")
status = gr.Markdown()
pm_state = gr.State(None)
pts_state = gr.State(None)
vids_state = gr.State(None)
# Load scene automatically from LOCAL_FILE_DEFAULT
def on_load():
pm = load_scene_local(LOCAL_FILE_DEFAULT)
pts_all, vids_all = stack_views(pm)
pts_vx, vids_vx = voxel_downsample_with_ids(pts_all, vids_all, VOXEL_SIZE)
pts_vx, vids_vx = hard_cap(pts_vx, vids_vx, DOWNSAMPLE_N_MAX)
ps = adaptive_point_size(pts_vx.shape[0])
fig3d = base_figure_gray_depth(pts_vx, ps, camera=DEFAULT_CAM)
msg = f"✅ Loaded {os.path.basename(LOCAL_FILE_DEFAULT)} | Views: {pm.shape[0]} | Points: {pts_vx.shape[0]:,}"
return fig3d, TOP_VIEW_IMAGE_PATH, msg, pm, pts_vx, vids_vx
def on_submit(text, topk, pm, pts_vx, vids_vx):
if pm is None:
return gr.update(), TOP_VIEW_IMAGE_PATH, "⚠️ Scene not loaded yet."
k = int(max(1, min(12, int(topk)))) if topk else TOPK_VIEWS_DEFAULT
top_views = rank_views_for_text(model, text, pm, device, topk=k)
ps = adaptive_point_size(pts_vx.shape[0])
fig = highlight_views_3d(pts_vx, vids_vx, top_views, ps, camera=DEFAULT_CAM)
msg = f"Highlighted views (top-{k}): {top_views}"
return fig, TOP_VIEW_IMAGE_PATH, msg
demo.load(on_load, inputs=[], outputs=[plot3d, img_ref, status, pm_state, pts_state, vids_state])
submit_btn.click(on_submit, inputs=[text_in, topk_in, pm_state, pts_state, vids_state],
outputs=[plot3d, img_ref, status])
if __name__ == "__main__":
demo.launch() |