File size: 6,170 Bytes
c6535db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import json
from typing import Any, List, Tuple
import torch
class AILab_Florence2ToCoordinates:
CATEGORY = "🧪AILab/🧽RMBG"
RETURN_TYPES = ("STRING", "BBOX", "MASK")
RETURN_NAMES = ("CENTER_COORDINATES", "BBOXES", "MASK")
FUNCTION = "convert"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"data": ("JSON", {"tooltip": "Florence2 JSON output (list per image)."}),
"index": (
"STRING",
{
"default": "",
"tooltip": "Comma-separated indexes; blank = use all boxes from first item.",
},
),
"batch": ("BOOLEAN", {"default": False, "tooltip": "If true, gather boxes across the batch."}),
},
"optional": {
"image": ("IMAGE",)
},
}
@staticmethod
def _parse_payload(payload: Any) -> List[Any]:
if payload is None:
return []
if isinstance(payload, str):
cleaned = payload.strip()
if not cleaned:
return []
try:
return json.loads(cleaned)
except json.JSONDecodeError:
# Fallback: some upstream nodes stringify Python lists with single quotes
try:
normalized = cleaned.replace("'", '"')
return json.loads(normalized)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON payload for Florence2 data") from exc
if isinstance(payload, list):
return payload
return []
@staticmethod
def _get_bboxes(entry: Any) -> List[List[float]]:
if isinstance(entry, dict):
if "bboxes" in entry:
return entry["bboxes"]
raise ValueError("Entry does not contain 'bboxes'.")
if isinstance(entry, list):
return entry
raise ValueError("Unsupported entry type; expected dict with 'bboxes' or list of boxes.")
@staticmethod
def _parse_indexes(index_str: str, default_count: int) -> List[int]:
text = index_str.strip()
if not text:
return list(range(default_count))
try:
return [int(part.strip()) for part in text.split(",") if part.strip()]
except ValueError as exc:
raise ValueError("Index must be comma-separated integers.") from exc
def convert(self, data, index: str, batch: bool = False, image=None):
records = self._parse_payload(data)
if not records:
empty = json.dumps([{"x": 0, "y": 0}])
mask = self._build_empty_mask(image)
return (empty, [], mask)
first_bboxes = self._get_bboxes(records[0])
if not first_bboxes:
empty = json.dumps([{"x": 0, "y": 0}])
mask = self._build_empty_mask(image)
return (empty, [], mask)
indexes = self._parse_indexes(index, len(first_bboxes))
centers = []
selected_boxes = []
selections: List[Tuple[int, List[float]]] = []
max_dims = {}
def append_box(batch_idx: int, box: List[float]):
min_x, min_y, max_x, max_y = box
center_x = int((min_x + max_x) / 2)
center_y = int((min_y + max_y) / 2)
centers.append({"x": center_x, "y": center_y})
selected_boxes.append(box)
selections.append((batch_idx, box))
dims = max_dims.setdefault(batch_idx, [1, 1])
dims[0] = max(dims[0], int(max_x) + 1)
dims[1] = max(dims[1], int(max_y) + 1)
if batch:
for batch_idx, record in enumerate(records):
boxes = self._get_bboxes(record)
for idx in indexes:
if 0 <= idx < len(boxes):
append_box(batch_idx, boxes[idx])
else:
boxes = first_bboxes
for idx in indexes:
if not 0 <= idx < len(boxes):
raise ValueError(f"Index {idx} is out of range for available boxes")
append_box(0, boxes[idx])
mask_tensor = self._build_mask_tensor(image, selections, max_dims, batch, len(records))
return (json.dumps(centers), selected_boxes, mask_tensor)
@staticmethod
def _build_empty_mask(image):
if image is not None:
tensor = image
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
return torch.zeros((tensor.shape[0], tensor.shape[1], tensor.shape[2]), dtype=torch.float32, device=tensor.device)
return torch.zeros((1, 1, 1), dtype=torch.float32)
def _build_mask_tensor(self, image, selections, max_dims, batch_mode, record_count):
if not selections:
return self._build_empty_mask(image)
if image is not None:
tensor = image
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
base = torch.zeros((tensor.shape[0], tensor.shape[1], tensor.shape[2]), dtype=torch.float32, device=tensor.device)
else:
batch_size = record_count if batch_mode else 1
max_width = max((dims[0] for dims in max_dims.values()), default=1)
max_height = max((dims[1] for dims in max_dims.values()), default=1)
base = torch.zeros((batch_size, max_height, max_width), dtype=torch.float32)
for batch_idx, box in selections:
min_x, min_y, max_x, max_y = box
width = base.shape[2]
height = base.shape[1]
x0 = max(0, min(int(min_x), width - 1))
y0 = max(0, min(int(min_y), height - 1))
x1 = max(x0 + 1, min(int(max_x), width))
y1 = max(y0 + 1, min(int(max_y), height))
base[batch_idx, y0:y1, x0:x1] = 1.0
return base
NODE_CLASS_MAPPINGS = {
"AILab_Florence2ToCoordinates": AILab_Florence2ToCoordinates,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AILab_Florence2ToCoordinates": "Florence2 Box Coordinates (RMBG)",
}
|