File size: 6,170 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import json
from typing import Any, List, Tuple

import torch


class AILab_Florence2ToCoordinates:
    CATEGORY = "🧪AILab/🧽RMBG"
    RETURN_TYPES = ("STRING", "BBOX", "MASK")
    RETURN_NAMES = ("CENTER_COORDINATES", "BBOXES", "MASK")
    FUNCTION = "convert"

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "data": ("JSON", {"tooltip": "Florence2 JSON output (list per image)."}),
                "index": (
                    "STRING",
                    {
                        "default": "",
                        "tooltip": "Comma-separated indexes; blank = use all boxes from first item.",
                    },
                ),
                "batch": ("BOOLEAN", {"default": False, "tooltip": "If true, gather boxes across the batch."}),
            },
            "optional": {
                "image": ("IMAGE",)
            },
        }

    @staticmethod
    def _parse_payload(payload: Any) -> List[Any]:
        if payload is None:
            return []
        if isinstance(payload, str):
            cleaned = payload.strip()
            if not cleaned:
                return []
            try:
                return json.loads(cleaned)
            except json.JSONDecodeError:
                # Fallback: some upstream nodes stringify Python lists with single quotes
                try:
                    normalized = cleaned.replace("'", '"')
                    return json.loads(normalized)
                except json.JSONDecodeError as exc:
                    raise ValueError("Invalid JSON payload for Florence2 data") from exc
        if isinstance(payload, list):
            return payload
        return []

    @staticmethod
    def _get_bboxes(entry: Any) -> List[List[float]]:
        if isinstance(entry, dict):
            if "bboxes" in entry:
                return entry["bboxes"]
            raise ValueError("Entry does not contain 'bboxes'.")
        if isinstance(entry, list):
            return entry
        raise ValueError("Unsupported entry type; expected dict with 'bboxes' or list of boxes.")

    @staticmethod
    def _parse_indexes(index_str: str, default_count: int) -> List[int]:
        text = index_str.strip()
        if not text:
            return list(range(default_count))
        try:
            return [int(part.strip()) for part in text.split(",") if part.strip()]
        except ValueError as exc:
            raise ValueError("Index must be comma-separated integers.") from exc

    def convert(self, data, index: str, batch: bool = False, image=None):
        records = self._parse_payload(data)
        if not records:
            empty = json.dumps([{"x": 0, "y": 0}])
            mask = self._build_empty_mask(image)
            return (empty, [], mask)

        first_bboxes = self._get_bboxes(records[0])
        if not first_bboxes:
            empty = json.dumps([{"x": 0, "y": 0}])
            mask = self._build_empty_mask(image)
            return (empty, [], mask)

        indexes = self._parse_indexes(index, len(first_bboxes))
        centers = []
        selected_boxes = []
        selections: List[Tuple[int, List[float]]] = []
        max_dims = {}

        def append_box(batch_idx: int, box: List[float]):
            min_x, min_y, max_x, max_y = box
            center_x = int((min_x + max_x) / 2)
            center_y = int((min_y + max_y) / 2)
            centers.append({"x": center_x, "y": center_y})
            selected_boxes.append(box)
            selections.append((batch_idx, box))
            dims = max_dims.setdefault(batch_idx, [1, 1])
            dims[0] = max(dims[0], int(max_x) + 1)
            dims[1] = max(dims[1], int(max_y) + 1)

        if batch:
            for batch_idx, record in enumerate(records):
                boxes = self._get_bboxes(record)
                for idx in indexes:
                    if 0 <= idx < len(boxes):
                        append_box(batch_idx, boxes[idx])
        else:
            boxes = first_bboxes
            for idx in indexes:
                if not 0 <= idx < len(boxes):
                    raise ValueError(f"Index {idx} is out of range for available boxes")
                append_box(0, boxes[idx])

        mask_tensor = self._build_mask_tensor(image, selections, max_dims, batch, len(records))

        return (json.dumps(centers), selected_boxes, mask_tensor)

    @staticmethod
    def _build_empty_mask(image):
        if image is not None:
            tensor = image
            if tensor.dim() == 3:
                tensor = tensor.unsqueeze(0)
            return torch.zeros((tensor.shape[0], tensor.shape[1], tensor.shape[2]), dtype=torch.float32, device=tensor.device)
        return torch.zeros((1, 1, 1), dtype=torch.float32)

    def _build_mask_tensor(self, image, selections, max_dims, batch_mode, record_count):
        if not selections:
            return self._build_empty_mask(image)

        if image is not None:
            tensor = image
            if tensor.dim() == 3:
                tensor = tensor.unsqueeze(0)
            base = torch.zeros((tensor.shape[0], tensor.shape[1], tensor.shape[2]), dtype=torch.float32, device=tensor.device)
        else:
            batch_size = record_count if batch_mode else 1
            max_width = max((dims[0] for dims in max_dims.values()), default=1)
            max_height = max((dims[1] for dims in max_dims.values()), default=1)
            base = torch.zeros((batch_size, max_height, max_width), dtype=torch.float32)

        for batch_idx, box in selections:
            min_x, min_y, max_x, max_y = box
            width = base.shape[2]
            height = base.shape[1]
            x0 = max(0, min(int(min_x), width - 1))
            y0 = max(0, min(int(min_y), height - 1))
            x1 = max(x0 + 1, min(int(max_x), width))
            y1 = max(y0 + 1, min(int(max_y), height))
            base[batch_idx, y0:y1, x0:x1] = 1.0

        return base


NODE_CLASS_MAPPINGS = {
    "AILab_Florence2ToCoordinates": AILab_Florence2ToCoordinates,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "AILab_Florence2ToCoordinates": "Florence2 Box Coordinates (RMBG)",
}