Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Simplified Gradio demo for Search-TTA evaluation. | |
| """ | |
| # ββββββββββββββββββββββββββ imports βββββββββββββββββββββββββββββββββββ | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg", force=True) | |
| import gradio as gr | |
| import ctypes # for safely stopping background threads | |
| import os, glob, threading, time | |
| import torch | |
| from PIL import Image | |
| import json | |
| import shutil | |
| import spaces # integration with ZeroGPU on hf | |
| from planner.test_parameter import * | |
| from planner.model import PolicyNet | |
| from planner.test_worker import TestWorker | |
| from taxabind_avs.satbind.clip_seg_tta import ClipSegTTA | |
| # Helper to kill a Python thread by injecting SystemExit | |
| def _stop_thread(thread: threading.Thread): | |
| """Forcefully raise SystemExit in the given thread (best-effort).""" | |
| if thread is None or not thread.is_alive(): | |
| return | |
| tid = thread.ident | |
| if tid is None: | |
| return | |
| # Ask CPython to raise SystemExit in the thread context | |
| res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit)) | |
| if res > 1: | |
| # If it returned >1, cleanup and fail safe | |
| ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None) | |
| # ββββββββββββ Thread Registry for Cleanup on Tab Switch βββββββββββββ | |
| _running_threads: list[threading.Thread] = [] | |
| _running_threads_lock = threading.Lock() | |
| # Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag | |
| _thread_clip_map: dict[threading.Thread, ClipSegTTA] = {} | |
| # ββββββββββββ Run directory rotation βββββββββββββ | |
| RUN_HISTORY_LIMIT = 30 # keep at most this many timestamped run directories per instance | |
| def _prune_old_run_dirs(base_dir: str, limit: int = RUN_HISTORY_LIMIT): | |
| """Delete oldest timestamp-named run directories leaving only *limit* of the newest ones.""" | |
| try: | |
| dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))] | |
| dirs.sort() | |
| if len(dirs) > limit: | |
| for obsolete in dirs[:-limit]: | |
| shutil.rmtree(os.path.join(base_dir, obsolete), ignore_errors=True) | |
| except Exception: | |
| pass | |
| # CHANGE ME! | |
| POLL_INTERVAL = 1.0 # For visualization | |
| # Prepare the model | |
| device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu') | |
| policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device) | |
| script_dir = Path(__file__).resolve().parent | |
| print("real_script_dir: ", script_dir) | |
| checkpoint = torch.load(f'{MODEL_PATH}/{MODEL_NAME}') | |
| policy_net.load_state_dict(checkpoint['policy_model']) | |
| print('Model loaded!') | |
| # Load metadata json | |
| tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json") | |
| tgts_metadata = json.load(open(tgts_metadata_json_path)) | |
| # ββββββββββββββββββββββββββ Gradio process fn βββββββββββββββββββββββββ | |
| ### integration with ZeroGPU on hf | |
| # @spaces.GPU | |
| def process_search_tta( | |
| sat_path: str | None, | |
| ground_path: str | None, | |
| taxonomy: str | None = None, | |
| session_threads: list[threading.Thread] | None = None, | |
| ): | |
| """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps.""" | |
| if session_threads is None: | |
| session_threads = [] | |
| # Disable Run button and clear image/status outputs, hide sliders, clear frame states | |
| yield ( | |
| gr.update(interactive=False), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value="Initializing modelβ¦", visible=True), | |
| gr.update(value="Initializing modelβ¦", visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| [], | |
| [], | |
| session_threads, | |
| ) | |
| # Bail early if satellite image missing | |
| if sat_path is None: | |
| yield ( | |
| gr.update(interactive=True), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value="No satellite image provided.", visible=True), | |
| gr.update(value="", visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| [], | |
| [], | |
| session_threads, | |
| ) | |
| return | |
| # Prepare PIL images | |
| sat_img = Image.open(sat_path).convert("RGB") | |
| ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None | |
| # Lookup target positions metadata (may be empty) | |
| tgt_positions = [] | |
| if taxonomy and taxonomy in tgts_metadata: | |
| tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]] | |
| # Helper to build a TestWorker with/without TTA | |
| def build_planner(enable_tta: bool, save_dir: str, clip_obj): | |
| # Lazily (re)create a ClipSegTTA instance per thread if not provided | |
| local_clip = clip_obj | |
| if LOAD_AVS_BENCH and local_clip is None: | |
| local_clip = ClipSegTTA( | |
| img_dir=AVS_IMG_DIR, | |
| imo_dir=AVS_IMO_DIR, | |
| json_path=AVS_INAT_JSON_PATH, | |
| sat_to_img_ids_path=AVS_SAT_TO_IMG_IDS_PATH, | |
| sat_checkpoint_path=AVS_SAT_CHECKPOINT_PATH, | |
| load_pretrained_hf_ckpt=AVS_LOAD_PRETRAINED_HF_CHECKPOINT, | |
| blur_kernel = AVS_GAUSSIAN_BLUR_KERNEL, | |
| sample_index=-1, | |
| device=device, | |
| sat_to_img_ids_json_is_train_dict=False, | |
| tax_to_filter_val=QUERY_TAX, | |
| load_model=USE_CLIP_PREDS, | |
| query_modality=QUERY_MODALITY, | |
| sound_dir = AVS_SOUND_DIR, | |
| sound_checkpoint_path=AVS_SOUND_CHECKPOINT_PATH, | |
| ) | |
| if local_clip is not None: | |
| # Feed inputs to ClipSegTTA copy | |
| local_clip.img_paths = [ground_path] if ground_path else [] | |
| local_clip.imo_path = sat_path | |
| local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else []) | |
| local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device) | |
| local_clip.sounds = [] | |
| local_clip.sound_ids = [] | |
| local_clip.species_name = taxonomy or "" | |
| local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else "" | |
| local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)] | |
| planner = TestWorker( | |
| meta_agent_id=0, | |
| n_agent=1, | |
| policy_net=policy_net, | |
| global_step=-1, | |
| device=device, | |
| greedy=True, | |
| save_image=SAVE_GIFS, | |
| clip_seg_tta=local_clip, | |
| ) | |
| planner.execute_tta = enable_tta | |
| planner.gifs_path = save_dir | |
| return planner | |
| # ββββββββββββββ Per-run output directories ββββββββββββββ | |
| # Ensure base directory exists | |
| os.makedirs(GIFS_PATH, exist_ok=True) | |
| run_id = time.strftime("%Y%m%d_%H%M%S") # unique timestamp | |
| run_root = os.path.join(GIFS_PATH, run_id) | |
| gifs_dir_tta = os.path.join(run_root, "with_tta") | |
| gifs_dir_no = os.path.join(run_root, "no_tta") | |
| os.makedirs(gifs_dir_tta, exist_ok=True) | |
| os.makedirs(gifs_dir_no, exist_ok=True) | |
| # House-keep old runs so we never keep more than RUN_HISTORY_LIMIT | |
| _prune_old_run_dirs(GIFS_PATH, RUN_HISTORY_LIMIT) | |
| # Shared dict to record if a thread hit an exception | |
| error_flags = {"tta": False, "no": False} | |
| def _planner_thread(enable_tta: bool, save_dir: str, clip_obj, key: str): | |
| """Prepare directory, build planner, run an episode, record errors.""" | |
| try: | |
| planner = build_planner(enable_tta, save_dir, clip_obj) | |
| _thread_clip_map[threading.current_thread()] = planner.clip_seg_tta | |
| planner.run_episode(0) | |
| except Exception as exc: | |
| # Mark that this planner crashed so UI can show an error status | |
| error_flags[key] = True | |
| # Log full traceback so developers can debug via console logs | |
| import traceback, sys | |
| traceback.print_exc() | |
| # Still exit the thread | |
| return | |
| # Launch both planners in background threads β preparation included | |
| thread_tta = threading.Thread( | |
| target=_planner_thread, | |
| args=(True, gifs_dir_tta, None, "tta"), | |
| daemon=True, | |
| ) | |
| thread_no = threading.Thread( | |
| target=_planner_thread, | |
| args=(False, gifs_dir_no, None, "no"), | |
| daemon=True, | |
| ) | |
| # Track threads for this user session | |
| session_threads.extend([thread_tta, thread_no]) | |
| thread_tta.start() | |
| thread_no.start() | |
| sent_tta: set[str] = set() | |
| sent_no: set[str] = set() | |
| last_tta = None | |
| last_no = None | |
| # Track previous status strings so we can emit updates when only the | |
| # status (Runningβ¦/Done.) changes even if no new frame was produced. | |
| # Previous status values so we can detect changes and yield updates | |
| prev_status_tta = "Initializing modelβ¦" | |
| prev_status_no = "Initializing modelβ¦" | |
| try: | |
| while thread_tta.is_alive() or thread_no.is_alive(): | |
| updated = False | |
| # Collect new frames from TTA dir | |
| pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png")) | |
| pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| for fp in pngs: | |
| if fp not in sent_tta: | |
| # Ensure file is fully written (non-empty & readable) | |
| try: | |
| if os.path.getsize(fp) == 0: | |
| continue | |
| with open(fp, "rb") as fh: | |
| fh.read(1) | |
| except Exception: | |
| # Skip this round; we'll retry next poll | |
| continue | |
| sent_tta.add(fp) | |
| last_tta = fp | |
| updated = True | |
| # Collect new frames from no-TTA dir | |
| pngs = glob.glob(os.path.join(gifs_dir_no, "*.png")) | |
| pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| for fp in pngs: | |
| if fp not in sent_no: | |
| try: | |
| if os.path.getsize(fp) == 0: | |
| continue | |
| with open(fp, "rb") as fh: | |
| fh.read(1) | |
| except Exception: | |
| continue | |
| sent_no.add(fp) | |
| last_no = fp | |
| updated = True | |
| # Determine status based on whether we already have a frame and whether | |
| # the corresponding thread is still alive. | |
| def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False): | |
| if errored: | |
| return "Error!" | |
| if last_frame is None: | |
| return "Initializing modelβ¦" | |
| if not thread_alive: | |
| return "Done." | |
| return "Executing TTA (Scheduling GPUs)β¦" if running_tta else "Executing Plannerβ¦" | |
| exec_tta_flag = False | |
| if thread_tta.is_alive(): | |
| clip_obj = _thread_clip_map.get(thread_tta) | |
| if clip_obj is not None and getattr(clip_obj, "executing_tta", False): | |
| exec_tta_flag = True | |
| status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag) | |
| status_no = _mk_status(last_no, thread_no.is_alive(), error_flags["no"], False) | |
| # Determine if we should reveal sliders (once corresponding thread has finished) | |
| show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None) | |
| show_slider_no = (not thread_no.is_alive()) and (last_no is not None) | |
| # Build slider updates | |
| slider_tta_upd = gr.update() | |
| slider_no_upd = gr.update() | |
| frames_tta_upd = gr.update() | |
| frames_no_upd = gr.update() | |
| if show_slider_tta: | |
| n_tta_frames = max(len(sent_tta), 1) | |
| slider_tta_upd = gr.update(visible=True, minimum=1, maximum=n_tta_frames, value=n_tta_frames) | |
| frames_tta_upd = sorted(sent_tta, key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| if show_slider_no: | |
| n_no_frames = max(len(sent_no), 1) | |
| slider_no_upd = gr.update(visible=True, minimum=1, maximum=n_no_frames, value=n_no_frames) | |
| frames_no_upd = sorted(sent_no, key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| # Emit update if we have a new frame OR status changed OR slider visibility changed | |
| if ( | |
| updated | |
| or status_tta != prev_status_tta | |
| or status_no != prev_status_no | |
| or show_slider_tta | |
| or show_slider_no | |
| ): | |
| yield ( | |
| gr.update(interactive=False), | |
| last_tta, | |
| last_no, | |
| gr.update(value=status_tta, visible=True), | |
| gr.update(value=status_no, visible=True), | |
| slider_tta_upd, | |
| slider_no_upd, | |
| frames_tta_upd, | |
| frames_no_upd, | |
| session_threads, | |
| ) | |
| prev_status_tta = status_tta | |
| prev_status_no = status_no | |
| time.sleep(POLL_INTERVAL) | |
| finally: | |
| # Ensure background threads are stopped on cancel | |
| for th in (thread_tta, thread_no): | |
| if th.is_alive(): | |
| _stop_thread(th) | |
| th.join(timeout=1) | |
| # Remove finished threads from global registry | |
| with _running_threads_lock: | |
| # Clear session thread list | |
| session_threads.clear() | |
| # Small delay to ensure last frame files are fully flushed | |
| time.sleep(0.2) | |
| # One last scan after both threads have finished to catch any frame | |
| # that may have been written just before termination but after the last | |
| # polling iteration. | |
| for fp in sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])): | |
| if fp not in sent_tta: | |
| sent_tta.add(fp) | |
| last_tta = fp | |
| for fp in sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])): | |
| if fp not in sent_no: | |
| sent_no.add(fp) | |
| last_no = fp | |
| # Prepare frames list and slider configs | |
| frames_tta = sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| frames_no = sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| if last_tta is None and frames_tta: | |
| last_tta = frames_tta[-1] | |
| if last_no is None and frames_no: | |
| last_no = frames_no[-1] | |
| n_tta = len(frames_tta) or 1 # prevent zero-range slider | |
| n_no = len(frames_no) or 1 | |
| # Final emit: re-enable button, hide statuses, show sliders set to last frame | |
| yield ( | |
| gr.update(interactive=True), | |
| last_tta, | |
| last_no, | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True, minimum=1, maximum=n_tta, value=n_tta), | |
| gr.update(visible=True, minimum=1, maximum=n_no, value=n_no), | |
| frames_tta, | |
| frames_no, | |
| session_threads, | |
| ) | |
| # ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo | |
| Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the other tab above. <br> | |
| Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. <br> | |
| If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br> | |
| <a href="https://search-tta.github.io">Project Website</a> | |
| """ | |
| ) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| gr.Markdown("### Model Inputs") | |
| sat_input = gr.Image( | |
| label="Satellite Image", | |
| sources=["upload"], | |
| type="filepath", | |
| height=320, | |
| ) | |
| ground_input = gr.Image( | |
| label="Ground-level Image", | |
| sources=["upload"], | |
| type="filepath", | |
| height=320, | |
| ) | |
| taxonomy_input = gr.Textbox( | |
| label="Full Taxonomy Name (not used)", | |
| placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", | |
| ) | |
| run_btn = gr.Button("Run Search-TTA", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Live Heatmap Output") | |
| display_img_tta = gr.Image(label="Heatmap (TTA per 20 steps)", type="filepath", height=400) # 512 | |
| status_tta = gr.Markdown("") | |
| slider_tta = gr.Slider(label="TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False) | |
| display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512 | |
| status_no_tta = gr.Markdown("") | |
| slider_no = gr.Slider(label="No-TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False) | |
| frames_state_tta = gr.State([]) | |
| frames_state_no = gr.State([]) | |
| session_threads_state = gr.State([]) | |
| # Slider callbacks (updates image when user drags slider) | |
| def _show_frame(idx: int, frames: list[str]): | |
| # Slider is 1-indexed; convert to 0-indexed list access | |
| if 1 <= idx <= len(frames): | |
| return frames[idx - 1] | |
| return gr.update() | |
| slider_tta.change(_show_frame, inputs=[slider_tta, frames_state_tta], outputs=display_img_tta) | |
| slider_no.change(_show_frame, inputs=[slider_no, frames_state_no], outputs=display_img_no_tta) | |
| # EXAMPLES | |
| with gr.Row(): | |
| gr.Markdown("### Taxonomy") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", | |
| "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg", | |
| "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg", | |
| "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg", | |
| "examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/938aab7b-4509-4de7-afad-2c8ea51f4799.jpg", | |
| "Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg", | |
| "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", | |
| ], | |
| ], | |
| inputs=[sat_input, ground_input, taxonomy_input], | |
| outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no], | |
| fn=process_search_tta, | |
| cache_examples=False, | |
| ) | |
| run_btn.click( | |
| fn=process_search_tta, | |
| inputs=[sat_input, ground_input, taxonomy_input, session_threads_state], | |
| outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state], | |
| ) | |
| # Footer to point out to model and data from app page. | |
| gr.Markdown( | |
| """ | |
| The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| # Build UI with explicit Tabs so we can detect tab selection and clean up | |
| from app_multimodal_inference import demo as multimodal_demo | |
| with gr.Blocks() as root: | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Multimodal Inference"): | |
| multimodal_demo.render() | |
| with gr.TabItem("Search-TTA"): | |
| demo.render() | |
| # Hidden textbox purely to satisfy Gradio's need for an output component. | |
| _cleanup_status = gr.Textbox(visible=False) | |
| outputs_on_tab = [_cleanup_status] | |
| def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]): | |
| # evt.value contains the name of the newly-selected tab. | |
| if evt.value == "Multimodal Inference": | |
| # Stop only threads started in this session | |
| for th in list(session_threads): | |
| if th is not None and th.is_alive(): | |
| _stop_thread(th) | |
| th.join(timeout=1) | |
| session_threads.clear() | |
| return "Stopped running Search-TTA threads." | |
| return "" | |
| tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab) | |
| root.queue(max_size=15) | |
| root.launch(share=True) | |