Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import re | |
| import time | |
| import base64 | |
| import gc | |
| import io | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoProcessor, AutoModel, AutoTokenizer | |
| from huggingface_hub import CommitScheduler | |
| import spaces | |
| _FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "LXGWWenKai-Bold.ttf") | |
| def _get_first_env(*names): | |
| for name in names: | |
| value = os.environ.get(name) | |
| if value and value.strip(): | |
| return value.strip() | |
| return None | |
| def _configure_hf_auth(): | |
| model_token = _get_first_env( | |
| "MODEL_HF_TOKEN", | |
| "LOG_HF_TOKEN", | |
| "HF_TOKEN", | |
| "HUGGINGFACE_HUB_TOKEN", | |
| "HUGGINGFACEHUB_API_TOKEN", | |
| ) | |
| log_token = _get_first_env( | |
| "LOG_HF_TOKEN", | |
| "MODEL_HF_TOKEN", | |
| "HF_TOKEN", | |
| "HUGGINGFACE_HUB_TOKEN", | |
| "HUGGINGFACEHUB_API_TOKEN", | |
| ) | |
| shared_token = model_token or log_token | |
| if shared_token: | |
| for name in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN"): | |
| os.environ[name] = shared_token | |
| return model_token, log_token | |
| MODEL_HF_TOKEN, LOG_HF_TOKEN = _configure_hf_auth() | |
| def _load_font(size=20): | |
| if os.path.exists(_FONT_PATH): | |
| try: | |
| return ImageFont.truetype(_FONT_PATH, size) | |
| except Exception: | |
| pass | |
| try: | |
| return ImageFont.truetype("DejaVuSans-Bold.ttf", size) | |
| except Exception: | |
| return ImageFont.load_default() | |
| # ============================================================ | |
| # Color / Parsing / Rendering Operations | |
| # ============================================================ | |
| def get_color_for_label(label): | |
| colors = [ | |
| (8, 145, 178), (220, 38, 38), (22, 163, 74), (37, 99, 235), | |
| (217, 119, 6), (147, 51, 234), | |
| ] | |
| idx = sum(ord(c) for c in label) | |
| return colors[idx % len(colors)] | |
| def parse_mixed_results(text, category_str=""): | |
| results = [] | |
| expected_cats = [c.strip().lower() for c in category_str.split("</c>") if c.strip()] | |
| ref_box_pattern = r"(<ref>.*?</ref>)|(<box>.*?</box>)" | |
| current_label = None | |
| found_structured = False | |
| for m in re.finditer(ref_box_pattern, text, flags=re.IGNORECASE | re.DOTALL): | |
| token = m.group(0) | |
| if token.lower().startswith("<ref>"): | |
| label_raw = re.sub(r"</?ref>", "", token, flags=re.IGNORECASE).strip() | |
| if label_raw: | |
| current_label = label_raw | |
| else: | |
| content = re.sub(r"</?box>", "", token, flags=re.IGNORECASE) | |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) | |
| coords = [float(n) for n in nums] | |
| if not coords: | |
| continue | |
| label = current_label | |
| if label is None: | |
| label = expected_cats[0] if expected_cats else "object" | |
| if len(coords) == 4: | |
| results.append({"type": "box", "coords": coords, "label": label}) | |
| elif len(coords) == 2: | |
| results.append({"type": "point", "coords": coords, "label": label}) | |
| found_structured = True | |
| if found_structured: | |
| return results | |
| box_pattern = r"<box>(.*?)</box>" | |
| parts = re.split(box_pattern, text) | |
| for i in range(1, len(parts), 2): | |
| preceding_text = parts[i - 1].lower() | |
| content = parts[i] | |
| label = expected_cats[0] if expected_cats else "object" | |
| for cat in expected_cats: | |
| if cat in preceding_text: | |
| label = cat | |
| break | |
| nums = re.findall(r"<\s*([0-9]+(?:\.[0-9]+)?)\s*>", content) | |
| coords = [float(n) for n in nums] | |
| if len(coords) == 4: | |
| results.append({"type": "box", "coords": coords, "label": label}) | |
| elif len(coords) == 2: | |
| results.append({"type": "point", "coords": coords, "label": label}) | |
| return results | |
| def resize_image_short_side(image, short_side_size): | |
| w, h = image.size | |
| if w <= h: | |
| new_w = short_side_size | |
| scale_factor = new_w / w | |
| new_h = int(h * scale_factor) | |
| else: | |
| new_h = short_side_size | |
| scale_factor = new_h / h | |
| new_w = int(w * scale_factor) | |
| resized_image = image.resize((new_w, new_h), Image.BILINEAR) | |
| return resized_image, scale_factor | |
| def draw_on_frame(frame_bgr, results, draw_label=True): | |
| pil_img = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) | |
| img_draw = pil_img.convert("RGBA") | |
| overlay = Image.new("RGBA", img_draw.size, (255, 255, 255, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| font = _load_font(20) | |
| w_img, h_img = pil_img.size | |
| parsed = [] | |
| for res in results: | |
| label = res.get("label", "object") | |
| color = get_color_for_label(label) | |
| if res.get("type") == "point": | |
| c = res["coords"] | |
| cx = max(0, min(w_img, c[0] * w_img / 1000)) | |
| cy = max(0, min(h_img, c[1] * h_img / 1000)) | |
| parsed.append(("point", label, color, cx, cy)) | |
| continue | |
| if "is_pixel" in res: | |
| x1, y1, bw, bh = res["coords"] | |
| x2, y2 = x1 + bw, y1 + bh | |
| else: | |
| c = res["coords"] | |
| if len(c) < 4: | |
| continue | |
| x1 = c[0] * w_img / 1000 | |
| y1 = c[1] * h_img / 1000 | |
| x2 = c[2] * w_img / 1000 | |
| y2 = c[3] * h_img / 1000 | |
| x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w_img, x2), min(h_img, y2) | |
| x1, x2 = min(x1, x2), max(x1, x2) | |
| y1, y2 = min(y1, y2), max(y1, y2) | |
| parsed.append(("box", label, color, x1, y1, x2, y2)) | |
| for item in parsed: | |
| if item[0] == "box": | |
| _, _, color, x1, y1, x2, y2 = item | |
| fill_color = color + (65,) | |
| draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=color, width=4) | |
| elif item[0] == "point": | |
| _, _, color, cx, cy = item | |
| r = 10 | |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=color, outline="white", width=2) | |
| if draw_label: | |
| for item in parsed: | |
| if item[0] == "box": | |
| _, label, color, x1, y1, x2, y2 = item | |
| if not label: | |
| continue | |
| t_box = draw.textbbox((0, 0), label, font=font) | |
| th = t_box[3] - t_box[1] | |
| tw = t_box[2] - t_box[0] | |
| pad_x, pad_y = 7, 4 | |
| tag_h = th + pad_y * 2 | |
| tag_w = tw + pad_x * 2 | |
| tag_y = y1 - tag_h - 2 | |
| if tag_y < 0: | |
| tag_y = y2 + 2 | |
| draw.rectangle([x1, tag_y, x1 + tag_w, tag_y + tag_h], fill=color) | |
| draw.text((x1 + pad_x, tag_y + pad_y), label, fill="white", font=font) | |
| elif item[0] == "point": | |
| _, label, color, cx, cy = item | |
| if not label: | |
| continue | |
| t_box = draw.textbbox((0, 0), label, font=font) | |
| th, tw = t_box[3] - t_box[1], t_box[2] - t_box[0] | |
| tx, ty = cx + 14, cy - th // 2 | |
| draw.rectangle([tx - 2, ty - 2, tx + tw + 6, ty + th + 4], fill=color) | |
| draw.text((tx + 2, ty), label, fill="white", font=font) | |
| combined = Image.alpha_composite(img_draw, overlay).convert("RGB") | |
| return cv2.cvtColor(np.array(combined), cv2.COLOR_RGB2BGR) | |
| # ============================================================ | |
| # Model Runner Component | |
| # ============================================================ | |
| class EagleWorker: | |
| def __init__(self, model_path, device="cuda", generation_mode: str = "hybrid"): | |
| self.model_id = model_path | |
| self.device = device | |
| self.dtype = torch.bfloat16 | |
| self.generation_mode = generation_mode | |
| self.hf_token = MODEL_HF_TOKEN | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| token=self.hf_token, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| token=self.hf_token, | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| model_path, | |
| torch_dtype=self.dtype, | |
| _attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| token=self.hf_token, | |
| ).to(device).eval() | |
| print("Model Engine Loaded Safely.") | |
| def build_messages(self, image, categories, question_override=None): | |
| if question_override is not None: | |
| user_text = question_override | |
| else: | |
| category_set_str = "</c>".join(categories) | |
| user_text = f"Locate all the instances that matches the following description: {category_set_str}." | |
| return [{"role": "user", "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": user_text}, | |
| ]}] | |
| def generate(self, image, categories, generation_mode=None, | |
| max_new_tokens=4096, temp=0.7, top_p=0.9, top_k=50, | |
| question_override=None): | |
| messages = self.build_messages(image, categories, question_override=question_override) | |
| text = self.processor.py_apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| images, videos = self.processor.process_vision_info(messages) | |
| inputs = self.processor(text=[text], images=images, videos=videos, return_tensors="pt").to(self.device) | |
| pixel_values = inputs["pixel_values"].to(self.dtype) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| image_grid_hws = inputs.get("image_grid_hws", None) | |
| with torch.inference_mode(): | |
| result = self.model.generate( | |
| pixel_values=pixel_values, input_ids=input_ids, | |
| attention_mask=attention_mask, image_grid_hws=image_grid_hws, | |
| tokenizer=self.tokenizer, max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| generation_mode=generation_mode if generation_mode is not None else self.generation_mode, | |
| temperature=temp, do_sample=True, top_p=top_p, | |
| repetition_penalty=1.1, verbose=True, | |
| ) | |
| token_sequence, out_info, output_text = [], "", "" | |
| if isinstance(result, tuple) and len(result) >= 3: | |
| output_text, token_sequence, out_info = result | |
| if generation_mode == "slow": | |
| token_sequence[-1] = ("ar", token_sequence[-1][1]) | |
| else: | |
| output_text = result | |
| return output_text, token_sequence, out_info | |
| # ============================================================ | |
| # Post-Processing UI Helpers | |
| # ============================================================ | |
| def _postprocess_detections(detections, w, h): | |
| valid = [] | |
| for det in detections: | |
| if det["type"] == "box": | |
| c = det["coords"] | |
| rx1 = max(0, min(w - 1, int(c[0] * w / 1000))) | |
| ry1 = max(0, min(h - 1, int(c[1] * h / 1000))) | |
| rx2 = max(0, min(w - 1, int(c[2] * w / 1000))) | |
| ry2 = max(0, min(h - 1, int(c[3] * h / 1000))) | |
| box_w, box_h = rx2 - rx1, ry2 - ry1 | |
| if box_w <= 0 or box_h <= 0: | |
| continue | |
| valid.append({"type": "box", "coords": [rx1, ry1, box_w, box_h], | |
| "is_pixel": True, "label": det["label"]}) | |
| elif det["type"] == "point": | |
| valid.append(det) | |
| return valid | |
| def _parse_out_info_dict(out_info: str) -> dict: | |
| stats = {} | |
| if not out_info: | |
| return stats | |
| cleaned = re.sub(r"^[Ss]tast?ic\s*[Ii]nfo\s*,?\s*", "", out_info.strip()) | |
| for part in cleaned.split(";"): | |
| part = part.strip() | |
| if "=" in part: | |
| k, v = part.split("=", 1) | |
| stats[k.strip()] = v.strip() | |
| return stats | |
| def generate_dynamic_html(token_sequence, out_info, raw_text): | |
| uid = f"a{int(time.time() * 1000)}" | |
| css = f""" | |
| <style> | |
| .dc-root {{ | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| border: 1px solid #cce875; border-radius: 10px; background: #ffffff; overflow: hidden; | |
| }} | |
| .dc-header {{ | |
| display: flex; align-items: center; justify-content: space-between; padding: 12px 18px; | |
| background: linear-gradient(135deg, #76b900 0%, #649d00 100%); border-bottom: 1px solid #527f00; | |
| }} | |
| .dc-header-title {{ font-weight: 700; font-size: 0.95em; color: #ffffff !important; }} | |
| .dc-legend {{ display: flex; gap: 16px; align-items: center; }} | |
| .dc-legend-item {{ display: flex; align-items: center; gap: 5px; font-size: 0.78em; color: rgba(255,255,255,0.92); }} | |
| .dc-legend-dot {{ width: 10px; height: 10px; border-radius: 3px; display: inline-block; }} | |
| .dc-row {{ display: flex; gap: 10px; padding: 14px 18px; border-bottom: 1px solid #eef7d1; }} | |
| .dc-row:last-child {{ border-bottom: none; }} | |
| .dc-val {{ flex: 1; line-height: 2.3; word-wrap: break-word; color: #4b5563; font-size: 0.92em; }} | |
| @keyframes tk-{uid} {{ | |
| 0% {{ opacity: 0; transform: translateY(8px); }} | |
| 100% {{ opacity: 1; transform: translateY(0); }} | |
| }} | |
| .tk-mtp-{uid}, .tk-ar-{uid} {{ | |
| opacity: 0; animation: tk-{uid} 0.35s ease-out forwards; border-radius: 5px; padding: 2px 7px; margin: 2px 1px; display: inline-block; | |
| font-size: 0.80em; font-weight: 600; font-family: monospace; white-space: nowrap; | |
| }} | |
| .tk-mtp-{uid} {{ background: #e8f5e9; border: 2px solid #76b900; color: #000000; }} | |
| .tk-ar-{uid} {{ background: #fff3e0; border: 2px solid #e65100; color: #000000; }} | |
| .tk-stat-{uid} {{ | |
| opacity: 0; animation: tk-{uid} 0.4s ease-out forwards; background: #f0f9e2; border: 1px solid #a4d422; border-radius: 6px; | |
| padding: 5px 14px; display: inline-block; font-size: 0.82em; color: #3f6200; font-weight: 600; | |
| }} | |
| .dc-raw {{ padding: 0 18px 14px; }} | |
| .dc-raw summary {{ cursor: pointer; color: #9ca3af; font-size: 0.82em; }} | |
| .dc-raw-pre {{ | |
| background: #f7fbe8; border: 1px solid #ddf0a3; border-radius: 6px; padding: 12px; margin-top: 8px; | |
| font-family: monospace; font-size: 0.78em; color: #374151; white-space: pre-wrap; max-height: 200px; overflow-y: auto; | |
| }} | |
| </style> | |
| """ | |
| h = css + '<div class="dc-root">' | |
| h += ('<div class="dc-header">' | |
| '<span class="dc-header-title">LocateAnything Decoding Trace</span>' | |
| '<div class="dc-legend">' | |
| '<div class="dc-legend-item"><span class="dc-legend-dot" style="background:#76b900;"></span>MTP</div>' | |
| '<div class="dc-legend-item"><span class="dc-legend-dot" style="background:#e65100;"></span>AR</div>' | |
| '</div></div>') | |
| h += '<div class="dc-row"><div class="dc-val">' | |
| tok_idx = 0 | |
| if token_sequence: | |
| for item in token_sequence: | |
| if not isinstance(item, (list, tuple)) or len(item) < 2: | |
| continue | |
| decode_type = str(item[0]).lower() | |
| text = str(item[1]) | |
| safe = text.replace("<", "<").replace(">", ">") | |
| delay = f"{tok_idx * 0.04:.2f}s" | |
| cls = f"tk-ar-{uid}" if decode_type == "ar" else f"tk-mtp-{uid}" | |
| h += f'<span class="{cls}" style="animation-delay:{delay}">{safe}</span> ' | |
| tok_idx += 1 | |
| h += '</div></div>' | |
| if out_info: | |
| stats = _parse_out_info_dict(out_info) | |
| bits = [] | |
| for key, name in [("forward_step", "steps"), ("num_tokens", "tokens"), ("num_boxes", "boxes"), ("ar_step", "AR steps"), ("tps", "tok/s")]: | |
| if key in stats: | |
| bits.append(f"{stats[key]} {name}") | |
| summary = " · ".join(bits) if bits else out_info.strip() | |
| stat_delay = f"{tok_idx * 0.04 + 0.2:.2f}s" | |
| h += (f'<div class="dc-row" style="justify-content:flex-end;padding-top:4px;padding-bottom:10px;border-bottom:none;">' | |
| f'<span class="tk-stat-{uid}" style="animation-delay:{stat_delay}">⚡ {summary}</span></div>') | |
| if raw_text: | |
| safe_raw = raw_text.replace("<", "<").replace(">", ">") | |
| h += (f'<div class="dc-raw"><details><summary>📄 Show Raw Response</summary>' | |
| f'<div class="dc-raw-pre">{safe_raw}</div></details></div>') | |
| h += '</div>' | |
| return h | |
| def generate_raw_prompt(task_type, category): | |
| if not category: | |
| category = "objects" | |
| cats = "</c>".join(c.strip() for c in category.split(",") if c.strip()) | |
| if task_type == "Detection": | |
| return f"Locate all the instances that matches the following description: {cats}." | |
| elif task_type == "Grounding": | |
| return f"Locate all the instances that match the following description: {cats}." | |
| elif task_type == "OCR": | |
| return "Detect all the text in box format." | |
| elif task_type == "GUI": | |
| return f"Locate the region that matches the following description: {cats}." | |
| elif task_type == "Pointing": | |
| return f"Point to: {cats}." | |
| return f"Locate all the instances that matches the following description: {cats}." | |
| # ============================================================ | |
| # Dynamic Model Safety Initialization | |
| # ============================================================ | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "nvidia/LocateAnything-3B") | |
| print(f"Loading Base Weight Layer Model Matrix via: {MODEL_PATH}") | |
| GLOBAL_WORKER = EagleWorker(MODEL_PATH) | |
| LOG_DATASET_REPO = os.environ.get("LOG_DATASET_REPO") | |
| _LOG_DIR = Path(tempfile.mkdtemp(prefix="hf_log_")) | |
| _SESSION_ID = uuid.uuid4().hex[:8] | |
| _log_scheduler = None | |
| if LOG_DATASET_REPO and LOG_HF_TOKEN: | |
| try: | |
| _log_scheduler = CommitScheduler( | |
| repo_id=LOG_DATASET_REPO, | |
| repo_type="dataset", | |
| folder_path=str(_LOG_DIR), | |
| path_in_repo="data", | |
| every=5, | |
| token=LOG_HF_TOKEN, | |
| squash_history=False, | |
| ) | |
| print(f"[LOG] System Scheduler initialized successfully context workspace mapping tracking.") | |
| except Exception as e: | |
| print(f"[LOG] Remote logging skipped or unauthorized setup boundary: {e}") | |
| def _pil_to_b64(pil_img): | |
| buf = io.BytesIO() | |
| pil_img.save(buf, "PNG") | |
| return base64.b64encode(buf.getvalue()).decode("ascii") | |
| def _log_to_dataset(input_type, category, model_mode, raw_prompt, output_text="", input_image=None, output_image=None): | |
| if _log_scheduler is None: | |
| return | |
| try: | |
| entry_id = f"{int(time.time())}_{uuid.uuid4().hex[:6]}" | |
| record = { | |
| "id": entry_id, | |
| "session_id": _SESSION_ID, | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "input_type": input_type, | |
| "category": category, | |
| "model_mode": model_mode, | |
| "raw_prompt": raw_prompt, | |
| "output_text": output_text, | |
| "input_image_b64": _pil_to_b64(input_image) if input_image else None, | |
| "output_image_b64": _pil_to_b64(output_image) if output_image else None, | |
| } | |
| day_dir = _LOG_DIR / time.strftime("%Y-%m-%d", time.gmtime()) | |
| day_dir.mkdir(parents=True, exist_ok=True) | |
| with _log_scheduler.lock: | |
| with open(day_dir / f"{_SESSION_ID}__{entry_id}.jsonl", "w", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| except Exception as e: | |
| print(f"[LOG] Write failure: {e}") | |
| def _prepare_image_for_model(pil_img, short_size): | |
| process_img = pil_img.copy() | |
| if short_size and int(short_size) > 0: | |
| process_img, _ = resize_image_short_side(process_img, min(int(short_size), 1024)) | |
| else: | |
| if min(process_img.size) > 1024: | |
| process_img, _ = resize_image_short_side(process_img, 1024) | |
| return process_img | |
| # ============================================================ | |
| # Spaces GPU Wrapper Decorators | |
| # ============================================================ | |
| def _run_image_inference(image_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_size, question_override): | |
| if image_in is None: | |
| return gr.update(value=None, visible=True), gr.update(value=None, visible=False), "<p>⚠️ Upload image.</p>" | |
| process_img = _prepare_image_for_model(image_in, short_size) | |
| output_text, token_sequence, out_info = GLOBAL_WORKER.generate( | |
| process_img, categories_list, model_mode, temp=temp, top_p=top_p, top_k=top_k, question_override=question_override | |
| ) | |
| detections = parse_mixed_results(output_text, category_str) | |
| frame_bgr = cv2.cvtColor(np.array(image_in), cv2.COLOR_RGB2BGR) | |
| out_img_bgr = draw_on_frame(frame_bgr, detections, draw_label=True) | |
| output_image = Image.fromarray(cv2.cvtColor(out_img_bgr, cv2.COLOR_BGR2RGB)) | |
| _log_to_dataset("image", ", ".join(categories_list), model_mode, question_override or category_str, output_text, image_in, output_image) | |
| return gr.update(value=output_image, visible=True), gr.update(value=None, visible=False), generate_dynamic_html(token_sequence, out_info, output_text) | |
| def _run_video_inference(video_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_size, question_override, max_video_frames): | |
| import subprocess as _sp | |
| if video_in is None: | |
| return gr.update(value=None, visible=False), gr.update(value=None, visible=True), "<p>⚠️ Upload video.</p>" | |
| cap = cv2.VideoCapture(video_in) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| vid_w, vid_h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| all_frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| all_frames.append(frame) | |
| cap.release() | |
| total = len(all_frames) | |
| max_frames = int(max_video_frames) if max_video_frames else 4 | |
| sample_indices = list(range(total)) if total <= max_frames else [int(round(i * (total - 1) / (max_frames - 1))) for i in range(max_frames)] | |
| sampled_frames = [all_frames[i] for i in sample_indices] | |
| out_fps = max(1.0, len(sampled_frames) / (total / fps)) if fps > 0 else 5.0 | |
| del all_frames | |
| gc.collect() | |
| inference_results = [] | |
| for frame in sampled_frames: | |
| pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| process_img = _prepare_image_for_model(pil_img, short_size) | |
| output_text, _, _ = GLOBAL_WORKER.generate(process_img, categories_list, model_mode, temp=temp, top_p=top_p, top_k=top_k, question_override=question_override) | |
| inference_results.append(output_text) | |
| tmp_raw = tempfile.mktemp(suffix=".raw.mp4") | |
| out_video_path = tempfile.mktemp(suffix=".mp4") | |
| out = cv2.VideoWriter(tmp_raw, cv2.VideoWriter_fourcc(*"mp4v"), out_fps, (vid_w, vid_h)) | |
| for frame, output_text in zip(sampled_frames, inference_results): | |
| detections = parse_mixed_results(output_text, category_str) | |
| valid_results = _postprocess_detections(detections, vid_w, vid_h) | |
| out.write(draw_on_frame(frame, valid_results, draw_label=True)) | |
| out.release() | |
| _sp.run(["ffmpeg", "-y", "-i", tmp_raw, "-c:v", "libx264", "-preset", "ultrafast", "-crf", "23", "-pix_fmt", "yuv420p", out_video_path], capture_output=True) | |
| if os.path.exists(tmp_raw): os.remove(tmp_raw) | |
| combined_raw_text = "\n\n".join([f"--- Frame {i+1} ---\n{t}" for i, t in enumerate(inference_results)]) | |
| return gr.update(value=None, visible=False), gr.update(value=out_video_path, visible=True), generate_dynamic_html([], "Processed Loop Successful", combined_raw_text) | |
| def run_inference(input_type, image_in, video_in, task_type, category_str, model_mode, temp, top_p, top_k, short_side, question_override, max_video_frames): | |
| categories_list = [c.strip() for c in category_str.split(",") if c.strip()] or ["object"] | |
| final_override = question_override.strip() if (question_override and question_override.strip()) else None | |
| if input_type == "Image": | |
| return _run_image_inference(image_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_side, final_override) | |
| return _run_video_inference(video_in, categories_list, category_str, model_mode, temp, top_p, top_k, short_side, final_override, max_video_frames) | |
| # ============================================================ | |
| # GRADIO INTERFACE LAYOUT BUILD | |
| # ============================================================ | |
| def build_ui(): | |
| with gr.Blocks(title="LocateAnything Grounding Suite") as demo: | |
| gr.Markdown("# 🔍 LocateAnything Grounding Studio\nInfer target regions, visual boxes, and point indicators.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_type = gr.Radio(["Image", "Video"], value="Image", label="Input Format") | |
| image_input = gr.Image(type="pil", label="Source Image", visible=True) | |
| video_input = gr.Video(label="Source Video", visible=False) | |
| task_dropdown = gr.Dropdown(["Detection", "Grounding", "OCR", "GUI", "Pointing"], value="Detection", label="Goal Context Task") | |
| category_input = gr.Textbox(label="Categories / Label Targets (comma separated)", value="car, pedestrian") | |
| raw_prompt_box = gr.Textbox(label="Generated Execution Prompt (Read Only)", value="Locate all the instances that matches the following description: car</c>pedestrian.", interactive=False) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| model_dropdown = gr.Dropdown(["hybrid", "fast", "slow"], value="hybrid", label="Decoding Engine Mode") | |
| temp_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature") | |
| top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P") | |
| top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top K") | |
| short_size_input = gr.Slider(0, 1024, value=1024, step=64, label="Max Downscaling Res Constraint (0 for Native)") | |
| max_video_frames_slider = gr.Slider(1, 16, value=4, step=1, label="Video Sample Extraction Cap") | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Annotated Image Result", visible=True) | |
| output_video = gr.Video(label="Annotated Video Result", visible=False) | |
| raw_output_box = gr.HTML(label="Visual Trace Dashboard") | |
| input_type.change( | |
| fn=lambda c: (gr.update(visible=(c == "Image")), gr.update(visible=(c == "Video"))), | |
| inputs=input_type, outputs=[image_input, video_input], | |
| ) | |
| for comp in [task_dropdown, category_input]: | |
| comp.change(fn=generate_raw_prompt, inputs=[task_dropdown, category_input], outputs=raw_prompt_box) | |
| run_btn.click( | |
| fn=lambda: gr.update(interactive=False, value="Processing Tensors..."), | |
| outputs=[run_btn], | |
| ).then( | |
| fn=run_inference, | |
| inputs=[ | |
| input_type, image_input, video_input, | |
| task_dropdown, category_input, model_dropdown, | |
| temp_slider, top_p_slider, top_k_slider, | |
| short_size_input, raw_prompt_box, max_video_frames_slider, | |
| ], | |
| outputs=[output_image, output_video, raw_output_box], | |
| ).then( | |
| fn=lambda: gr.update(interactive=True, value="Run Inference"), | |
| outputs=[run_btn], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.queue().launch() |