remiii25's picture
Update app.py
e9ad0c1 verified
import os, sys, argparse, threading, json, re, requests
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn.functional as F
import faiss
import cv2
import gradio as gr
print("Gradio version:", gr.__version__)
from huggingface_hub import snapshot_download
# ----------------------------
# 0) Config
# ----------------------------
# Where to put downloaded assets inside the container
ASSET_ROOT = Path("/data/teatime")
DATASET_ID = os.getenv("TEATIME_DATASET_ID", "remiii25/teatime_assets")
#HF_TOKEN = os.getenv("HF_TOKEN", None) # only needed if dataset is private
# Hosted LLM router (Groq)
OA_BASE_URL = os.environ.get("OA_BASE_URL", "https://api.groq.com/openai/v1").rstrip("/")
OA_API_KEY = os.environ.get("OA_API_KEY", "")
OA_MODEL = os.environ.get("OA_MODEL", "llama-3.1-8b-instant")
# ProFuse code location inside Docker image
REPO_ROOT = "/workspace/ProFuse/feature_registration"
if REPO_ROOT not in sys.path:
sys.path.append(REPO_ROOT)
# Put these near the top of app.py (easy to edit)
LINKS = {
"github": "https://github.com/chiou1203/ProFuse",
"project": "https://chiou1203.github.io/ProFuse/",
"arxiv": "https://arxiv.org/abs/2601.04754",
"hf_paper": "https://huggingface.co/papers/2601.04754",
}
LINK_BAR = f"""
<div class="linkbar">
<a href="{LINKS['github']}" target="_blank" rel="noopener">πŸ™ GitHub</a>
<a href="{LINKS['project']}" target="_blank" rel="noopener">🌐 Project</a>
<a href="{LINKS['arxiv']}" target="_blank" rel="noopener">πŸ“„ arXiv</a>
<a href="{LINKS['hf_paper']}" target="_blank" rel="noopener">πŸ€— HF Paper</a>
</div>
"""
CSS = """
.linkbar { display:flex; gap:12px; align-items:center; flex-wrap:wrap; margin:10px 0 18px 0; }
.linkbar a {
padding:10px 16px; /* <-- bigger button */
font-size:16px; /* <-- bigger text */
border:1px solid rgba(255,255,255,0.18);
border-radius:999px;
text-decoration:none;
line-height:1; /* keeps height tight/clean */
}
.linkbar a:hover { border-color: rgba(255,255,255,0.40); }
"""
print("RUNNING app.py from:", __file__)
print("DEFAULT DATASET_ID =", "remiii25/teatime_assets")
print("ENV TEATIME_DATASET_ID =", os.getenv("TEATIME_DATASET_ID"))
print("USING DATASET_ID =", DATASET_ID)
# ProFuse imports
from scene import Scene
from arguments import ModelParams, PipelineParams, get_combined_args
from evaluation.openclip_encoder import OpenCLIPNetwork
from gaussian_renderer import GaussianModel, render
RENDER_OPT = SimpleNamespace(include_feature=False)
# ----------------------------
# 1) Download assets if missing
# ----------------------------
def ensure_assets():
ASSET_ROOT.mkdir(parents=True, exist_ok=True)
# If images folder exists, assume ready
if (ASSET_ROOT / "images").exists():
return
# Download dataset snapshot into ASSET_ROOT
# (public dataset: no token needed)
snapshot_download(
repo_id=DATASET_ID,
repo_type="dataset",
local_dir=str(ASSET_ROOT),
token=False, # <- IMPORTANT: force anonymous
)
ensure_assets()
# Paths used by your session code
SCENE_DIR = str(ASSET_ROOT)
MODEL_DIR = str(ASSET_ROOT)
IMAGES_DIR = str(ASSET_ROOT / "images")
# PQ index is inside your ProFuse repo (as in Colab)
PQ_INDEX = "/workspace/ProFuse/feature_registration/ckpts/pq_index.faiss"
# ----------------------------
# 2) Cached Activation Session (your code)
# ----------------------------
class ActivationSession:
def __init__(self, args, mdl: ModelParams, pip: PipelineParams):
self.args = args
self.mdl = mdl
self.pip = pip
self.lock = threading.Lock()
self.clip = OpenCLIPNetwork("cuda")
self.idx = faiss.read_index(args.pq_index)
self.ds = mdl.extract(args)
self.pipe = pip.extract(args)
self.g = GaussianModel(self.ds.sh_degree)
self.scn = Scene(self.ds, self.g, shuffle=False)
ckpt = Path(args.model_path) / "chkpnt0.pth"
state, _ = torch.load(ckpt)
self.g.restore(state, args, mode="test")
feats = self.g._language_feature.clone()
self.zero = torch.all(feats == -1, dim=-1)
decoded = torch.from_numpy(self.idx.sa_decode(feats[~self.zero].cpu().numpy())).to("cuda")
self.d_norm = F.normalize(decoded, dim=-1)
self.train_cams = self.scn.getTrainCameras()
self.test_cams = self.scn.getTestCameras()
self._sim_cache = {}
@staticmethod
def build_args(scene_dir: str, model_dir: str, pq_index: str,
white_bg=True, skip_test=True, skip_train=False, rerank=False,
canon_words="object,thing,stuff,texture"):
p = argparse.ArgumentParser()
mdl = ModelParams(p, sentinel=True)
pip = PipelineParams(p)
p.add_argument("--pq_index", required=True)
p.add_argument("--img_label", required=True)
p.add_argument("--threshold", type=float, default=0.55)
p.add_argument("--frames", type=str, default=None)
p.add_argument("--skip_train", action="store_true")
p.add_argument("--skip_test", action="store_true")
p.add_argument("--white_bg", action="store_true")
p.add_argument("--quiet", action="store_true")
p.add_argument("--rerank", action="store_true")
p.add_argument("--canon_words", type=str, default=canon_words)
p.add_argument("--rerank_tau", type=float, default=3.0)
arg_list = [
"-s", scene_dir,
"-m", model_dir,
"--pq_index", pq_index,
"--img_label", "init",
"--threshold", "0.54",
]
if white_bg: arg_list.append("--white_bg")
if skip_test: arg_list.append("--skip_test")
if skip_train: arg_list.append("--skip_train")
if rerank: arg_list.append("--rerank")
if canon_words:
arg_list += ["--canon_words", canon_words]
argv_backup = sys.argv
try:
sys.argv = ["session_init"] + arg_list
args = get_combined_args(p)
finally:
sys.argv = argv_backup
return args, mdl, pip
def _compute_sim(self, label: str) -> torch.Tensor:
label = label.strip()
cache_key = (label, bool(self.args.rerank), getattr(self.args, "canon_words", ""))
if cache_key in self._sim_cache:
return self._sim_cache[cache_key]
if self.args.rerank:
canon_terms = [w.strip() for w in self.args.canon_words.split(",") if w.strip()]
self.clip.set_positives([label] + canon_terms)
else:
self.clip.set_positives([label])
sim = torch.zeros((self.g._language_feature.shape[0], 1), device="cuda")
sim[~self.zero] = self.clip.get_activation(self.d_norm, 0)
self._sim_cache[cache_key] = sim
return sim
@torch.no_grad()
def render_selected_rgb_np(self, label: str, tau: float, frame: str, split: str = "train") -> np.ndarray:
tau = float(tau)
frame = str(frame)
score = self._compute_sim(label)
sel = (score.squeeze() > tau).nonzero(as_tuple=True)[0]
cams = self.train_cams if split == "train" else self.test_cams
idx = int(frame) - 1
cam = cams[idx]
bg = torch.ones(3, device=self.g._opacity.device) if self.args.white_bg else torch.zeros(3, device=self.g._opacity.device)
with self.lock:
opa = self.g._opacity
backup = opa.detach().clone()
mask = torch.zeros_like(opa, dtype=torch.bool, device=opa.device)
mask[sel] = True
opa.data[~mask] = 0.0 if float(backup.min()) >= -1e-6 and float(backup.max()) <= 1.0 + 1e-6 else -12.0
pkg = render(cam, self.g, self.pipe, bg, RENDER_OPT)
img = (pkg["render"].detach().clamp(0,1).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
opa.data.copy_(backup)
return img
SESSION = None
def get_session():
global SESSION
if SESSION is None:
args, mdl, pip = ActivationSession.build_args(
scene_dir=SCENE_DIR,
model_dir=MODEL_DIR,
pq_index=PQ_INDEX,
white_bg=True,
skip_test=True,
rerank=False,
)
SESSION = ActivationSession(args, mdl, pip)
return SESSION
# ----------------------------
# 3) View mapping (your code)
# ----------------------------
EXTS = [".png", ".jpg", ".jpeg", ".webp"]
def _cam_stem(cam) -> str | None:
for attr in ("image_name", "image_path", "image_file", "name", "uid"):
if hasattr(cam, attr):
v = getattr(cam, attr)
if isinstance(v, str) and v.strip():
return Path(v).stem
return None
def _find_image_by_stem(stem: str, images_dir: str) -> str | None:
p = Path(images_dir)
for ext in EXTS:
exact = p / f"{stem}{ext}"
if exact.exists():
return str(exact)
for ext in EXTS:
cand = sorted(p.glob(f"*{stem}*{ext}"), key=lambda x: x.name)
if cand:
return str(cand[0])
return None
def _find_image_by_5digits(fid5: str, images_dir: str) -> str | None:
p = Path(images_dir)
for ext in EXTS:
cand = sorted(p.glob(f"*{fid5}*{ext}"), key=lambda x: x.name)
if cand:
return str(cand[0])
return None
def build_train_view_mapping(session, images_dir: str):
choices, gt_map = [], {}
cams = session.train_cams
for i, cam in enumerate(cams):
key = f"{i+1:05d}"
stem = _cam_stem(cam)
gt = _find_image_by_stem(stem, images_dir) if stem else None
if gt is None:
gt = _find_image_by_5digits(key, images_dir)
disp = key if not stem else f"{key} ({stem})"
choices.append((disp, key))
gt_map[key] = gt
return choices, gt_map
def load_image_rgb(path: str | None) -> np.ndarray | None:
if not path:
return None
bgr = cv2.imread(path, cv2.IMREAD_COLOR)
if bgr is None:
return None
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
# ----------------------------
# 4) Groq router (hosted LLM)
# ----------------------------
ROUTER_SYSTEM = """You are a routing assistant for an open-vocabulary 3D scene query demo.
Extract the target object label the user wants to localize.
Return STRICT JSON only.
Schema:
{"label": string|null, "threshold": number|null, "reply": string}
Rules:
- If user did not specify an object, label=null and ask a short clarifying question in reply.
- If threshold not explicit, threshold=null.
- Output JSON ONLY.
"""
def extract_first_json(text: str) -> dict:
m = re.search(r"\{.*\}", text, flags=re.S)
if not m:
raise ValueError(f"No JSON found. Raw:\n{text[:400]}")
return json.loads(m.group(0))
def sanitize_label(label: str) -> str:
label = str(label).strip().lower()
label = re.sub(r"^(a|an|the)\s+", "", label)
label = re.sub(r"[\r\n\t]+", " ", label)
return label[:80]
def clamp_tau(x) -> float:
return max(0.05, min(0.95, float(x)))
def route_with_groq(clean_history):
if not OA_API_KEY:
raise RuntimeError("OA_API_KEY is not set in Space secrets.")
payload = {
"model": OA_MODEL,
"messages": [{"role":"system","content":ROUTER_SYSTEM}] + clean_history,
"temperature": 0.0,
"max_completion_tokens": 120,
}
headers = {"Authorization": f"Bearer {OA_API_KEY}", "Content-Type": "application/json"}
r = requests.post(f"{OA_BASE_URL}/chat/completions", json=payload, headers=headers, timeout=60)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code} {r.text[:500]}")
return extract_first_json(r.json()["choices"][0]["message"]["content"])
# ----------------------------
# 5) Gradio app
# ----------------------------
sess = get_session()
view_choices, GT_MAP = build_train_view_mapping(sess, IMAGES_DIR)
def load_gt_np_for_view(view_key: str):
return load_image_rgb(GT_MAP.get(view_key))
def chat_and_render(user_text, display_history, llm_history, view_key, tau_slider):
if display_history is None: display_history = []
if llm_history is None: llm_history = []
user_text = (user_text or "").strip()
if not user_text:
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_slider, ""
display_history = display_history + [{"role":"user","content": user_text}]
llm_history = llm_history + [{"role":"user","content": user_text}]
try:
obj = route_with_groq(llm_history[-6:])
except Exception as e:
msg = f"[Router error] {e}"
display_history += [{"role":"assistant","content": msg}]
#display_history = display_history + [(user_text, assistant_msg)]
llm_history += [{"role":"assistant","content": msg}]
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_slider, ""
label = obj.get("label", None)
thr = obj.get("threshold", None)
reply = obj.get("reply", "") or ""
if label is not None:
label = sanitize_label(label)
if not label:
label = None
tau_used = float(tau_slider)
if thr is not None:
try: tau_used = clamp_tau(thr)
except: pass
if label is None:
assistant_msg = reply or "Which object should I look for?"
display_history += [{"role":"assistant","content": assistant_msg}]
#display_history = display_history + [(user_text, assistant_msg)]
llm_history += [{"role":"assistant","content": assistant_msg}]
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_used, ""
try:
gt = load_gt_np_for_view(view_key)
sel = sess.render_selected_rgb_np(label=label, tau=tau_used, frame=view_key, split="train")
assistant_msg = reply or f"Showing {label} (tau={tau_used:.2f})."
display_history += [{"role":"assistant","content": assistant_msg}]
#display_history = display_history + [(user_text, assistant_msg)]
llm_history += [{"role":"assistant","content": assistant_msg}]
return display_history, llm_history, gt, sel, view_key, tau_used, ""
except Exception as e:
assistant_msg = f"[Render error] {e}"
display_history += [{"role":"assistant","content": assistant_msg}]
#display_history = display_history + [(user_text, assistant_msg)]
llm_history += [{"role":"assistant","content": assistant_msg}]
return display_history, llm_history, load_gt_np_for_view(view_key), None, view_key, tau_used, ""
with gr.Blocks() as demo:
gr.Markdown("# Open-Vocabulary 3D Object Selection with ProFuse")
gr.HTML(LINK_BAR)
llm_state = gr.State([])
with gr.Row():
view_dd = gr.Dropdown(choices=view_choices, value=view_choices[0][1], label="View")
tau_in = gr.Slider(0.05, 0.95, value=0.54, step=0.01, label="Threshold")
with gr.Row():
img_gt = gr.Image(type="numpy", label="Original")
img_sel = gr.Image(type="numpy", label="Selected")
chatbot = gr.Chatbot(type="messages", label="Chat")
#chatbot = gr.Chatbot(label="Chat")
msg = gr.Textbox(placeholder="Ask: Where is the apple?", label="Message")
send = gr.Button("Send")
demo.load(fn=load_gt_np_for_view, inputs=view_dd, outputs=img_gt)
view_dd.change(fn=load_gt_np_for_view, inputs=view_dd, outputs=img_gt)
send.click(
fn=chat_and_render,
inputs=[msg, chatbot, llm_state, view_dd, tau_in],
outputs=[chatbot, llm_state, img_gt, img_sel, view_dd, tau_in, msg],
)
msg.submit(
fn=chat_and_render,
inputs=[msg, chatbot, llm_state, view_dd, tau_in],
outputs=[chatbot, llm_state, img_gt, img_sel, view_dd, tau_in, msg],
)
PORT = int(os.environ.get("PORT", "7860"))
demo.queue().launch(server_name="0.0.0.0", server_port=PORT, debug=True)