import os, sys, argparse, threading, json, re, requests from pathlib import Path from types import SimpleNamespace import numpy as np import torch import torch.nn.functional as F import faiss import cv2 import gradio as gr print("Gradio version:", gr.__version__) from huggingface_hub import snapshot_download # ---------------------------- # 0) Config # ---------------------------- # Where to put downloaded assets inside the container ASSET_ROOT = Path("/data/teatime") DATASET_ID = os.getenv("TEATIME_DATASET_ID", "remiii25/teatime_assets") #HF_TOKEN = os.getenv("HF_TOKEN", None) # only needed if dataset is private # Hosted LLM router (Groq) OA_BASE_URL = os.environ.get("OA_BASE_URL", "https://api.groq.com/openai/v1").rstrip("/") OA_API_KEY = os.environ.get("OA_API_KEY", "") OA_MODEL = os.environ.get("OA_MODEL", "llama-3.1-8b-instant") # ProFuse code location inside Docker image REPO_ROOT = "/workspace/ProFuse/feature_registration" if REPO_ROOT not in sys.path: sys.path.append(REPO_ROOT) # Put these near the top of app.py (easy to edit) LINKS = { "github": "https://github.com/chiou1203/ProFuse", "project": "https://chiou1203.github.io/ProFuse/", "arxiv": "https://arxiv.org/abs/2601.04754", "hf_paper": "https://huggingface.co/papers/2601.04754", } LINK_BAR = f"""
""" CSS = """ .linkbar { display:flex; gap:12px; align-items:center; flex-wrap:wrap; margin:10px 0 18px 0; } .linkbar a { padding:10px 16px; /* <-- bigger button */ font-size:16px; /* <-- bigger text */ border:1px solid rgba(255,255,255,0.18); border-radius:999px; text-decoration:none; line-height:1; /* keeps height tight/clean */ } .linkbar a:hover { border-color: rgba(255,255,255,0.40); } """ print("RUNNING app.py from:", __file__) print("DEFAULT DATASET_ID =", "remiii25/teatime_assets") print("ENV TEATIME_DATASET_ID =", os.getenv("TEATIME_DATASET_ID")) print("USING DATASET_ID =", DATASET_ID) # ProFuse imports from scene import Scene from arguments import ModelParams, PipelineParams, get_combined_args from evaluation.openclip_encoder import OpenCLIPNetwork from gaussian_renderer import GaussianModel, render RENDER_OPT = SimpleNamespace(include_feature=False) # ---------------------------- # 1) Download assets if missing # ---------------------------- def ensure_assets(): ASSET_ROOT.mkdir(parents=True, exist_ok=True) # If images folder exists, assume ready if (ASSET_ROOT / "images").exists(): return # Download dataset snapshot into ASSET_ROOT # (public dataset: no token needed) snapshot_download( repo_id=DATASET_ID, repo_type="dataset", local_dir=str(ASSET_ROOT), token=False, # <- IMPORTANT: force anonymous ) ensure_assets() # Paths used by your session code SCENE_DIR = str(ASSET_ROOT) MODEL_DIR = str(ASSET_ROOT) IMAGES_DIR = str(ASSET_ROOT / "images") # PQ index is inside your ProFuse repo (as in Colab) PQ_INDEX = "/workspace/ProFuse/feature_registration/ckpts/pq_index.faiss" # ---------------------------- # 2) Cached Activation Session (your code) # ---------------------------- class ActivationSession: def __init__(self, args, mdl: ModelParams, pip: PipelineParams): self.args = args self.mdl = mdl self.pip = pip self.lock = threading.Lock() self.clip = OpenCLIPNetwork("cuda") self.idx = faiss.read_index(args.pq_index) self.ds = mdl.extract(args) self.pipe = pip.extract(args) self.g = GaussianModel(self.ds.sh_degree) self.scn = Scene(self.ds, self.g, shuffle=False) ckpt = Path(args.model_path) / "chkpnt0.pth" state, _ = torch.load(ckpt) self.g.restore(state, args, mode="test") feats = self.g._language_feature.clone() self.zero = torch.all(feats == -1, dim=-1) decoded = torch.from_numpy(self.idx.sa_decode(feats[~self.zero].cpu().numpy())).to("cuda") self.d_norm = F.normalize(decoded, dim=-1) self.train_cams = self.scn.getTrainCameras() self.test_cams = self.scn.getTestCameras() self._sim_cache = {} @staticmethod def build_args(scene_dir: str, model_dir: str, pq_index: str, white_bg=True, skip_test=True, skip_train=False, rerank=False, canon_words="object,thing,stuff,texture"): p = argparse.ArgumentParser() mdl = ModelParams(p, sentinel=True) pip = PipelineParams(p) p.add_argument("--pq_index", required=True) p.add_argument("--img_label", required=True) p.add_argument("--threshold", type=float, default=0.55) p.add_argument("--frames", type=str, default=None) p.add_argument("--skip_train", action="store_true") p.add_argument("--skip_test", action="store_true") p.add_argument("--white_bg", action="store_true") p.add_argument("--quiet", action="store_true") p.add_argument("--rerank", action="store_true") p.add_argument("--canon_words", type=str, default=canon_words) p.add_argument("--rerank_tau", type=float, default=3.0) arg_list = [ "-s", scene_dir, "-m", model_dir, "--pq_index", pq_index, "--img_label", "init", "--threshold", "0.54", ] if white_bg: arg_list.append("--white_bg") if skip_test: arg_list.append("--skip_test") if skip_train: arg_list.append("--skip_train") if rerank: arg_list.append("--rerank") if canon_words: arg_list += ["--canon_words", canon_words] argv_backup = sys.argv try: sys.argv = ["session_init"] + arg_list args = get_combined_args(p) finally: sys.argv = argv_backup return args, mdl, pip def _compute_sim(self, label: str) -> torch.Tensor: label = label.strip() cache_key = (label, bool(self.args.rerank), getattr(self.args, "canon_words", "")) if cache_key in self._sim_cache: return self._sim_cache[cache_key] if self.args.rerank: canon_terms = [w.strip() for w in self.args.canon_words.split(",") if w.strip()] self.clip.set_positives([label] + canon_terms) else: self.clip.set_positives([label]) sim = torch.zeros((self.g._language_feature.shape[0], 1), device="cuda") sim[~self.zero] = self.clip.get_activation(self.d_norm, 0) self._sim_cache[cache_key] = sim return sim @torch.no_grad() def render_selected_rgb_np(self, label: str, tau: float, frame: str, split: str = "train") -> np.ndarray: tau = float(tau) frame = str(frame) score = self._compute_sim(label) sel = (score.squeeze() > tau).nonzero(as_tuple=True)[0] cams = self.train_cams if split == "train" else self.test_cams idx = int(frame) - 1 cam = cams[idx] bg = torch.ones(3, device=self.g._opacity.device) if self.args.white_bg else torch.zeros(3, device=self.g._opacity.device) with self.lock: opa = self.g._opacity backup = opa.detach().clone() mask = torch.zeros_like(opa, dtype=torch.bool, device=opa.device) mask[sel] = True opa.data[~mask] = 0.0 if float(backup.min()) >= -1e-6 and float(backup.max()) <= 1.0 + 1e-6 else -12.0 pkg = render(cam, self.g, self.pipe, bg, RENDER_OPT) img = (pkg["render"].detach().clamp(0,1).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8) opa.data.copy_(backup) return img SESSION = None def get_session(): global SESSION if SESSION is None: args, mdl, pip = ActivationSession.build_args( scene_dir=SCENE_DIR, model_dir=MODEL_DIR, pq_index=PQ_INDEX, white_bg=True, skip_test=True, rerank=False, ) SESSION = ActivationSession(args, mdl, pip) return SESSION # ---------------------------- # 3) View mapping (your code) # ---------------------------- EXTS = [".png", ".jpg", ".jpeg", ".webp"] def _cam_stem(cam) -> str | None: for attr in ("image_name", "image_path", "image_file", "name", "uid"): if hasattr(cam, attr): v = getattr(cam, attr) if isinstance(v, str) and v.strip(): return Path(v).stem return None def _find_image_by_stem(stem: str, images_dir: str) -> str | None: p = Path(images_dir) for ext in EXTS: exact = p / f"{stem}{ext}" if exact.exists(): return str(exact) for ext in EXTS: cand = sorted(p.glob(f"*{stem}*{ext}"), key=lambda x: x.name) if cand: return str(cand[0]) return None def _find_image_by_5digits(fid5: str, images_dir: str) -> str | None: p = Path(images_dir) for ext in EXTS: cand = sorted(p.glob(f"*{fid5}*{ext}"), key=lambda x: x.name) if cand: return str(cand[0]) return None def build_train_view_mapping(session, images_dir: str): choices, gt_map = [], {} cams = session.train_cams for i, cam in enumerate(cams): key = f"{i+1:05d}" stem = _cam_stem(cam) gt = _find_image_by_stem(stem, images_dir) if stem else None if gt is None: gt = _find_image_by_5digits(key, images_dir) disp = key if not stem else f"{key} ({stem})" choices.append((disp, key)) gt_map[key] = gt return choices, gt_map def load_image_rgb(path: str | None) -> np.ndarray | None: if not path: return None bgr = cv2.imread(path, cv2.IMREAD_COLOR) if bgr is None: return None return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) # ---------------------------- # 4) Groq router (hosted LLM) # ---------------------------- ROUTER_SYSTEM = """You are a routing assistant for an open-vocabulary 3D scene query demo. Extract the target object label the user wants to localize. Return STRICT JSON only. Schema: {"label": string|null, "threshold": number|null, "reply": string} Rules: - If user did not specify an object, label=null and ask a short clarifying question in reply. - If threshold not explicit, threshold=null. - Output JSON ONLY. """ def extract_first_json(text: str) -> dict: m = re.search(r"\{.*\}", text, flags=re.S) if not m: raise ValueError(f"No JSON found. Raw:\n{text[:400]}") return json.loads(m.group(0)) def sanitize_label(label: str) -> str: label = str(label).strip().lower() label = re.sub(r"^(a|an|the)\s+", "", label) label = re.sub(r"[\r\n\t]+", " ", label) return label[:80] def clamp_tau(x) -> float: return max(0.05, min(0.95, float(x))) def route_with_groq(clean_history): if not OA_API_KEY: raise RuntimeError("OA_API_KEY is not set in Space secrets.") payload = { "model": OA_MODEL, "messages": [{"role":"system","content":ROUTER_SYSTEM}] + clean_history, "temperature": 0.0, "max_completion_tokens": 120, } headers = {"Authorization": f"Bearer {OA_API_KEY}", "Content-Type": "application/json"} r = requests.post(f"{OA_BASE_URL}/chat/completions", json=payload, headers=headers, timeout=60) if r.status_code != 200: raise RuntimeError(f"{r.status_code} {r.text[:500]}") return extract_first_json(r.json()["choices"][0]["message"]["content"]) # ---------------------------- # 5) Gradio app # ---------------------------- sess = get_session() view_choices, GT_MAP = build_train_view_mapping(sess, IMAGES_DIR) def load_gt_np_for_view(view_key: str): return load_image_rgb(GT_MAP.get(view_key)) def chat_and_render(user_text, display_history, llm_history, view_key, tau_slider): if display_history is None: display_history = [] if llm_history is None: llm_history = [] user_text = (user_text or "").strip() if not user_text: return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_slider, "" display_history = display_history + [{"role":"user","content": user_text}] llm_history = llm_history + [{"role":"user","content": user_text}] try: obj = route_with_groq(llm_history[-6:]) except Exception as e: msg = f"[Router error] {e}" display_history += [{"role":"assistant","content": msg}] #display_history = display_history + [(user_text, assistant_msg)] llm_history += [{"role":"assistant","content": msg}] return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_slider, "" label = obj.get("label", None) thr = obj.get("threshold", None) reply = obj.get("reply", "") or "" if label is not None: label = sanitize_label(label) if not label: label = None tau_used = float(tau_slider) if thr is not None: try: tau_used = clamp_tau(thr) except: pass if label is None: assistant_msg = reply or "Which object should I look for?" display_history += [{"role":"assistant","content": assistant_msg}] #display_history = display_history + [(user_text, assistant_msg)] llm_history += [{"role":"assistant","content": assistant_msg}] return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_used, "" try: gt = load_gt_np_for_view(view_key) sel = sess.render_selected_rgb_np(label=label, tau=tau_used, frame=view_key, split="train") assistant_msg = reply or f"Showing {label} (tau={tau_used:.2f})." display_history += [{"role":"assistant","content": assistant_msg}] #display_history = display_history + [(user_text, assistant_msg)] llm_history += [{"role":"assistant","content": assistant_msg}] return display_history, llm_history, gt, sel, view_key, tau_used, "" except Exception as e: assistant_msg = f"[Render error] {e}" display_history += [{"role":"assistant","content": assistant_msg}] #display_history = display_history + [(user_text, assistant_msg)] llm_history += [{"role":"assistant","content": assistant_msg}] return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_used, "" with gr.Blocks() as demo: gr.Markdown("# Open-Vocabulary 3D Object Selection with ProFuse") gr.HTML(LINK_BAR) llm_state = gr.State([]) with gr.Row(): view_dd = gr.Dropdown(choices=view_choices, value=view_choices[0][1], label="View") tau_in = gr.Slider(0.05, 0.95, value=0.54, step=0.01, label="Threshold") with gr.Row(): img_gt = gr.Image(type="numpy", label="Original") img_sel = gr.Image(type="numpy", label="Selected") chatbot = gr.Chatbot(type="messages", label="Chat") #chatbot = gr.Chatbot(label="Chat") msg = gr.Textbox(placeholder="Ask: Where is the apple?", label="Message") send = gr.Button("Send") demo.load(fn=load_gt_np_for_view, inputs=view_dd, outputs=img_gt) view_dd.change(fn=load_gt_np_for_view, inputs=view_dd, outputs=img_gt) send.click( fn=chat_and_render, inputs=[msg, chatbot, llm_state, view_dd, tau_in], outputs=[chatbot, llm_state, img_gt, img_sel, view_dd, tau_in, msg], ) msg.submit( fn=chat_and_render, inputs=[msg, chatbot, llm_state, view_dd, tau_in], outputs=[chatbot, llm_state, img_gt, img_sel, view_dd, tau_in, msg], ) PORT = int(os.environ.get("PORT", "7860")) demo.queue().launch(server_name="0.0.0.0", server_port=PORT, debug=True)