import argparse import base64 import importlib.util import inspect import io import json import math import os import re import sys from pathlib import Path def _disable_invalid_socks_proxy(): if importlib.util.find_spec("socksio") is not None: return for key in ("http_proxy", "https_proxy", "all_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"): value = os.environ.get(key) if value and value.lower().startswith("socks"): os.environ.pop(key, None) _disable_invalid_socks_proxy() import gradio as gr import numpy as np import torch from PIL import Image, ImageDraw from starlette.templating import Jinja2Templates from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList def _patch_starlette_template_response(): template_response = Jinja2Templates.TemplateResponse params = tuple(inspect.signature(template_response).parameters.keys()) if len(params) < 3 or params[1] != "request": return if getattr(template_response, "_vectorllm_compat", False): return def _compat_template_response(self, *args, **kwargs): if args and isinstance(args[0], str): name = args[0] context = args[1] if len(args) > 1 else kwargs.pop("context", None) if context is None: context = {} if not isinstance(context, dict): raise TypeError("TemplateResponse context must be a dict.") request = kwargs.pop("request", None) or context.get("request") if request is None: raise TypeError("TemplateResponse request is required.") return template_response( self, request, name, context, *args[2:], **kwargs, ) return template_response(self, *args, **kwargs) _compat_template_response._vectorllm_compat = True Jinja2Templates.TemplateResponse = _compat_template_response _patch_starlette_template_response() SCRIPT_DIR = Path(__file__).resolve().parent REPO_ROOT = next((parent for parent in SCRIPT_DIR.parents if parent.name == "VecorLLM"), SCRIPT_DIR) DEFAULT_EXPORT_DIR = SCRIPT_DIR if (SCRIPT_DIR / "config.json").exists() else (REPO_ROOT.parent / "hf_model" / "vectorllm_hf_0407") TORCH_DTYPE_MAP = { "auto": "auto", "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, } COORD_PATTERN = re.compile(r"<([xy])(\d+)>") DEFAULT_PAD_COLOR = (109, 104, 75) PIXEL_TOKEN = "" BUILDING_RAW_PROMPT = ( "<|im_start|>user\n\nPlease extract the regular vector contour of the central building in the image, " "start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n" ) OBJECT_RAW_PROMPT = ( "<|im_start|>user\n\nPlease extract the contour of the central object in the image, " "start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n" ) CANVAS_ANNOTATOR_HTML = """
Drag on the image to add one or more bounding boxes.
Upload an image to start drawing.
""" CANVAS_ANNOTATOR_HEAD = """ """ HF_MODEL = None HF_PROCESSOR = None HF_TOKENIZER = None HF_GENERATION_CONFIG = None class StopWordStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer, stop_word): self.tokenizer = tokenizer self.stop_word = stop_word self.length = len(stop_word) def __call__(self, input_ids, *args, **kwargs) -> bool: cur_text = self.tokenizer.decode(input_ids[0]) cur_text = cur_text.replace("\r", "").replace("\n", "") return cur_text[-self.length:] == self.stop_word def get_stop_criteria(tokenizer, stop_words=None): stop_words = stop_words or [] stop_criteria = StoppingCriteriaList() for word in stop_words: stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) return stop_criteria def parse_args(): parser = argparse.ArgumentParser(description="VectorLLM HF Gradio demo with full-image bbox cropping.") parser.add_argument( "--model-path", default=str(DEFAULT_EXPORT_DIR), help="Local HF export directory. If this script is copied into the export folder, the folder itself is used.", ) parser.add_argument( "--dtype", choices=sorted(TORCH_DTYPE_MAP.keys()), default="auto", help="Model dtype on CUDA. CPU uses fp32 automatically.", ) parser.add_argument("--max-new-tokens", type=int, default=640) parser.add_argument("--server-name", default="0.0.0.0") parser.add_argument("--server-port", type=int, default=7861) parser.add_argument("--share", action="store_true") return parser.parse_args() def bootstrap_local_registry(model_path): model_path = Path(model_path).expanduser().resolve() parent = str(model_path.parent) package_name = model_path.name if parent not in sys.path: sys.path.insert(0, parent) __import__(package_name) def build_generation_config(model_path, tokenizer, max_new_tokens): try: generation_config = GenerationConfig.from_pretrained(model_path) except Exception: generation_config = GenerationConfig() generation_config.max_new_tokens = max_new_tokens generation_config.use_cache = True generation_config.do_sample = False generation_config.temperature = None generation_config.top_k = None generation_config.top_p = None generation_config.eos_token_id = tokenizer.eos_token_id generation_config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id return generation_config def init_model(model_path, dtype_name, max_new_tokens): bootstrap_local_registry(model_path) use_cuda = torch.cuda.is_available() torch_dtype = TORCH_DTYPE_MAP[dtype_name] if use_cuda else torch.float32 model = AutoModel.from_pretrained( model_path, trust_remote_code=False, dtype=torch_dtype, ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=False) tokenizer = processor.tokenizer if use_cuda: model = model.cuda() model.eval() generation_config = build_generation_config(model_path, tokenizer, max_new_tokens) return model, processor, tokenizer, generation_config def load_image_source(image_source): if image_source is None: raise ValueError("Please upload an image first.") if isinstance(image_source, Image.Image): return image_source.convert("RGB") if isinstance(image_source, str): if not image_source.strip(): raise ValueError("Please upload an image first.") if image_source.startswith("data:image"): _, encoded = image_source.split(",", 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)).convert("RGB") return Image.open(image_source).convert("RGB") raise ValueError("Unsupported image input.") def normalize_bbox(bbox): x1, y1, x2, y2 = bbox x1, x2 = sorted((float(x1), float(x2))) y1, y2 = sorted((float(y1), float(y2))) return [x1, y1, x2, y2] def is_valid_bbox(bbox, min_size=2.0): x1, y1, x2, y2 = normalize_bbox(bbox) return (x2 - x1) >= min_size and (y2 - y1) >= min_size def parse_bbox_text(raw_text): if raw_text is None or not raw_text.strip(): return [] bbox_entries = [] invalid_entries = [] for chunk in re.split(r"[;\n]+", raw_text): entry = chunk.strip() if not entry: continue parts = [part.strip() for part in entry.split(",")] if len(parts) != 4: invalid_entries.append(entry) continue try: bbox = [float(part) for part in parts] except ValueError: invalid_entries.append(entry) continue bbox = normalize_bbox(bbox) if is_valid_bbox(bbox): bbox_entries.append(bbox) else: invalid_entries.append(entry) if invalid_entries: raise ValueError( "Invalid bbox entries: " + "; ".join(invalid_entries) + ". Use x1,y1,x2,y2 with width/height >= 2." ) return bbox_entries def get_grid_size(): image_processor = getattr(HF_PROCESSOR, "image_processor", None) if image_processor is None: return 128 resized_size = getattr(image_processor, "resized_size", 128) return int(resized_size) def get_pad_color(): image_processor = getattr(HF_PROCESSOR, "image_processor", None) if image_processor is None: return DEFAULT_PAD_COLOR image_mean = getattr(image_processor, "image_mean", None) if image_mean is None or len(image_mean) < 3: return DEFAULT_PAD_COLOR pad_color = [] for value in image_mean[:3]: value = float(value) if value <= 1.0: value = value * 255.0 pad_color.append(int(round(min(max(value, 0.0), 255.0)))) return tuple(pad_color) def get_raw_prompt(subject): if subject == "object": return OBJECT_RAW_PROMPT return BUILDING_RAW_PROMPT def decode_generated_text(output, model_inputs, tokenizer): prompt_length = model_inputs["input_ids"].shape[1] sequences = output.sequences if hasattr(output, "sequences") else output def _clean(text): return ( text.replace("<|im_end|>", "") .replace("<|endoftext|>", "") .replace("", "") .strip() ) full_text = _clean(tokenizer.decode(sequences[0], skip_special_tokens=False)) sliced_text = "" if sequences.shape[1] > prompt_length: sliced_text = _clean(tokenizer.decode(sequences[0][prompt_length:], skip_special_tokens=False)) full_score = len(re.findall(r"<[xy]\d+>", full_text)) sliced_score = len(re.findall(r"<[xy]\d+>", sliced_text)) if sliced_score >= full_score and sliced_text: return sliced_text return full_text def parse_polygon(text): points = [] pending_x = None for axis, raw_value in COORD_PATTERN.findall(text): value = int(raw_value) if axis == "x": pending_x = value elif pending_x is not None: points.append((pending_x, value)) pending_x = None return points def expand_bbox(bbox, expand_ratio): x1, y1, x2, y2 = normalize_bbox(bbox) cx = (x1 + x2) / 2.0 cy = (y1 + y2) / 2.0 width = (x2 - x1) * float(expand_ratio) height = (y2 - y1) * float(expand_ratio) expanded = [ int(math.floor(cx - width / 2.0)), int(math.floor(cy - height / 2.0)), int(math.ceil(cx + width / 2.0)), int(math.ceil(cy + height / 2.0)), ] if expanded[2] <= expanded[0]: expanded[2] = expanded[0] + 1 if expanded[3] <= expanded[1]: expanded[3] = expanded[1] + 1 return expanded def crop_image_by_bbox(image, bbox, expand_ratio): expanded_bbox = expand_bbox(bbox, expand_ratio) crop_width = max(1, expanded_bbox[2] - expanded_bbox[0]) crop_height = max(1, expanded_bbox[3] - expanded_bbox[1]) crop_image = Image.new("RGB", (crop_width, crop_height), get_pad_color()) src_x1 = max(0, expanded_bbox[0]) src_y1 = max(0, expanded_bbox[1]) src_x2 = min(image.size[0], expanded_bbox[2]) src_y2 = min(image.size[1], expanded_bbox[3]) if src_x2 > src_x1 and src_y2 > src_y1: region = image.crop((src_x1, src_y1, src_x2, src_y2)) crop_image.paste(region, (src_x1 - expanded_bbox[0], src_y1 - expanded_bbox[1])) return crop_image, expanded_bbox def recover_polygon(points, image_size, grid_size, offset_x=0.0, offset_y=0.0): image_w, image_h = image_size recovered = [] for x_coord, y_coord in points: x_val = (float(x_coord) + 0.5) / grid_size * image_w + offset_x y_val = (float(y_coord) + 0.5) / grid_size * image_h + offset_y recovered.append((x_val, y_val)) return recovered def clamp_polygon(polygon, image_size): image_w, image_h = image_size clamped = [] for x_coord, y_coord in polygon: clamped.append( ( min(max(float(x_coord), 0.0), image_w - 1.0), min(max(float(y_coord), 0.0), image_h - 1.0), ) ) return clamped def draw_crop_polygon(image, polygon): rendered = image.convert("RGBA") overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0)) drawer = ImageDraw.Draw(overlay) polygon_points = [(int(round(x)), int(round(y))) for x, y in polygon] if len(polygon_points) >= 3: drawer.polygon( polygon_points, outline=(255, 0, 255, 255), fill=(0, 255, 255, 90), width=2, ) for x_coord, y_coord in polygon_points: drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255)) return Image.alpha_composite(rendered, overlay).convert("RGB") def draw_full_overlay(image, results): rendered = image.convert("RGBA") overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0)) drawer = ImageDraw.Draw(overlay) for result in results: bbox = result["bbox"] expanded_bbox = result["expanded_bbox"] polygon = result["polygon"] index = result["index"] drawer.rectangle( [tuple(expanded_bbox[:2]), tuple(expanded_bbox[2:])], outline=(255, 191, 0, 255), width=2, ) drawer.rectangle( [tuple(bbox[:2]), tuple(bbox[2:])], outline=(0, 255, 127, 255), width=2, ) polygon_points = [(int(round(x)), int(round(y))) for x, y in polygon] if len(polygon_points) >= 3: drawer.polygon( polygon_points, outline=(255, 0, 255, 255), fill=(0, 255, 255, 72), width=2, ) for x_coord, y_coord in polygon_points: drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255)) anchor_x, anchor_y = polygon_points[0] if polygon_points else (int(round(bbox[0])), int(round(bbox[1]))) drawer.text((anchor_x + 4, anchor_y + 4), str(index), fill=(255, 255, 255, 255)) return Image.alpha_composite(rendered, overlay).convert("RGB") def format_text_outputs(results): if not results: return "No model output." chunks = [] for result in results: chunks.append(f"[BBox {result['index']}]\n{result['text']}") return "\n\n".join(chunks) def build_report(image, results, subject, expand_ratio): return { "image_size": list(image.size), "subject": subject, "expand_ratio": float(expand_ratio), "grid_size": get_grid_size(), "results": [ { "index": result["index"], "bbox": result["bbox"], "expanded_bbox": result["expanded_bbox"], "crop_size": list(result["crop_image"].size), "text": result["text"], "grid_polygon": result["grid_polygon"], "crop_polygon": result["crop_polygon"], "polygon": result["polygon"], } for result in results ], } def predict_single_bbox(image, bbox, expand_ratio, subject): crop_image, expanded_bbox = crop_image_by_bbox(image, bbox, expand_ratio) prompt = get_raw_prompt(subject) model_inputs = HF_PROCESSOR(text=[prompt], images=[crop_image], return_tensors="pt") model_inputs = { key: value.to(HF_MODEL.device) if torch.is_tensor(value) else value for key, value in model_inputs.items() } stop_criteria = get_stop_criteria(HF_TOKENIZER, ["<|im_end|>", "<|endoftext|>"]) with torch.inference_mode(): output = HF_MODEL.generate( **model_inputs, generation_config=HF_GENERATION_CONFIG, bos_token_id=HF_TOKENIZER.bos_token_id, stopping_criteria=stop_criteria, output_hidden_states=False, return_dict_in_generate=True, use_cache=True, ) text = decode_generated_text(output, model_inputs, HF_TOKENIZER) grid_polygon = parse_polygon(text) crop_polygon = recover_polygon(grid_polygon, crop_image.size, get_grid_size()) full_polygon = recover_polygon( grid_polygon, crop_image.size, get_grid_size(), offset_x=float(expanded_bbox[0]), offset_y=float(expanded_bbox[1]), ) full_polygon = clamp_polygon(full_polygon, image.size) return { "bbox": [float(v) for v in bbox], "expanded_bbox": [int(v) for v in expanded_bbox], "crop_image": crop_image, "text": text, "grid_polygon": [[int(x), int(y)] for x, y in grid_polygon], "crop_polygon": [[float(x), float(y)] for x, y in crop_polygon], "polygon": [[float(x), float(y)] for x, y in full_polygon], } def run_inference(image, bboxes, expand_ratio, subject): results = [] crop_gallery = [] for index, bbox in enumerate(bboxes, start=1): result = predict_single_bbox(image, bbox, expand_ratio, subject) result["index"] = index results.append(result) crop_overlay = draw_crop_polygon(result["crop_image"], result["crop_polygon"]) crop_gallery.append((crop_overlay, f"BBox {index} | expand={expand_ratio:.2f}")) overlay = draw_full_overlay(image, results) report = build_report(image, results, subject, expand_ratio) return overlay, crop_gallery, format_text_outputs(results), report def inference_canvas(image_data, bbox_text, expand_ratio, subject): try: image = load_image_source(image_data) except ValueError as exc: return None, [], str(exc), None try: bboxes = parse_bbox_text(bbox_text) except ValueError as exc: return image, [], str(exc), None if not bboxes: return image, [], "Please drag at least one valid bbox on the image.", None return run_inference(image, bboxes, expand_ratio, subject) def clear_outputs(): return None, [], "", None def build_demo(): with gr.Blocks( theme=gr.themes.Soft(), title="VectorLLM HF Full-Image BBox Demo", head=CANVAS_ANNOTATOR_HEAD, ) as demo: gr.Markdown("# VectorLLM HF Full-Image BBox Demo") gr.Markdown( "Upload a full image, draw one or more bboxes, choose an expand ratio between 1.0 and 1.3, " "then run VectorLLM on the cropped regions and project the predicted polygon back to the full image." ) with gr.Row(): with gr.Column(scale=1): gr.HTML(CANVAS_ANNOTATOR_HTML) hidden_image_data = gr.Textbox(visible=False, elem_id="vectorllm-hidden-image-data") hidden_bbox_text = gr.Textbox(visible=False, elem_id="vectorllm-hidden-bboxes") subject = gr.Radio( choices=[("Building", "building"), ("Object", "object")], value="building", label="Prompt Target", ) expand_ratio = gr.Slider( minimum=1.0, maximum=1.3, value=1.15, step=0.01, label="BBox Expand Ratio", ) with gr.Row(): run_button = gr.Button("Run", variant="primary") clear_button = gr.Button("Clear") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Full-Image Overlay", height=520) crop_gallery = gr.Gallery( label="Expanded Crop Preview", columns=2, height=240, object_fit="contain", ) output_text = gr.Textbox(label="Model Text Output", lines=12) output_json = gr.JSON(label="Structured Result") run_button.click( inference_canvas, inputs=[hidden_image_data, hidden_bbox_text, expand_ratio, subject], outputs=[output_image, crop_gallery, output_text, output_json], show_api=False, ) clear_button.click( clear_outputs, outputs=[output_image, crop_gallery, output_text, output_json], show_api=False, js=""" () => { if (window.vectorllmCanvasAnnotator) { window.vectorllmCanvasAnnotator.reset(); } return []; } """, ) return demo def main(): global HF_MODEL, HF_PROCESSOR, HF_TOKENIZER, HF_GENERATION_CONFIG args = parse_args() model_path = Path(args.model_path).expanduser().resolve() if not model_path.exists(): raise FileNotFoundError(f"Model path does not exist: {model_path}") HF_MODEL, HF_PROCESSOR, HF_TOKENIZER, HF_GENERATION_CONFIG = init_model( str(model_path), args.dtype, args.max_new_tokens, ) demo = build_demo() demo.queue() demo.launch( share=args.share, server_name=args.server_name, server_port=args.server_port, ) if __name__ == "__main__": main()