vectorllm_v1 / gradio_bbox_demo.py
insomnia7's picture
Upload folder using huggingface_hub
bcc6605 verified
Raw
History Blame Contribute Delete
34.1 kB
import argparse
import base64
import importlib.util
import inspect
import io
import json
import math
import os
import re
import sys
from pathlib import Path
def _disable_invalid_socks_proxy():
if importlib.util.find_spec("socksio") is not None:
return
for key in ("http_proxy", "https_proxy", "all_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"):
value = os.environ.get(key)
if value and value.lower().startswith("socks"):
os.environ.pop(key, None)
_disable_invalid_socks_proxy()
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
from starlette.templating import Jinja2Templates
from transformers import AutoModel, AutoProcessor, GenerationConfig, StoppingCriteria, StoppingCriteriaList
def _patch_starlette_template_response():
template_response = Jinja2Templates.TemplateResponse
params = tuple(inspect.signature(template_response).parameters.keys())
if len(params) < 3 or params[1] != "request":
return
if getattr(template_response, "_vectorllm_compat", False):
return
def _compat_template_response(self, *args, **kwargs):
if args and isinstance(args[0], str):
name = args[0]
context = args[1] if len(args) > 1 else kwargs.pop("context", None)
if context is None:
context = {}
if not isinstance(context, dict):
raise TypeError("TemplateResponse context must be a dict.")
request = kwargs.pop("request", None) or context.get("request")
if request is None:
raise TypeError("TemplateResponse request is required.")
return template_response(
self,
request,
name,
context,
*args[2:],
**kwargs,
)
return template_response(self, *args, **kwargs)
_compat_template_response._vectorllm_compat = True
Jinja2Templates.TemplateResponse = _compat_template_response
_patch_starlette_template_response()
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = next((parent for parent in SCRIPT_DIR.parents if parent.name == "VecorLLM"), SCRIPT_DIR)
DEFAULT_EXPORT_DIR = SCRIPT_DIR if (SCRIPT_DIR / "config.json").exists() else (REPO_ROOT.parent / "hf_model" / "vectorllm_hf_0407")
TORCH_DTYPE_MAP = {
"auto": "auto",
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
COORD_PATTERN = re.compile(r"<([xy])(\d+)>")
DEFAULT_PAD_COLOR = (109, 104, 75)
PIXEL_TOKEN = "<pixel>"
BUILDING_RAW_PROMPT = (
"<|im_start|>user\n<pixel>\nPlease extract the regular vector contour of the central building in the image, "
"start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n"
)
OBJECT_RAW_PROMPT = (
"<|im_start|>user\n<pixel>\nPlease extract the contour of the central object in the image, "
"start from the left top corner and in clockwise.<|im_end|>\n<|im_start|>assistant\n"
)
CANVAS_ANNOTATOR_HTML = """
<div id="vectorllm-canvas-annotator" class="vectorllm-canvas-annotator">
<div class="vectorllm-canvas-toolbar">
<label class="vectorllm-upload-button">
<input type="file" id="vectorllm-canvas-file" accept="image/*">
<span>Upload Image</span>
</label>
<button type="button" id="vectorllm-canvas-undo">Undo Last Box</button>
<button type="button" id="vectorllm-canvas-reset">Clear Boxes</button>
<span class="vectorllm-canvas-hint">Drag on the image to add one or more bounding boxes.</span>
</div>
<div id="vectorllm-canvas-stage" class="vectorllm-canvas-stage">
<canvas id="vectorllm-canvas-surface"></canvas>
<div id="vectorllm-canvas-empty" class="vectorllm-canvas-empty">
Upload an image to start drawing.
</div>
</div>
<div class="vectorllm-canvas-footer">
<div id="vectorllm-canvas-status">No image selected.</div>
<pre id="vectorllm-canvas-boxlist" class="vectorllm-canvas-boxlist">No boxes yet.</pre>
</div>
</div>
"""
CANVAS_ANNOTATOR_HEAD = """
<style>
.vectorllm-canvas-annotator {
display: flex;
flex-direction: column;
gap: 12px;
}
.vectorllm-canvas-toolbar {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
}
.vectorllm-upload-button,
.vectorllm-canvas-toolbar button {
border: 1px solid #c7d2fe;
border-radius: 999px;
background: #eef2ff;
color: #1f2937;
cursor: pointer;
font: inherit;
font-size: 14px;
padding: 8px 14px;
}
.vectorllm-upload-button {
display: inline-flex;
align-items: center;
position: relative;
overflow: hidden;
}
.vectorllm-upload-button input {
cursor: pointer;
inset: 0;
opacity: 0;
position: absolute;
}
.vectorllm-canvas-hint {
color: #4b5563;
font-size: 13px;
}
.vectorllm-canvas-stage {
align-items: center;
background:
linear-gradient(135deg, rgba(148, 163, 184, 0.14), rgba(59, 130, 246, 0.08)),
#f8fafc;
border: 1px solid #dbe4f0;
border-radius: 16px;
display: flex;
height: 520px;
justify-content: center;
overflow: hidden;
position: relative;
width: 100%;
}
#vectorllm-canvas-surface {
cursor: crosshair;
display: none;
max-height: 100%;
max-width: 100%;
touch-action: none;
}
.vectorllm-canvas-empty {
color: #64748b;
font-size: 14px;
padding: 24px;
text-align: center;
}
.vectorllm-canvas-footer {
display: flex;
flex-direction: column;
gap: 8px;
}
#vectorllm-canvas-status {
color: #334155;
font-size: 14px;
}
.vectorllm-canvas-boxlist {
background: #f8fafc;
border: 1px solid #dbe4f0;
border-radius: 12px;
color: #0f172a;
font-family: "IBM Plex Mono", monospace;
font-size: 12px;
margin: 0;
max-height: 120px;
overflow: auto;
padding: 10px 12px;
white-space: pre-wrap;
}
</style>
<script>
(() => {
const rootId = "vectorllm-canvas-annotator";
const fileId = "vectorllm-canvas-file";
const stageId = "vectorllm-canvas-stage";
const canvasId = "vectorllm-canvas-surface";
const emptyId = "vectorllm-canvas-empty";
const undoId = "vectorllm-canvas-undo";
const resetId = "vectorllm-canvas-reset";
const statusId = "vectorllm-canvas-status";
const boxListId = "vectorllm-canvas-boxlist";
const maxHeight = 520;
function clamp(value, minValue, maxValue) {
return Math.min(Math.max(value, minValue), maxValue);
}
function normalizeBox(box) {
const x1 = Math.min(box.x1, box.x2);
const y1 = Math.min(box.y1, box.y2);
const x2 = Math.max(box.x1, box.x2);
const y2 = Math.max(box.y1, box.y2);
if ((x2 - x1) < 2 || (y2 - y1) < 2) {
return null;
}
return { x1, y1, x2, y2 };
}
function formatBoxes(boxes) {
return boxes.map((box) => {
return [
Math.round(box.x1),
Math.round(box.y1),
Math.round(box.x2),
Math.round(box.y2),
].join(",");
}).join("\\n");
}
function formatBoxList(boxes) {
if (!boxes.length) {
return "No boxes yet.";
}
return boxes.map((box, index) => {
return `${index + 1}: ${Math.round(box.x1)},${Math.round(box.y1)},${Math.round(box.x2)},${Math.round(box.y2)}`;
}).join("\\n");
}
function initAnnotator() {
const root = document.getElementById(rootId);
if (!root || root.dataset.initialized === "true") {
return;
}
root.dataset.initialized = "true";
const fileInput = document.getElementById(fileId);
const stage = document.getElementById(stageId);
const canvas = document.getElementById(canvasId);
const empty = document.getElementById(emptyId);
const undoButton = document.getElementById(undoId);
const resetButton = document.getElementById(resetId);
const status = document.getElementById(statusId);
const boxList = document.getElementById(boxListId);
if (!fileInput || !stage || !canvas || !empty || !undoButton || !resetButton || !status || !boxList) {
return;
}
const ctx = canvas.getContext("2d");
if (!ctx) {
return;
}
const state = {
imageData: "",
image: null,
boxes: [],
draft: null,
scale: 1,
displayWidth: 0,
displayHeight: 0,
};
function setHiddenValue(elemId, value) {
const container = document.getElementById(elemId);
if (!container) {
return;
}
const field = container.querySelector("textarea, input");
if (!field) {
return;
}
const prototype = field.tagName === "TEXTAREA"
? window.HTMLTextAreaElement.prototype
: window.HTMLInputElement.prototype;
const descriptor = Object.getOwnPropertyDescriptor(prototype, "value");
if (descriptor && descriptor.set) {
descriptor.set.call(field, value);
} else {
field.value = value;
}
field.dispatchEvent(new Event("input", { bubbles: true }));
field.dispatchEvent(new Event("change", { bubbles: true }));
}
function syncHiddenInputs() {
setHiddenValue("vectorllm-hidden-image-data", state.imageData || "");
setHiddenValue("vectorllm-hidden-bboxes", formatBoxes(state.boxes));
}
function updateStatus() {
if (!state.image) {
status.textContent = "No image selected.";
boxList.textContent = "No boxes yet.";
return;
}
if (!state.boxes.length) {
status.textContent = "Image loaded. Drag on the image to draw a bbox.";
boxList.textContent = "No boxes yet.";
return;
}
status.textContent = `${state.boxes.length} box(es) selected.`;
boxList.textContent = formatBoxList(state.boxes);
}
function render() {
if (!state.image) {
canvas.style.display = "none";
empty.style.display = "flex";
return;
}
const stageWidth = Math.max(stage.clientWidth - 24, 240);
const scale = Math.min(stageWidth / state.image.naturalWidth, maxHeight / state.image.naturalHeight);
state.scale = scale;
state.displayWidth = Math.max(1, Math.round(state.image.naturalWidth * scale));
state.displayHeight = Math.max(1, Math.round(state.image.naturalHeight * scale));
const dpr = window.devicePixelRatio || 1;
canvas.width = Math.round(state.displayWidth * dpr);
canvas.height = Math.round(state.displayHeight * dpr);
canvas.style.width = `${state.displayWidth}px`;
canvas.style.height = `${state.displayHeight}px`;
canvas.style.display = "block";
empty.style.display = "none";
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
ctx.clearRect(0, 0, state.displayWidth, state.displayHeight);
ctx.drawImage(state.image, 0, 0, state.displayWidth, state.displayHeight);
state.boxes.forEach((box, index) => {
const x = box.x1 * scale;
const y = box.y1 * scale;
const width = (box.x2 - box.x1) * scale;
const height = (box.y2 - box.y1) * scale;
ctx.fillStyle = "rgba(34, 197, 94, 0.14)";
ctx.strokeStyle = "rgba(15, 118, 110, 0.95)";
ctx.lineWidth = 2;
ctx.fillRect(x, y, width, height);
ctx.strokeRect(x, y, width, height);
ctx.fillStyle = "rgba(15, 23, 42, 0.92)";
ctx.font = "12px sans-serif";
ctx.fillText(String(index + 1), x + 6, y + 16);
});
if (state.draft) {
const draftBox = normalizeBox(state.draft);
if (draftBox) {
const x = draftBox.x1 * scale;
const y = draftBox.y1 * scale;
const width = (draftBox.x2 - draftBox.x1) * scale;
const height = (draftBox.y2 - draftBox.y1) * scale;
ctx.strokeStyle = "rgba(59, 130, 246, 0.95)";
ctx.fillStyle = "rgba(59, 130, 246, 0.12)";
ctx.setLineDash([6, 4]);
ctx.lineWidth = 2;
ctx.fillRect(x, y, width, height);
ctx.strokeRect(x, y, width, height);
ctx.setLineDash([]);
}
}
}
function getCanvasPoint(event) {
const rect = canvas.getBoundingClientRect();
const x = clamp(event.clientX - rect.left, 0, state.displayWidth);
const y = clamp(event.clientY - rect.top, 0, state.displayHeight);
return {
x: x / state.scale,
y: y / state.scale,
};
}
function commitDraft() {
if (!state.draft) {
return;
}
const draftBox = normalizeBox(state.draft);
state.draft = null;
if (draftBox) {
state.boxes.push(draftBox);
}
syncHiddenInputs();
render();
updateStatus();
}
function resetBoxes() {
state.boxes = [];
state.draft = null;
syncHiddenInputs();
render();
updateStatus();
}
function resetAll() {
state.imageData = "";
state.image = null;
state.boxes = [];
state.draft = null;
fileInput.value = "";
syncHiddenInputs();
render();
updateStatus();
}
function loadImage(dataUrl) {
const image = new Image();
image.onload = () => {
state.imageData = dataUrl;
state.image = image;
state.boxes = [];
state.draft = null;
syncHiddenInputs();
render();
updateStatus();
};
image.src = dataUrl;
}
fileInput.addEventListener("change", (event) => {
const file = event.target.files && event.target.files[0];
if (!file) {
return;
}
const reader = new FileReader();
reader.onload = () => {
if (typeof reader.result === "string") {
loadImage(reader.result);
}
};
reader.readAsDataURL(file);
});
canvas.addEventListener("pointerdown", (event) => {
if (!state.image) {
return;
}
const point = getCanvasPoint(event);
state.draft = { x1: point.x, y1: point.y, x2: point.x, y2: point.y };
if (canvas.setPointerCapture) {
canvas.setPointerCapture(event.pointerId);
}
render();
event.preventDefault();
});
canvas.addEventListener("pointermove", (event) => {
if (!state.draft) {
return;
}
const point = getCanvasPoint(event);
state.draft.x2 = point.x;
state.draft.y2 = point.y;
render();
});
canvas.addEventListener("pointerup", (event) => {
if (canvas.releasePointerCapture) {
try {
canvas.releasePointerCapture(event.pointerId);
} catch (error) {
}
}
commitDraft();
});
canvas.addEventListener("pointerleave", () => {
if (state.draft) {
render();
}
});
undoButton.addEventListener("click", () => {
if (state.boxes.length) {
state.boxes.pop();
syncHiddenInputs();
render();
updateStatus();
}
});
resetButton.addEventListener("click", () => {
resetBoxes();
});
window.addEventListener("resize", () => {
if (state.image) {
window.requestAnimationFrame(render);
}
});
window.vectorllmCanvasAnnotator = {
getImageData: () => state.imageData || "",
getBoxesText: () => formatBoxes(state.boxes),
reset: resetAll,
resetBoxes: resetBoxes,
};
render();
updateStatus();
}
window.initVectorLLMCanvasAnnotator = initAnnotator;
if (document.readyState === "loading") {
document.addEventListener("DOMContentLoaded", initAnnotator, { once: true });
} else {
window.setTimeout(initAnnotator, 0);
}
const observer = new MutationObserver(() => initAnnotator());
observer.observe(document.documentElement, { childList: true, subtree: true });
})();
</script>
"""
HF_MODEL = None
HF_PROCESSOR = None
HF_TOKENIZER = None
HF_GENERATION_CONFIG = None
class StopWordStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace("\r", "").replace("\n", "")
return cur_text[-self.length:] == self.stop_word
def get_stop_criteria(tokenizer, stop_words=None):
stop_words = stop_words or []
stop_criteria = StoppingCriteriaList()
for word in stop_words:
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
return stop_criteria
def parse_args():
parser = argparse.ArgumentParser(description="VectorLLM HF Gradio demo with full-image bbox cropping.")
parser.add_argument(
"--model-path",
default=str(DEFAULT_EXPORT_DIR),
help="Local HF export directory. If this script is copied into the export folder, the folder itself is used.",
)
parser.add_argument(
"--dtype",
choices=sorted(TORCH_DTYPE_MAP.keys()),
default="auto",
help="Model dtype on CUDA. CPU uses fp32 automatically.",
)
parser.add_argument("--max-new-tokens", type=int, default=640)
parser.add_argument("--server-name", default="0.0.0.0")
parser.add_argument("--server-port", type=int, default=7861)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def bootstrap_local_registry(model_path):
model_path = Path(model_path).expanduser().resolve()
parent = str(model_path.parent)
package_name = model_path.name
if parent not in sys.path:
sys.path.insert(0, parent)
__import__(package_name)
def build_generation_config(model_path, tokenizer, max_new_tokens):
try:
generation_config = GenerationConfig.from_pretrained(model_path)
except Exception:
generation_config = GenerationConfig()
generation_config.max_new_tokens = max_new_tokens
generation_config.use_cache = True
generation_config.do_sample = False
generation_config.temperature = None
generation_config.top_k = None
generation_config.top_p = None
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
return generation_config
def init_model(model_path, dtype_name, max_new_tokens):
bootstrap_local_registry(model_path)
use_cuda = torch.cuda.is_available()
torch_dtype = TORCH_DTYPE_MAP[dtype_name] if use_cuda else torch.float32
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=False,
dtype=torch_dtype,
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=False)
tokenizer = processor.tokenizer
if use_cuda:
model = model.cuda()
model.eval()
generation_config = build_generation_config(model_path, tokenizer, max_new_tokens)
return model, processor, tokenizer, generation_config
def load_image_source(image_source):
if image_source is None:
raise ValueError("Please upload an image first.")
if isinstance(image_source, Image.Image):
return image_source.convert("RGB")
if isinstance(image_source, str):
if not image_source.strip():
raise ValueError("Please upload an image first.")
if image_source.startswith("data:image"):
_, encoded = image_source.split(",", 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
return Image.open(image_source).convert("RGB")
raise ValueError("Unsupported image input.")
def normalize_bbox(bbox):
x1, y1, x2, y2 = bbox
x1, x2 = sorted((float(x1), float(x2)))
y1, y2 = sorted((float(y1), float(y2)))
return [x1, y1, x2, y2]
def is_valid_bbox(bbox, min_size=2.0):
x1, y1, x2, y2 = normalize_bbox(bbox)
return (x2 - x1) >= min_size and (y2 - y1) >= min_size
def parse_bbox_text(raw_text):
if raw_text is None or not raw_text.strip():
return []
bbox_entries = []
invalid_entries = []
for chunk in re.split(r"[;\n]+", raw_text):
entry = chunk.strip()
if not entry:
continue
parts = [part.strip() for part in entry.split(",")]
if len(parts) != 4:
invalid_entries.append(entry)
continue
try:
bbox = [float(part) for part in parts]
except ValueError:
invalid_entries.append(entry)
continue
bbox = normalize_bbox(bbox)
if is_valid_bbox(bbox):
bbox_entries.append(bbox)
else:
invalid_entries.append(entry)
if invalid_entries:
raise ValueError(
"Invalid bbox entries: "
+ "; ".join(invalid_entries)
+ ". Use x1,y1,x2,y2 with width/height >= 2."
)
return bbox_entries
def get_grid_size():
image_processor = getattr(HF_PROCESSOR, "image_processor", None)
if image_processor is None:
return 128
resized_size = getattr(image_processor, "resized_size", 128)
return int(resized_size)
def get_pad_color():
image_processor = getattr(HF_PROCESSOR, "image_processor", None)
if image_processor is None:
return DEFAULT_PAD_COLOR
image_mean = getattr(image_processor, "image_mean", None)
if image_mean is None or len(image_mean) < 3:
return DEFAULT_PAD_COLOR
pad_color = []
for value in image_mean[:3]:
value = float(value)
if value <= 1.0:
value = value * 255.0
pad_color.append(int(round(min(max(value, 0.0), 255.0))))
return tuple(pad_color)
def get_raw_prompt(subject):
if subject == "object":
return OBJECT_RAW_PROMPT
return BUILDING_RAW_PROMPT
def decode_generated_text(output, model_inputs, tokenizer):
prompt_length = model_inputs["input_ids"].shape[1]
sequences = output.sequences if hasattr(output, "sequences") else output
def _clean(text):
return (
text.replace("<|im_end|>", "")
.replace("<|endoftext|>", "")
.replace("</s>", "")
.strip()
)
full_text = _clean(tokenizer.decode(sequences[0], skip_special_tokens=False))
sliced_text = ""
if sequences.shape[1] > prompt_length:
sliced_text = _clean(tokenizer.decode(sequences[0][prompt_length:], skip_special_tokens=False))
full_score = len(re.findall(r"<[xy]\d+>", full_text))
sliced_score = len(re.findall(r"<[xy]\d+>", sliced_text))
if sliced_score >= full_score and sliced_text:
return sliced_text
return full_text
def parse_polygon(text):
points = []
pending_x = None
for axis, raw_value in COORD_PATTERN.findall(text):
value = int(raw_value)
if axis == "x":
pending_x = value
elif pending_x is not None:
points.append((pending_x, value))
pending_x = None
return points
def expand_bbox(bbox, expand_ratio):
x1, y1, x2, y2 = normalize_bbox(bbox)
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
width = (x2 - x1) * float(expand_ratio)
height = (y2 - y1) * float(expand_ratio)
expanded = [
int(math.floor(cx - width / 2.0)),
int(math.floor(cy - height / 2.0)),
int(math.ceil(cx + width / 2.0)),
int(math.ceil(cy + height / 2.0)),
]
if expanded[2] <= expanded[0]:
expanded[2] = expanded[0] + 1
if expanded[3] <= expanded[1]:
expanded[3] = expanded[1] + 1
return expanded
def crop_image_by_bbox(image, bbox, expand_ratio):
expanded_bbox = expand_bbox(bbox, expand_ratio)
crop_width = max(1, expanded_bbox[2] - expanded_bbox[0])
crop_height = max(1, expanded_bbox[3] - expanded_bbox[1])
crop_image = Image.new("RGB", (crop_width, crop_height), get_pad_color())
src_x1 = max(0, expanded_bbox[0])
src_y1 = max(0, expanded_bbox[1])
src_x2 = min(image.size[0], expanded_bbox[2])
src_y2 = min(image.size[1], expanded_bbox[3])
if src_x2 > src_x1 and src_y2 > src_y1:
region = image.crop((src_x1, src_y1, src_x2, src_y2))
crop_image.paste(region, (src_x1 - expanded_bbox[0], src_y1 - expanded_bbox[1]))
return crop_image, expanded_bbox
def recover_polygon(points, image_size, grid_size, offset_x=0.0, offset_y=0.0):
image_w, image_h = image_size
recovered = []
for x_coord, y_coord in points:
x_val = (float(x_coord) + 0.5) / grid_size * image_w + offset_x
y_val = (float(y_coord) + 0.5) / grid_size * image_h + offset_y
recovered.append((x_val, y_val))
return recovered
def clamp_polygon(polygon, image_size):
image_w, image_h = image_size
clamped = []
for x_coord, y_coord in polygon:
clamped.append(
(
min(max(float(x_coord), 0.0), image_w - 1.0),
min(max(float(y_coord), 0.0), image_h - 1.0),
)
)
return clamped
def draw_crop_polygon(image, polygon):
rendered = image.convert("RGBA")
overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0))
drawer = ImageDraw.Draw(overlay)
polygon_points = [(int(round(x)), int(round(y))) for x, y in polygon]
if len(polygon_points) >= 3:
drawer.polygon(
polygon_points,
outline=(255, 0, 255, 255),
fill=(0, 255, 255, 90),
width=2,
)
for x_coord, y_coord in polygon_points:
drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255))
return Image.alpha_composite(rendered, overlay).convert("RGB")
def draw_full_overlay(image, results):
rendered = image.convert("RGBA")
overlay = Image.new("RGBA", rendered.size, (0, 0, 0, 0))
drawer = ImageDraw.Draw(overlay)
for result in results:
bbox = result["bbox"]
expanded_bbox = result["expanded_bbox"]
polygon = result["polygon"]
index = result["index"]
drawer.rectangle(
[tuple(expanded_bbox[:2]), tuple(expanded_bbox[2:])],
outline=(255, 191, 0, 255),
width=2,
)
drawer.rectangle(
[tuple(bbox[:2]), tuple(bbox[2:])],
outline=(0, 255, 127, 255),
width=2,
)
polygon_points = [(int(round(x)), int(round(y))) for x, y in polygon]
if len(polygon_points) >= 3:
drawer.polygon(
polygon_points,
outline=(255, 0, 255, 255),
fill=(0, 255, 255, 72),
width=2,
)
for x_coord, y_coord in polygon_points:
drawer.ellipse((x_coord - 2, y_coord - 2, x_coord + 2, y_coord + 2), fill=(255, 165, 0, 255))
anchor_x, anchor_y = polygon_points[0] if polygon_points else (int(round(bbox[0])), int(round(bbox[1])))
drawer.text((anchor_x + 4, anchor_y + 4), str(index), fill=(255, 255, 255, 255))
return Image.alpha_composite(rendered, overlay).convert("RGB")
def format_text_outputs(results):
if not results:
return "No model output."
chunks = []
for result in results:
chunks.append(f"[BBox {result['index']}]\n{result['text']}")
return "\n\n".join(chunks)
def build_report(image, results, subject, expand_ratio):
return {
"image_size": list(image.size),
"subject": subject,
"expand_ratio": float(expand_ratio),
"grid_size": get_grid_size(),
"results": [
{
"index": result["index"],
"bbox": result["bbox"],
"expanded_bbox": result["expanded_bbox"],
"crop_size": list(result["crop_image"].size),
"text": result["text"],
"grid_polygon": result["grid_polygon"],
"crop_polygon": result["crop_polygon"],
"polygon": result["polygon"],
}
for result in results
],
}
def predict_single_bbox(image, bbox, expand_ratio, subject):
crop_image, expanded_bbox = crop_image_by_bbox(image, bbox, expand_ratio)
prompt = get_raw_prompt(subject)
model_inputs = HF_PROCESSOR(text=[prompt], images=[crop_image], return_tensors="pt")
model_inputs = {
key: value.to(HF_MODEL.device) if torch.is_tensor(value) else value
for key, value in model_inputs.items()
}
stop_criteria = get_stop_criteria(HF_TOKENIZER, ["<|im_end|>", "<|endoftext|>"])
with torch.inference_mode():
output = HF_MODEL.generate(
**model_inputs,
generation_config=HF_GENERATION_CONFIG,
bos_token_id=HF_TOKENIZER.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True,
use_cache=True,
)
text = decode_generated_text(output, model_inputs, HF_TOKENIZER)
grid_polygon = parse_polygon(text)
crop_polygon = recover_polygon(grid_polygon, crop_image.size, get_grid_size())
full_polygon = recover_polygon(
grid_polygon,
crop_image.size,
get_grid_size(),
offset_x=float(expanded_bbox[0]),
offset_y=float(expanded_bbox[1]),
)
full_polygon = clamp_polygon(full_polygon, image.size)
return {
"bbox": [float(v) for v in bbox],
"expanded_bbox": [int(v) for v in expanded_bbox],
"crop_image": crop_image,
"text": text,
"grid_polygon": [[int(x), int(y)] for x, y in grid_polygon],
"crop_polygon": [[float(x), float(y)] for x, y in crop_polygon],
"polygon": [[float(x), float(y)] for x, y in full_polygon],
}
def run_inference(image, bboxes, expand_ratio, subject):
results = []
crop_gallery = []
for index, bbox in enumerate(bboxes, start=1):
result = predict_single_bbox(image, bbox, expand_ratio, subject)
result["index"] = index
results.append(result)
crop_overlay = draw_crop_polygon(result["crop_image"], result["crop_polygon"])
crop_gallery.append((crop_overlay, f"BBox {index} | expand={expand_ratio:.2f}"))
overlay = draw_full_overlay(image, results)
report = build_report(image, results, subject, expand_ratio)
return overlay, crop_gallery, format_text_outputs(results), report
def inference_canvas(image_data, bbox_text, expand_ratio, subject):
try:
image = load_image_source(image_data)
except ValueError as exc:
return None, [], str(exc), None
try:
bboxes = parse_bbox_text(bbox_text)
except ValueError as exc:
return image, [], str(exc), None
if not bboxes:
return image, [], "Please drag at least one valid bbox on the image.", None
return run_inference(image, bboxes, expand_ratio, subject)
def clear_outputs():
return None, [], "", None
def build_demo():
with gr.Blocks(
theme=gr.themes.Soft(),
title="VectorLLM HF Full-Image BBox Demo",
head=CANVAS_ANNOTATOR_HEAD,
) as demo:
gr.Markdown("# VectorLLM HF Full-Image BBox Demo")
gr.Markdown(
"Upload a full image, draw one or more bboxes, choose an expand ratio between 1.0 and 1.3, "
"then run VectorLLM on the cropped regions and project the predicted polygon back to the full image."
)
with gr.Row():
with gr.Column(scale=1):
gr.HTML(CANVAS_ANNOTATOR_HTML)
hidden_image_data = gr.Textbox(visible=False, elem_id="vectorllm-hidden-image-data")
hidden_bbox_text = gr.Textbox(visible=False, elem_id="vectorllm-hidden-bboxes")
subject = gr.Radio(
choices=[("Building", "building"), ("Object", "object")],
value="building",
label="Prompt Target",
)
expand_ratio = gr.Slider(
minimum=1.0,
maximum=1.3,
value=1.15,
step=0.01,
label="BBox Expand Ratio",
)
with gr.Row():
run_button = gr.Button("Run", variant="primary")
clear_button = gr.Button("Clear")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Full-Image Overlay", height=520)
crop_gallery = gr.Gallery(
label="Expanded Crop Preview",
columns=2,
height=240,
object_fit="contain",
)
output_text = gr.Textbox(label="Model Text Output", lines=12)
output_json = gr.JSON(label="Structured Result")
run_button.click(
inference_canvas,
inputs=[hidden_image_data, hidden_bbox_text, expand_ratio, subject],
outputs=[output_image, crop_gallery, output_text, output_json],
show_api=False,
)
clear_button.click(
clear_outputs,
outputs=[output_image, crop_gallery, output_text, output_json],
show_api=False,
js="""
() => {
if (window.vectorllmCanvasAnnotator) {
window.vectorllmCanvasAnnotator.reset();
}
return [];
}
""",
)
return demo
def main():
global HF_MODEL, HF_PROCESSOR, HF_TOKENIZER, HF_GENERATION_CONFIG
args = parse_args()
model_path = Path(args.model_path).expanduser().resolve()
if not model_path.exists():
raise FileNotFoundError(f"Model path does not exist: {model_path}")
HF_MODEL, HF_PROCESSOR, HF_TOKENIZER, HF_GENERATION_CONFIG = init_model(
str(model_path),
args.dtype,
args.max_new_tokens,
)
demo = build_demo()
demo.queue()
demo.launch(
share=args.share,
server_name=args.server_name,
server_port=args.server_port,
)
if __name__ == "__main__":
main()