#!/usr/bin/env python # -*- coding: utf-8 -*- import gradio as gr import cv2 import numpy as np import os import tempfile import re import time import base64 import gc import io import json import uuid from pathlib import Path import torch from PIL import Image, ImageDraw, ImageFont from transformers import AutoProcessor, AutoModel, AutoTokenizer from huggingface_hub import CommitScheduler import spaces _FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf") def _get_first_env(*names): for name in names: value = os.environ.get(name) if value and value.strip(): return value.strip() return None def _configure_hf_auth(): model_token = _get_first_env( "MODEL_HF_TOKEN", "LOG_HF_TOKEN", "HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN", ) log_token = _get_first_env( "LOG_HF_TOKEN", "MODEL_HF_TOKEN", "HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN", ) shared_token = model_token or log_token if shared_token: for name in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN"): os.environ[name] = shared_token return model_token, log_token MODEL_HF_TOKEN, LOG_HF_TOKEN = _configure_hf_auth() def _load_font(size=20): if os.path.exists(_FONT_PATH): try: return ImageFont.truetype(_FONT_PATH, size) except Exception: pass try: return ImageFont.truetype("DejaVuSans-Bold.ttf", size) except Exception: return ImageFont.load_default() # ============================================================ # Color / Parsing / Rendering Operations # ============================================================ def get_color_for_label(label): colors = [ (8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235), (217, 119, 6), (147, 51, 234), ] idx = sum(ord(c) for c in label) return colors[idx % len(colors)] def parse_mixed_results(text, category_str=""): results = [] expected_cats = [c.strip().lower() for c in category_str.split("") if c.strip()] ref_box_pattern = r"(.*?)|(.*?)" current_label = None found_structured = False for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL): token = m.group(0) if token.lower().startswith(""): label_raw = re.sub(r"", "", token, flags=re.IGNORECASE).strip() if label_raw: current_label = label_raw else: content = re.sub(r"", "", token, flags=re.IGNORECASE) nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) coords = [float(n) for n in nums] if not coords: continue label = current_label if label is None: label = expected_cats[0] if expected_cats else "object" if len(coords) == 4: results.append({"type": "box", "coords": coords, "label": label}) elif len(coords) == 2: results.append({"type": "point", "coords": coords, "label": label}) found_structured = True if found_structured: return results box_pattern = r"(.*?)" parts = re.split(box_pattern, text) for i in range(1, len(parts), 2): preceding_text = parts[i - 1].lower() content = parts[i] label = expected_cats[0] if expected_cats else "object" for cat in expected_cats: if cat in preceding_text: label = cat break nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) coords = [float(n) for n in nums] if len(coords) == 4: results.append({"type": "box", "coords": coords, "label": label}) elif len(coords) == 2: results.append({"type": "point", "coords": coords, "label": label}) return results def resize_image_short_side(image, short_side_size): w, h = image.size if w <= h: new_w = short_side_size scale_factor = new_w / w new_h = int(h * scale_factor) else: new_h = short_side_size scale_factor = new_h / h new_w = int(w * scale_factor) resized_image = image.resize((new_w, new_h), Image.BILINEAR) return resized_image, scale_factor def draw_on_frame(frame_bgr, results, draw_label=True): pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) img_draw = pil_img.convert("RGBA") overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0)) draw = ImageDraw.Draw(overlay) font = _load_font(20) w_img, h_img = pil_img.size parsed = [] for res in results: label = res.get("label", "object") color = get_color_for_label(label) if res.get("type") == "point": c = res["coords"] cx = max(0, min(w_img, c[0] * w_img / 1000)) cy = max(0, min(h_img, c[1] * h_img / 1000)) parsed.append(("point", label, color, cx, cy)) continue if "is_pixel" in res: x1, y1, bw, bh = res["coords"] x2, y2 = x1 + bw, y1 + bh else: c = res["coords"] if len(c) < 4: continue x1 = c[0] * w_img / 1000 y1 = c[1] * h_img / 1000 x2 = c[2] * w_img / 1000 y2 = c[3] * h_img / 1000 x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2) x1, x2 = min(x1, x2), max(x1, x2) y1, y2 = min(y1, y2), max(y1, y2) parsed.append(("box", label, color, x1, y1, x2, y2)) for item in parsed: if item[0] == "box": _, _, color, x1, y1, x2, y2 = item fill_color = color + (65,) draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4) elif item[0] == "point": _, _, color, cx, cy = item r = 10 draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2) if draw_label: for item in parsed: if item[0] == "box": _, label, color, x1, y1, x2, y2 = item if not label: continue t_box = draw.textbbox((0, 0), label, font=font) th = t_box[3] - t_box[1] tw = t_box[2] - t_box[0] pad_x, pad_y = 7, 4 tag_h = th + pad_y * 2 tag_w = tw + pad_x * 2 tag_y = y1 - tag_h - 2 if tag_y < 0: tag_y = y2 + 2 draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color) draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font) elif item[0] == "point": _, label, color, cx, cy = item if not label: continue t_box = draw.textbbox((0, 0), label, font=font) th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0] tx, ty = cx + 14, cy - th // 2 draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color) draw.text((tx + 2, ty), label, fill="white", font=font) combined = Image.alpha_composite(img_draw, overlay).convert("RGB") return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR) # ============================================================ # Model Runner Component # ============================================================ class EagleWorker: def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"): self.model_id = model_path self.device = device self.dtype = torch.bfloat16 self.generation_mode = generation_mode self.hf_token = MODEL_HF_TOKEN self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, token=self.hf_token, ) self.processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, token=self.hf_token, ) self.model = AutoModel.from_pretrained( model_path, torch_dtype=self.dtype, _attn_implementation="sdpa", trust_remote_code=True, token=self.hf_token, ).to(device).eval() print("Model Engine Loaded Safely.") def build_messages(self, image, categories, question_override=None): if question_override is not None: user_text = question_override else: category_set_str = "".join(categories) user_text = f"Locate all the instances that matches the following description: {category_set_str}." return [{"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": user_text}, ]}] def generate(self, image, categories, generation_mode=None, max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50, question_override=None): messages = self.build_messages(image, categories, question_override=question_override) text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True) images, videos = self.processor.process_vision_info(messages) inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device) pixel_values = inputs["pixel_values"].to(self.dtype) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] image_grid_hws = inputs.get("image_grid_hws", None) with torch.inference_mode(): result = self.model.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, image_grid_hws=image_grid_hws, tokenizer=self.tokenizer, max_new_tokens=max_new_tokens, use_cache=True, generation_mode=generation_mode if generation_mode is not None else self.generation_mode, temperature=temp, do_sample=True, top_p=top_p, repetition_penalty=1.1, verbose=True, ) token_sequence, out_info, output_text = [], "", "" if isinstance(result, tuple) and len(result) >= 3: output_text, token_sequence, out_info = result if generation_mode == "slow": token_sequence[-1] = ("ar", token_sequence[-1][1]) else: output_text = result return output_text, token_sequence, out_info # ============================================================ # Post-Processing UI Helpers # ============================================================ def _postprocess_detections(detections, w, h): valid = [] for det in detections: if det["type"] == "box": c = det["coords"] rx1 = max(0, min(w - 1, int(c[0] * w / 1000))) ry1 = max(0, min(h - 1, int(c[1] * h / 1000))) rx2 = max(0, min(w - 1, int(c[2] * w / 1000))) ry2 = max(0, min(h - 1, int(c[3] * h / 1000))) box_w, box_h = rx2 - rx1, ry2 - ry1 if box_w <= 0 or box_h <= 0: continue valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h], "is_pixel": True, "label": det["label"]}) elif det["type"] == "point": valid.append(det) return valid def _parse_out_info_dict(out_info: str) -> dict: stats = {} if not out_info: return stats cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip()) for part in cleaned.split(";"): part = part.strip() if "=" in part: k, v = part.split("=", 1) stats[k.strip()] = v.strip() return stats def generate_dynamic_html(token_sequence, out_info, raw_text): uid = f"a{int(time.time() * 1000)}" css = f""" """ h = css + '
' h += ('
' 'LocateAnything Decoding Trace' '
' '
MTP
' '
AR
' '
') h += '
' tok_idx = 0 if token_sequence: for item in token_sequence: if not isinstance(item, (list, tuple)) or len(item) < 2: continue decode_type = str(item[0]).lower() text = str(item[1]) safe = text.replace("<", "<").replace(">", ">") delay = f"{tok_idx * 0.04:.2f}s" cls = f"tk-ar-{uid}" if decode_type == "ar" else f"tk-mtp-{uid}" h += f'{safe} ' tok_idx += 1 h += '
' if out_info: stats = _parse_out_info_dict(out_info) bits = [] for key, name in [("forward_step", "steps"), ("num_tokens", "tokens"), ("num_boxes", "boxes"), ("ar_step", "AR steps"), ("tps", "tok/s")]: if key in stats: bits.append(f"{stats[key]} {name}") summary = " · ".join(bits) if bits else out_info.strip() stat_delay = f"{tok_idx * 0.04 + 0.2:.2f}s" h += (f'
' f'⚡ {summary}
') if raw_text: safe_raw = raw_text.replace("<", "<").replace(">", ">") h += (f'
📄 Show Raw Response' f'
{safe_raw}
') h += '
' return h def generate_raw_prompt(task_type, category): if not category: category = "objects" cats = "".join(c.strip() for c in category.split(",") if c.strip()) if task_type == "Detection": return f"Locate all the instances that matches the following description: {cats}." elif task_type == "Grounding": return f"Locate all the instances that match the following description: {cats}." elif task_type == "OCR": return "Detect all the text in box format." elif task_type == "GUI": return f"Locate the region that matches the following description: {cats}." elif task_type == "Pointing": return f"Point to: {cats}." return f"Locate all the instances that matches the following description: {cats}." # ============================================================ # Dynamic Model Safety Initialization # ============================================================ MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B") print(f"Loading Base Weight Layer Model Matrix via: {MODEL_PATH}") GLOBAL_WORKER = EagleWorker(MODEL_PATH) LOG_DATASET_REPO = os.environ.get("LOG_DATASET_REPO") _LOG_DIR = Path(tempfile.mkdtemp(prefix="hf_log_")) _SESSION_ID = uuid.uuid4().hex[:8] _log_scheduler = None if LOG_DATASET_REPO and LOG_HF_TOKEN: try: _log_scheduler = CommitScheduler( repo_id=LOG_DATASET_REPO, repo_type="dataset", folder_path=str(_LOG_DIR), path_in_repo="data", every=5, token=LOG_HF_TOKEN, squash_history=False, ) print(f"[LOG] System Scheduler initialized successfully context workspace mapping tracking.") except Exception as e: print(f"[LOG] Remote logging skipped or unauthorized setup boundary: {e}") def _pil_to_b64(pil_img): buf = io.BytesIO() pil_img.save(buf, "PNG") return base64.b64encode(buf.getvalue()).decode("ascii") def _log_to_dataset(input_type, category, model_mode, raw_prompt, output_text="", input_image=None, output_image=None): if _log_scheduler is None: return try: entry_id = f"{int(time.time())}_{uuid.uuid4().hex[:6]}" record = { "id": entry_id, "session_id": _SESSION_ID, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "input_type": input_type, "category": category, "model_mode": model_mode, "raw_prompt": raw_prompt, "output_text": output_text, "input_image_b64": _pil_to_b64(input_image) if input_image else None, "output_image_b64": _pil_to_b64(output_image) if output_image else None, } day_dir = _LOG_DIR / time.strftime("%Y-%m-%d", time.gmtime()) day_dir.mkdir(parents=True, exist_ok=True) with _log_scheduler.lock: with open(day_dir / f"{_SESSION_ID}__{entry_id}.jsonl", "w", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") except Exception as e: print(f"[LOG] Write failure: {e}") def _prepare_image_for_model(pil_img, short_size): process_img = pil_img.copy() if short_size and int(short_size) > 0: process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024)) else: if min(process_img.size) > 1024: process_img, _ = resize_image_short_side(process_img, 1024) return process_img # ============================================================ # Spaces GPU Wrapper Decorators # ============================================================ @spaces.GPU(duration=45) def _run_image_inference(image_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_size, question_override): if image_in is None: return gr.update(value=None, visible=True), gr.update(value=None, visible=False), "

⚠️ Upload image.

" process_img = _prepare_image_for_model(image_in, short_size) output_text, token_sequence, out_info = GLOBAL_WORKER.generate( process_img, categories_list, model_mode, temp=temp, top_p=top_p, top_k=top_k, question_override=question_override ) detections = parse_mixed_results(output_text, category_str) frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR) out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True) output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB)) _log_to_dataset("image", ", ".join(categories_list), model_mode, question_override or category_str, output_text, image_in, output_image) return gr.update(value=output_image, visible=True), gr.update(value=None, visible=False), generate_dynamic_html(token_sequence, out_info, output_text) @spaces.GPU(duration=180) def _run_video_inference(video_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_size, question_override, max_video_frames): import subprocess as _sp if video_in is None: return gr.update(value=None, visible=False), gr.update(value=None, visible=True), "

⚠️ Upload video.

" cap = cv2.VideoCapture(video_in) fps = cap.get(cv2.CAP_PROP_FPS) vid_w, vid_h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) all_frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break all_frames.append(frame) cap.release() total = len(all_frames) max_frames = int(max_video_frames) if max_video_frames else 4 sample_indices = list(range(total)) if total <= max_frames else [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)] sampled_frames = [all_frames[i] for i in sample_indices] out_fps = max(1.0, len(sampled_frames) / (total / fps)) if fps > 0 else 5.0 del all_frames gc.collect() inference_results = [] for frame in sampled_frames: pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) process_img = _prepare_image_for_model(pil_img, short_size) output_text, _, _ = GLOBAL_WORKER.generate(process_img, categories_list, model_mode, temp=temp, top_p=top_p, top_k=top_k, question_override=question_override) inference_results.append(output_text) tmp_raw = tempfile.mktemp(suffix=".raw.mp4") out_video_path = tempfile.mktemp(suffix=".mp4") out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h)) for frame, output_text in zip(sampled_frames, inference_results): detections = parse_mixed_results(output_text, category_str) valid_results = _postprocess_detections(detections, vid_w, vid_h) out.write(draw_on_frame(frame, valid_results, draw_label=True)) out.release() _sp.run(["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264", "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p", out_video_path], capture_output=True) if os.path.exists(tmp_raw): os.remove(tmp_raw) combined_raw_text = "\n\n".join([f"--- Frame {i+1} ---\n{t}" for i, t in enumerate(inference_results)]) return gr.update(value=None, visible=False), gr.update(value=out_video_path, visible=True), generate_dynamic_html([], "Processed Loop Successful", combined_raw_text) def run_inference(input_type, image_in, video_in, task_type, category_str, model_mode, temp, top_p, top_k, short_side, question_override, max_video_frames): categories_list = [c.strip() for c in category_str.split(",") if c.strip()] or ["object"] final_override = question_override.strip() if (question_override and question_override.strip()) else None if input_type == "Image": return _run_image_inference(image_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_side, final_override) return _run_video_inference(video_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_side, final_override, max_video_frames) # ============================================================ # GRADIO INTERFACE LAYOUT BUILD # ============================================================ def build_ui(): with gr.Blocks(title="LocateAnything Grounding Suite") as demo: gr.Markdown("# 🔍 LocateAnything Grounding Studio\nInfer target regions, visual boxes, and point indicators.") with gr.Row(): with gr.Column(scale=1): input_type = gr.Radio(["Image", "Video"], value="Image", label="Input Format") image_input = gr.Image(type="pil", label="Source Image", visible=True) video_input = gr.Video(label="Source Video", visible=False) task_dropdown = gr.Dropdown(["Detection", "Grounding", "OCR", "GUI", "Pointing"], value="Detection", label="Goal Context Task") category_input = gr.Textbox(label="Categories / Label Targets (comma separated)", value="car, pedestrian") raw_prompt_box = gr.Textbox(label="Generated Execution Prompt (Read Only)", value="Locate all the instances that matches the following description: carpedestrian.", interactive=False) with gr.Accordion("Advanced Parameters", open=False): model_dropdown = gr.Dropdown(["hybrid", "fast", "slow"], value="hybrid", label="Decoding Engine Mode") temp_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature") top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P") top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top K") short_size_input = gr.Slider(0, 1024, value=1024, step=64, label="Max Downscaling Res Constraint (0 for Native)") max_video_frames_slider = gr.Slider(1, 16, value=4, step=1, label="Video Sample Extraction Cap") run_btn = gr.Button("Run Inference", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Annotated Image Result", visible=True) output_video = gr.Video(label="Annotated Video Result", visible=False) raw_output_box = gr.HTML(label="Visual Trace Dashboard") input_type.change( fn=lambda c: (gr.update(visible=(c == "Image")), gr.update(visible=(c == "Video"))), inputs=input_type, outputs=[image_input, video_input], ) for comp in [task_dropdown, category_input]: comp.change(fn=generate_raw_prompt, inputs=[task_dropdown, category_input], outputs=raw_prompt_box) run_btn.click( fn=lambda: gr.update(interactive=False, value="Processing Tensors..."), outputs=[run_btn], ).then( fn=run_inference, inputs=[ input_type, image_input, video_input, task_dropdown, category_input, model_dropdown, temp_slider, top_p_slider, top_k_slider, short_size_input, raw_prompt_box, max_video_frames_slider, ], outputs=[output_image, output_video, raw_output_box], ).then( fn=lambda: gr.update(interactive=True, value="Run Inference"), outputs=[run_btn], ) return demo if __name__ == "__main__": demo = build_ui() demo.queue().launch()