|
|
import os, sys, argparse, threading, json, re, requests |
|
|
from pathlib import Path |
|
|
from types import SimpleNamespace |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import faiss |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
print("Gradio version:", gr.__version__) |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ASSET_ROOT = Path("/data/teatime") |
|
|
DATASET_ID = os.getenv("TEATIME_DATASET_ID", "remiii25/teatime_assets") |
|
|
|
|
|
|
|
|
|
|
|
OA_BASE_URL = os.environ.get("OA_BASE_URL", "https://api.groq.com/openai/v1").rstrip("/") |
|
|
OA_API_KEY = os.environ.get("OA_API_KEY", "") |
|
|
OA_MODEL = os.environ.get("OA_MODEL", "llama-3.1-8b-instant") |
|
|
|
|
|
|
|
|
REPO_ROOT = "/workspace/ProFuse/feature_registration" |
|
|
if REPO_ROOT not in sys.path: |
|
|
sys.path.append(REPO_ROOT) |
|
|
|
|
|
LINKS = { |
|
|
"github": "https://github.com/chiou1203/ProFuse", |
|
|
"project": "https://chiou1203.github.io/ProFuse/", |
|
|
"arxiv": "https://arxiv.org/abs/2601.04754", |
|
|
"hf_paper": "https://huggingface.co/papers/2601.04754", |
|
|
} |
|
|
|
|
|
LINK_BAR = f""" |
|
|
<div class="linkbar"> |
|
|
<a href="{LINKS['github']}" target="_blank" rel="noopener">π GitHub</a> |
|
|
<a href="{LINKS['project']}" target="_blank" rel="noopener">π Project</a> |
|
|
<a href="{LINKS['arxiv']}" target="_blank" rel="noopener">π arXiv</a> |
|
|
<a href="{LINKS['hf_paper']}" target="_blank" rel="noopener">π€ HF Paper</a> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
CSS = """ |
|
|
.linkbar { display:flex; gap:12px; align-items:center; flex-wrap:wrap; margin:10px 0 18px 0; } |
|
|
.linkbar a { |
|
|
padding:10px 16px; /* <-- bigger button */ |
|
|
font-size:16px; /* <-- bigger text */ |
|
|
border:1px solid rgba(255,255,255,0.18); |
|
|
border-radius:999px; |
|
|
text-decoration:none; |
|
|
line-height:1; /* keeps height tight/clean */ |
|
|
} |
|
|
.linkbar a:hover { border-color: rgba(255,255,255,0.40); } |
|
|
""" |
|
|
|
|
|
|
|
|
print("RUNNING app.py from:", __file__) |
|
|
print("DEFAULT DATASET_ID =", "remiii25/teatime_assets") |
|
|
print("ENV TEATIME_DATASET_ID =", os.getenv("TEATIME_DATASET_ID")) |
|
|
print("USING DATASET_ID =", DATASET_ID) |
|
|
|
|
|
|
|
|
from scene import Scene |
|
|
from arguments import ModelParams, PipelineParams, get_combined_args |
|
|
from evaluation.openclip_encoder import OpenCLIPNetwork |
|
|
from gaussian_renderer import GaussianModel, render |
|
|
|
|
|
RENDER_OPT = SimpleNamespace(include_feature=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_assets(): |
|
|
ASSET_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if (ASSET_ROOT / "images").exists(): |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
snapshot_download( |
|
|
repo_id=DATASET_ID, |
|
|
repo_type="dataset", |
|
|
local_dir=str(ASSET_ROOT), |
|
|
token=False, |
|
|
) |
|
|
|
|
|
ensure_assets() |
|
|
|
|
|
|
|
|
SCENE_DIR = str(ASSET_ROOT) |
|
|
MODEL_DIR = str(ASSET_ROOT) |
|
|
IMAGES_DIR = str(ASSET_ROOT / "images") |
|
|
|
|
|
|
|
|
PQ_INDEX = "/workspace/ProFuse/feature_registration/ckpts/pq_index.faiss" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActivationSession: |
|
|
def __init__(self, args, mdl: ModelParams, pip: PipelineParams): |
|
|
self.args = args |
|
|
self.mdl = mdl |
|
|
self.pip = pip |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
self.clip = OpenCLIPNetwork("cuda") |
|
|
self.idx = faiss.read_index(args.pq_index) |
|
|
|
|
|
self.ds = mdl.extract(args) |
|
|
self.pipe = pip.extract(args) |
|
|
|
|
|
self.g = GaussianModel(self.ds.sh_degree) |
|
|
self.scn = Scene(self.ds, self.g, shuffle=False) |
|
|
|
|
|
ckpt = Path(args.model_path) / "chkpnt0.pth" |
|
|
state, _ = torch.load(ckpt) |
|
|
self.g.restore(state, args, mode="test") |
|
|
|
|
|
feats = self.g._language_feature.clone() |
|
|
self.zero = torch.all(feats == -1, dim=-1) |
|
|
decoded = torch.from_numpy(self.idx.sa_decode(feats[~self.zero].cpu().numpy())).to("cuda") |
|
|
self.d_norm = F.normalize(decoded, dim=-1) |
|
|
|
|
|
self.train_cams = self.scn.getTrainCameras() |
|
|
self.test_cams = self.scn.getTestCameras() |
|
|
self._sim_cache = {} |
|
|
|
|
|
@staticmethod |
|
|
def build_args(scene_dir: str, model_dir: str, pq_index: str, |
|
|
white_bg=True, skip_test=True, skip_train=False, rerank=False, |
|
|
canon_words="object,thing,stuff,texture"): |
|
|
|
|
|
p = argparse.ArgumentParser() |
|
|
mdl = ModelParams(p, sentinel=True) |
|
|
pip = PipelineParams(p) |
|
|
|
|
|
p.add_argument("--pq_index", required=True) |
|
|
p.add_argument("--img_label", required=True) |
|
|
p.add_argument("--threshold", type=float, default=0.55) |
|
|
p.add_argument("--frames", type=str, default=None) |
|
|
p.add_argument("--skip_train", action="store_true") |
|
|
p.add_argument("--skip_test", action="store_true") |
|
|
p.add_argument("--white_bg", action="store_true") |
|
|
p.add_argument("--quiet", action="store_true") |
|
|
p.add_argument("--rerank", action="store_true") |
|
|
p.add_argument("--canon_words", type=str, default=canon_words) |
|
|
p.add_argument("--rerank_tau", type=float, default=3.0) |
|
|
|
|
|
arg_list = [ |
|
|
"-s", scene_dir, |
|
|
"-m", model_dir, |
|
|
"--pq_index", pq_index, |
|
|
"--img_label", "init", |
|
|
"--threshold", "0.54", |
|
|
] |
|
|
if white_bg: arg_list.append("--white_bg") |
|
|
if skip_test: arg_list.append("--skip_test") |
|
|
if skip_train: arg_list.append("--skip_train") |
|
|
if rerank: arg_list.append("--rerank") |
|
|
if canon_words: |
|
|
arg_list += ["--canon_words", canon_words] |
|
|
|
|
|
argv_backup = sys.argv |
|
|
try: |
|
|
sys.argv = ["session_init"] + arg_list |
|
|
args = get_combined_args(p) |
|
|
finally: |
|
|
sys.argv = argv_backup |
|
|
|
|
|
return args, mdl, pip |
|
|
|
|
|
def _compute_sim(self, label: str) -> torch.Tensor: |
|
|
label = label.strip() |
|
|
cache_key = (label, bool(self.args.rerank), getattr(self.args, "canon_words", "")) |
|
|
|
|
|
if cache_key in self._sim_cache: |
|
|
return self._sim_cache[cache_key] |
|
|
|
|
|
if self.args.rerank: |
|
|
canon_terms = [w.strip() for w in self.args.canon_words.split(",") if w.strip()] |
|
|
self.clip.set_positives([label] + canon_terms) |
|
|
else: |
|
|
self.clip.set_positives([label]) |
|
|
|
|
|
sim = torch.zeros((self.g._language_feature.shape[0], 1), device="cuda") |
|
|
sim[~self.zero] = self.clip.get_activation(self.d_norm, 0) |
|
|
|
|
|
self._sim_cache[cache_key] = sim |
|
|
return sim |
|
|
|
|
|
@torch.no_grad() |
|
|
def render_selected_rgb_np(self, label: str, tau: float, frame: str, split: str = "train") -> np.ndarray: |
|
|
tau = float(tau) |
|
|
frame = str(frame) |
|
|
|
|
|
score = self._compute_sim(label) |
|
|
sel = (score.squeeze() > tau).nonzero(as_tuple=True)[0] |
|
|
|
|
|
cams = self.train_cams if split == "train" else self.test_cams |
|
|
idx = int(frame) - 1 |
|
|
cam = cams[idx] |
|
|
|
|
|
bg = torch.ones(3, device=self.g._opacity.device) if self.args.white_bg else torch.zeros(3, device=self.g._opacity.device) |
|
|
|
|
|
with self.lock: |
|
|
opa = self.g._opacity |
|
|
backup = opa.detach().clone() |
|
|
|
|
|
mask = torch.zeros_like(opa, dtype=torch.bool, device=opa.device) |
|
|
mask[sel] = True |
|
|
opa.data[~mask] = 0.0 if float(backup.min()) >= -1e-6 and float(backup.max()) <= 1.0 + 1e-6 else -12.0 |
|
|
|
|
|
pkg = render(cam, self.g, self.pipe, bg, RENDER_OPT) |
|
|
img = (pkg["render"].detach().clamp(0,1).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
|
opa.data.copy_(backup) |
|
|
|
|
|
return img |
|
|
|
|
|
SESSION = None |
|
|
def get_session(): |
|
|
global SESSION |
|
|
if SESSION is None: |
|
|
args, mdl, pip = ActivationSession.build_args( |
|
|
scene_dir=SCENE_DIR, |
|
|
model_dir=MODEL_DIR, |
|
|
pq_index=PQ_INDEX, |
|
|
white_bg=True, |
|
|
skip_test=True, |
|
|
rerank=False, |
|
|
) |
|
|
SESSION = ActivationSession(args, mdl, pip) |
|
|
return SESSION |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXTS = [".png", ".jpg", ".jpeg", ".webp"] |
|
|
|
|
|
def _cam_stem(cam) -> str | None: |
|
|
for attr in ("image_name", "image_path", "image_file", "name", "uid"): |
|
|
if hasattr(cam, attr): |
|
|
v = getattr(cam, attr) |
|
|
if isinstance(v, str) and v.strip(): |
|
|
return Path(v).stem |
|
|
return None |
|
|
|
|
|
def _find_image_by_stem(stem: str, images_dir: str) -> str | None: |
|
|
p = Path(images_dir) |
|
|
for ext in EXTS: |
|
|
exact = p / f"{stem}{ext}" |
|
|
if exact.exists(): |
|
|
return str(exact) |
|
|
for ext in EXTS: |
|
|
cand = sorted(p.glob(f"*{stem}*{ext}"), key=lambda x: x.name) |
|
|
if cand: |
|
|
return str(cand[0]) |
|
|
return None |
|
|
|
|
|
def _find_image_by_5digits(fid5: str, images_dir: str) -> str | None: |
|
|
p = Path(images_dir) |
|
|
for ext in EXTS: |
|
|
cand = sorted(p.glob(f"*{fid5}*{ext}"), key=lambda x: x.name) |
|
|
if cand: |
|
|
return str(cand[0]) |
|
|
return None |
|
|
|
|
|
def build_train_view_mapping(session, images_dir: str): |
|
|
choices, gt_map = [], {} |
|
|
cams = session.train_cams |
|
|
for i, cam in enumerate(cams): |
|
|
key = f"{i+1:05d}" |
|
|
stem = _cam_stem(cam) |
|
|
gt = _find_image_by_stem(stem, images_dir) if stem else None |
|
|
if gt is None: |
|
|
gt = _find_image_by_5digits(key, images_dir) |
|
|
disp = key if not stem else f"{key} ({stem})" |
|
|
choices.append((disp, key)) |
|
|
gt_map[key] = gt |
|
|
return choices, gt_map |
|
|
|
|
|
def load_image_rgb(path: str | None) -> np.ndarray | None: |
|
|
if not path: |
|
|
return None |
|
|
bgr = cv2.imread(path, cv2.IMREAD_COLOR) |
|
|
if bgr is None: |
|
|
return None |
|
|
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ROUTER_SYSTEM = """You are a routing assistant for an open-vocabulary 3D scene query demo. |
|
|
Extract the target object label the user wants to localize. |
|
|
Return STRICT JSON only. |
|
|
|
|
|
Schema: |
|
|
{"label": string|null, "threshold": number|null, "reply": string} |
|
|
|
|
|
Rules: |
|
|
- If user did not specify an object, label=null and ask a short clarifying question in reply. |
|
|
- If threshold not explicit, threshold=null. |
|
|
- Output JSON ONLY. |
|
|
""" |
|
|
|
|
|
def extract_first_json(text: str) -> dict: |
|
|
m = re.search(r"\{.*\}", text, flags=re.S) |
|
|
if not m: |
|
|
raise ValueError(f"No JSON found. Raw:\n{text[:400]}") |
|
|
return json.loads(m.group(0)) |
|
|
|
|
|
def sanitize_label(label: str) -> str: |
|
|
label = str(label).strip().lower() |
|
|
label = re.sub(r"^(a|an|the)\s+", "", label) |
|
|
label = re.sub(r"[\r\n\t]+", " ", label) |
|
|
return label[:80] |
|
|
|
|
|
def clamp_tau(x) -> float: |
|
|
return max(0.05, min(0.95, float(x))) |
|
|
|
|
|
def route_with_groq(clean_history): |
|
|
if not OA_API_KEY: |
|
|
raise RuntimeError("OA_API_KEY is not set in Space secrets.") |
|
|
payload = { |
|
|
"model": OA_MODEL, |
|
|
"messages": [{"role":"system","content":ROUTER_SYSTEM}] + clean_history, |
|
|
"temperature": 0.0, |
|
|
"max_completion_tokens": 120, |
|
|
} |
|
|
headers = {"Authorization": f"Bearer {OA_API_KEY}", "Content-Type": "application/json"} |
|
|
r = requests.post(f"{OA_BASE_URL}/chat/completions", json=payload, headers=headers, timeout=60) |
|
|
if r.status_code != 200: |
|
|
raise RuntimeError(f"{r.status_code} {r.text[:500]}") |
|
|
return extract_first_json(r.json()["choices"][0]["message"]["content"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sess = get_session() |
|
|
view_choices, GT_MAP = build_train_view_mapping(sess, IMAGES_DIR) |
|
|
|
|
|
def load_gt_np_for_view(view_key: str): |
|
|
return load_image_rgb(GT_MAP.get(view_key)) |
|
|
|
|
|
def chat_and_render(user_text, display_history, llm_history, view_key, tau_slider): |
|
|
if display_history is None: display_history = [] |
|
|
if llm_history is None: llm_history = [] |
|
|
|
|
|
user_text = (user_text or "").strip() |
|
|
if not user_text: |
|
|
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_slider, "" |
|
|
|
|
|
|
|
|
display_history = display_history + [{"role":"user","content": user_text}] |
|
|
llm_history = llm_history + [{"role":"user","content": user_text}] |
|
|
|
|
|
try: |
|
|
obj = route_with_groq(llm_history[-6:]) |
|
|
except Exception as e: |
|
|
msg = f"[Router error] {e}" |
|
|
display_history += [{"role":"assistant","content": msg}] |
|
|
|
|
|
llm_history += [{"role":"assistant","content": msg}] |
|
|
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_slider, "" |
|
|
|
|
|
label = obj.get("label", None) |
|
|
thr = obj.get("threshold", None) |
|
|
reply = obj.get("reply", "") or "" |
|
|
|
|
|
if label is not None: |
|
|
label = sanitize_label(label) |
|
|
if not label: |
|
|
label = None |
|
|
|
|
|
tau_used = float(tau_slider) |
|
|
if thr is not None: |
|
|
try: tau_used = clamp_tau(thr) |
|
|
except: pass |
|
|
|
|
|
if label is None: |
|
|
assistant_msg = reply or "Which object should I look for?" |
|
|
display_history += [{"role":"assistant","content": assistant_msg}] |
|
|
|
|
|
llm_history += [{"role":"assistant","content": assistant_msg}] |
|
|
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_used, "" |
|
|
|
|
|
try: |
|
|
gt = load_gt_np_for_view(view_key) |
|
|
sel = sess.render_selected_rgb_np(label=label, tau=tau_used, frame=view_key, split="train") |
|
|
assistant_msg = reply or f"Showing {label} (tau={tau_used:.2f})." |
|
|
display_history += [{"role":"assistant","content": assistant_msg}] |
|
|
|
|
|
llm_history += [{"role":"assistant","content": assistant_msg}] |
|
|
return display_history, llm_history, gt, sel, view_key, tau_used, "" |
|
|
except Exception as e: |
|
|
assistant_msg = f"[Render error] {e}" |
|
|
display_history += [{"role":"assistant","content": assistant_msg}] |
|
|
|
|
|
llm_history += [{"role":"assistant","content": assistant_msg}] |
|
|
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_used, "" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Open-Vocabulary 3D Object Selection with ProFuse") |
|
|
gr.HTML(LINK_BAR) |
|
|
|
|
|
llm_state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
|
view_dd = gr.Dropdown(choices=view_choices, value=view_choices[0][1], label="View") |
|
|
tau_in = gr.Slider(0.05, 0.95, value=0.54, step=0.01, label="Threshold") |
|
|
|
|
|
with gr.Row(): |
|
|
img_gt = gr.Image(type="numpy", label="Original") |
|
|
img_sel = gr.Image(type="numpy", label="Selected") |
|
|
|
|
|
chatbot = gr.Chatbot(type="messages", label="Chat") |
|
|
|
|
|
msg = gr.Textbox(placeholder="Ask: Where is the apple?", label="Message") |
|
|
send = gr.Button("Send") |
|
|
|
|
|
demo.load(fn=load_gt_np_for_view, inputs=view_dd, outputs=img_gt) |
|
|
view_dd.change(fn=load_gt_np_for_view, inputs=view_dd, outputs=img_gt) |
|
|
|
|
|
send.click( |
|
|
fn=chat_and_render, |
|
|
inputs=[msg, chatbot, llm_state, view_dd, tau_in], |
|
|
outputs=[chatbot, llm_state, img_gt, img_sel, view_dd, tau_in, msg], |
|
|
) |
|
|
msg.submit( |
|
|
fn=chat_and_render, |
|
|
inputs=[msg, chatbot, llm_state, view_dd, tau_in], |
|
|
outputs=[chatbot, llm_state, img_gt, img_sel, view_dd, tau_in, msg], |
|
|
) |
|
|
|
|
|
PORT = int(os.environ.get("PORT", "7860")) |
|
|
demo.queue().launch(server_name="0.0.0.0", server_port=PORT, debug=True) |
|
|
|