Object Detection
Transformers
Safetensors
falcon_perception
text-generation
falcon
detection
vision-language
open-vocabulary
custom_code
Instructions to use tiiuae/Falcon-Perception-300M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use tiiuae/Falcon-Perception-300M with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("object-detection", model="tiiuae/Falcon-Perception-300M", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-Perception-300M", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Update detection-mode code
Browse files- README.md +53 -34
- modeling_falcon_perception.py +76 -43
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
pipeline_tag:
|
| 3 |
library_name: transformers
|
| 4 |
tags:
|
| 5 |
- falcon
|
|
@@ -11,33 +11,35 @@ license: apache-2.0
|
|
| 11 |
|
| 12 |
<img src="main_fig.jpg" width="480" alt="Falcon Perception"/>
|
| 13 |
|
| 14 |
-
>
|
|
|
|
| 15 |
|
| 16 |
-
## Falcon Perception
|
| 17 |
|
| 18 |
-
Falcon Perception is a 0.3B parameter early-fusion vision-language model for open-vocabulary grounding detection. Given an image and a natural language query, it returns zero, one, or many matching instances with accurate bounding boxes.
|
| 19 |
|
| 20 |
-
The model is built around a simple interface. Image patches and text tokens are processed together in a single Transformer using a hybrid attention mask: image tokens build bidirectional visual context, while text and task tokens decode causally conditioned on the image. For each instance, the model generates a short structured sequence of task tokens
|
| 21 |
|
| 22 |
|
| 23 |
### Links
|
| 24 |
|
| 25 |
-
-
|
|
|
|
| 26 |
- Tech report: arXiv link coming soon
|
| 27 |
- PBench dataset: `tiiuae/PBench`
|
| 28 |
-
- OCR model: `tiiuae/Falcon-OCR`
|
| 29 |
|
| 30 |
## Quickstart
|
| 31 |
|
| 32 |
### Installation
|
| 33 |
|
| 34 |
```bash
|
| 35 |
-
pip install "torch>=2.5" transformers pillow einops
|
| 36 |
```
|
| 37 |
|
| 38 |
This model requires PyTorch 2.5 or newer for FlexAttention. The first call can be slower because `torch.compile` may build optimized kernels.
|
| 39 |
|
| 40 |
-
### Run open-vocabulary
|
| 41 |
|
| 42 |
```python
|
| 43 |
import torch
|
|
@@ -45,7 +47,7 @@ from PIL import Image
|
|
| 45 |
from transformers import AutoModelForCausalLM
|
| 46 |
|
| 47 |
model = AutoModelForCausalLM.from_pretrained(
|
| 48 |
-
"tiiuae/
|
| 49 |
trust_remote_code=True,
|
| 50 |
device_map={"": "cuda:0"},
|
| 51 |
)
|
|
@@ -57,18 +59,31 @@ for p in preds:
|
|
| 57 |
print(p["xy"], p["hw"])
|
| 58 |
```
|
| 59 |
|
| 60 |
-
|
| 61 |
|
| 62 |
```python
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
for p in preds:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
```
|
| 73 |
|
| 74 |
## API
|
|
@@ -79,6 +94,7 @@ for p in preds:
|
|
| 79 |
|---|---|---|---|
|
| 80 |
| `images` | `PIL.Image` or `list` | required | Single image or list of images |
|
| 81 |
| `queries` | `str` or `list[str]` | required | Query string(s), one per image |
|
|
|
|
| 82 |
| `max_new_tokens` | `int` | `2048` | Maximum decoding steps |
|
| 83 |
| `min_dimension` | `int` | `256` | Minimum image side after resize |
|
| 84 |
| `max_dimension` | `int` | `1024` | Maximum image side after resize |
|
|
@@ -86,23 +102,26 @@ for p in preds:
|
|
| 86 |
|
| 87 |
**Returns:** `list[list[dict]]`, one list per image.
|
| 88 |
|
| 89 |
-
Each
|
| 90 |
|
| 91 |
```python
|
| 92 |
{
|
| 93 |
-
"xy": {"x": float, "y": float},
|
| 94 |
-
"hw": {"h": float, "w": float},
|
| 95 |
-
"mask_rle": {"counts": str, "size": [H, W]}, # COCO RLE at original resolution
|
| 96 |
}
|
| 97 |
```
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
## What the model is for
|
| 100 |
|
| 101 |
-
Falcon Perception is designed for
|
| 102 |
|
| 103 |
- Natural language driven object selection in images
|
| 104 |
-
-
|
| 105 |
- Crowded scenes where the number of instances is large and variable
|
|
|
|
| 106 |
|
| 107 |
It is not intended as a general-purpose vision-language assistant for open-ended reasoning, long-form generation, or multi-step VQA.
|
| 108 |
|
|
@@ -112,24 +131,24 @@ The architecture follows a single-stack early-fusion recipe:
|
|
| 112 |
|
| 113 |
- One dense Transformer backbone processes image patches and text tokens in a shared space from the first layer
|
| 114 |
- Hybrid attention masking: bidirectional among image tokens, causal for text and task tokens conditioned on the image
|
| 115 |
-
- Chain-of-Perception decoding: `<|coord|>` then `<|size|>`
|
| 116 |
- Specialized heads for coordinates and size, with geometry conditioning via Fourier features
|
| 117 |
-
- Parallel mask decoding: each `<|seg|>` token becomes a mask query and produces a full-resolution mask via dot product with upsampled image features
|
| 118 |
-
|
| 119 |
-
## Evaluation summary
|
| 120 |
-
|
| 121 |
-
From the technical report:
|
| 122 |
|
| 123 |
-
|
| 124 |
-
- PBench: a diagnostic benchmark that breaks down performance by capability (attributes, OCR-guided disambiguation, spatial constraints, relations) and includes a dense long-context crowded split
|
| 125 |
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
## Limitations
|
| 129 |
|
| 130 |
-
- Presence calibration remains a key limitation for autoregressive dense interfaces. False positives are more likely on hard negatives than in DETR
|
| 131 |
- OCR-driven prompts depend on text size and image resolution. Small text and degraded scans are challenging.
|
| 132 |
- Dense scenes benefit strongly from high resolution inputs. Low resolution can be sufficient to recognize that a concept is present, but insufficient to localize each instance precisely.
|
|
|
|
| 133 |
|
| 134 |
## Citation
|
| 135 |
|
|
|
|
| 1 |
---
|
| 2 |
+
pipeline_tag: object-detection
|
| 3 |
library_name: transformers
|
| 4 |
tags:
|
| 5 |
- falcon
|
|
|
|
| 11 |
|
| 12 |
<img src="main_fig.jpg" width="480" alt="Falcon Perception"/>
|
| 13 |
|
| 14 |
+
> [!NOTE]
|
| 15 |
+
> This is the **300M parameter** variant of Falcon Perception. It supports **detection only** (bounding boxes). For the full model with segmentation masks, see [`tiiuae/Falcon-Perception`](https://huggingface.co/tiiuae/Falcon-Perception).
|
| 16 |
|
| 17 |
+
## Falcon Perception 300M
|
| 18 |
|
| 19 |
+
Falcon Perception 300M is a 0.3B parameter early-fusion vision-language model for open-vocabulary grounding detection. Given an image and a natural language query, it returns zero, one, or many matching instances with accurate bounding boxes.
|
| 20 |
|
| 21 |
+
The model is built around a simple interface. Image patches and text tokens are processed together in a single Transformer using a hybrid attention mask: image tokens build bidirectional visual context, while text and task tokens decode causally conditioned on the image. For each detected instance, the model generates a short structured sequence of task tokens: `<|coord|>` then `<|size|>`, producing a center point and bounding box size in normalized coordinates.
|
| 22 |
|
| 23 |
|
| 24 |
### Links
|
| 25 |
|
| 26 |
+
- Full model (with segmentation): [`tiiuae/Falcon-Perception`](https://huggingface.co/tiiuae/Falcon-Perception)
|
| 27 |
+
- Code and inference engine: [`github.com/tiiuae/Falcon-Perception`](https://github.com/tiiuae/Falcon-Perception)
|
| 28 |
- Tech report: arXiv link coming soon
|
| 29 |
- PBench dataset: `tiiuae/PBench`
|
| 30 |
+
- OCR model: [`tiiuae/Falcon-OCR`](https://huggingface.co/tiiuae/Falcon-OCR)
|
| 31 |
|
| 32 |
## Quickstart
|
| 33 |
|
| 34 |
### Installation
|
| 35 |
|
| 36 |
```bash
|
| 37 |
+
pip install "torch>=2.5" transformers pillow einops
|
| 38 |
```
|
| 39 |
|
| 40 |
This model requires PyTorch 2.5 or newer for FlexAttention. The first call can be slower because `torch.compile` may build optimized kernels.
|
| 41 |
|
| 42 |
+
### Run open-vocabulary detection
|
| 43 |
|
| 44 |
```python
|
| 45 |
import torch
|
|
|
|
| 47 |
from transformers import AutoModelForCausalLM
|
| 48 |
|
| 49 |
model = AutoModelForCausalLM.from_pretrained(
|
| 50 |
+
"tiiuae/Falcon-Perception-300M",
|
| 51 |
trust_remote_code=True,
|
| 52 |
device_map={"": "cuda:0"},
|
| 53 |
)
|
|
|
|
| 59 |
print(p["xy"], p["hw"])
|
| 60 |
```
|
| 61 |
|
| 62 |
+
Each prediction is a dict with normalized bounding box coordinates:
|
| 63 |
|
| 64 |
```python
|
| 65 |
+
{
|
| 66 |
+
"xy": {"x": float, "y": float}, # center in normalized coordinates (0 to 1)
|
| 67 |
+
"hw": {"h": float, "w": float}, # size in normalized coordinates (0 to 1)
|
| 68 |
+
}
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Visualize detections
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
from PIL import ImageDraw
|
| 75 |
+
|
| 76 |
+
draw = ImageDraw.Draw(image)
|
| 77 |
+
W, H = image.size
|
| 78 |
|
| 79 |
for p in preds:
|
| 80 |
+
cx, cy = p["xy"]["x"] * W, p["xy"]["y"] * H
|
| 81 |
+
bw, bh = p["hw"]["w"] * W, p["hw"]["h"] * H
|
| 82 |
+
x0, y0 = cx - bw / 2, cy - bh / 2
|
| 83 |
+
x1, y1 = cx + bw / 2, cy + bh / 2
|
| 84 |
+
draw.rectangle([x0, y0, x1, y1], outline="lime", width=2)
|
| 85 |
+
|
| 86 |
+
image.save("output.jpg")
|
| 87 |
```
|
| 88 |
|
| 89 |
## API
|
|
|
|
| 94 |
|---|---|---|---|
|
| 95 |
| `images` | `PIL.Image` or `list` | required | Single image or list of images |
|
| 96 |
| `queries` | `str` or `list[str]` | required | Query string(s), one per image |
|
| 97 |
+
| `task` | `str` | `"detection"` | Task type. Only `"detection"` is supported by this model. |
|
| 98 |
| `max_new_tokens` | `int` | `2048` | Maximum decoding steps |
|
| 99 |
| `min_dimension` | `int` | `256` | Minimum image side after resize |
|
| 100 |
| `max_dimension` | `int` | `1024` | Maximum image side after resize |
|
|
|
|
| 102 |
|
| 103 |
**Returns:** `list[list[dict]]`, one list per image.
|
| 104 |
|
| 105 |
+
Each detection dict contains:
|
| 106 |
|
| 107 |
```python
|
| 108 |
{
|
| 109 |
+
"xy": {"x": float, "y": float}, # center in normalized coordinates (0 to 1)
|
| 110 |
+
"hw": {"h": float, "w": float}, # size in normalized coordinates (0 to 1)
|
|
|
|
| 111 |
}
|
| 112 |
```
|
| 113 |
|
| 114 |
+
> [!NOTE]
|
| 115 |
+
> Requesting `task="segmentation"` on this model will raise a `ValueError`. Use the full [`tiiuae/Falcon-Perception`](https://huggingface.co/tiiuae/Falcon-Perception) model for segmentation masks.
|
| 116 |
+
|
| 117 |
## What the model is for
|
| 118 |
|
| 119 |
+
Falcon Perception 300M is designed for open-vocabulary object detection where the main difficulty is localization under free-form text queries. Use cases include:
|
| 120 |
|
| 121 |
- Natural language driven object selection in images
|
| 122 |
+
- Lightweight bounding-box detection for downstream pipelines
|
| 123 |
- Crowded scenes where the number of instances is large and variable
|
| 124 |
+
- Edge or resource-constrained deployments where the full model is too large
|
| 125 |
|
| 126 |
It is not intended as a general-purpose vision-language assistant for open-ended reasoning, long-form generation, or multi-step VQA.
|
| 127 |
|
|
|
|
| 131 |
|
| 132 |
- One dense Transformer backbone processes image patches and text tokens in a shared space from the first layer
|
| 133 |
- Hybrid attention masking: bidirectional among image tokens, causal for text and task tokens conditioned on the image
|
| 134 |
+
- Chain-of-Perception decoding: `<|coord|>` then `<|size|>` per instance
|
| 135 |
- Specialized heads for coordinates and size, with geometry conditioning via Fourier features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
## Comparison with the full model
|
|
|
|
| 138 |
|
| 139 |
+
| | **Falcon-Perception** | **Falcon-Perception-300M** |
|
| 140 |
+
|---|---|---|
|
| 141 |
+
| Parameters | ~7B | ~0.3B |
|
| 142 |
+
| Tasks | Detection + Segmentation | Detection only |
|
| 143 |
+
| Output | Bounding boxes + pixel masks | Bounding boxes |
|
| 144 |
+
| Token sequence | `<\|coord\|>` `<\|size\|>` `<\|seg\|>` | `<\|coord\|>` `<\|size\|>` |
|
| 145 |
|
| 146 |
## Limitations
|
| 147 |
|
| 148 |
+
- Presence calibration remains a key limitation for autoregressive dense interfaces. False positives are more likely on hard negatives than in DETR-like detection models.
|
| 149 |
- OCR-driven prompts depend on text size and image resolution. Small text and degraded scans are challenging.
|
| 150 |
- Dense scenes benefit strongly from high resolution inputs. Low resolution can be sufficient to recognize that a concept is present, but insufficient to localize each instance precisely.
|
| 151 |
+
- This variant does **not** produce segmentation masks. Use the full model if pixel-level masks are needed.
|
| 152 |
|
| 153 |
## Citation
|
| 154 |
|
modeling_falcon_perception.py
CHANGED
|
@@ -620,6 +620,7 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 620 |
self,
|
| 621 |
images,
|
| 622 |
queries,
|
|
|
|
| 623 |
max_new_tokens: int = 2048,
|
| 624 |
temperature: float = 0.0,
|
| 625 |
top_k: int | None = None,
|
|
@@ -630,11 +631,13 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 630 |
segm_threshold: float = 0.5,
|
| 631 |
) -> list[list[dict]]:
|
| 632 |
"""
|
| 633 |
-
|
| 634 |
|
| 635 |
Args:
|
| 636 |
images: Single PIL Image (or path/URL) or list of them.
|
| 637 |
queries: Single query string or list of query strings (one per image).
|
|
|
|
|
|
|
| 638 |
max_new_tokens: Maximum generation steps.
|
| 639 |
temperature: Sampling temperature (0.0 = greedy).
|
| 640 |
top_k: Top-k sampling (None = disabled).
|
|
@@ -645,14 +648,25 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 645 |
segm_threshold: Sigmoid threshold for binary mask.
|
| 646 |
|
| 647 |
Returns:
|
| 648 |
-
List (per image) of lists (per detection) of dicts
|
|
|
|
| 649 |
|
| 650 |
-
{
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
|
|
|
| 655 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
self._ensure_device_buffers()
|
| 657 |
if compile:
|
| 658 |
self.compile_model()
|
|
@@ -716,9 +730,12 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 716 |
coord_xy=coord_xy, size_hw=size_hw_t,
|
| 717 |
)
|
| 718 |
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
aux_output_B = [[] for _ in range(B)]
|
| 724 |
stop_ids = torch.tensor(stop_token_ids).to(device)
|
|
@@ -774,13 +791,14 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 774 |
for i, b in enumerate(sample_w_size.tolist()):
|
| 775 |
aux_output_B[b].append(size_preds[i])
|
| 776 |
|
| 777 |
-
# Decode segmentation
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
|
|
|
| 784 |
|
| 785 |
# Next step
|
| 786 |
logits_BSV, h_BSD = self.forward(
|
|
@@ -791,12 +809,13 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 791 |
hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
|
| 792 |
should_stop_B = should_stop_B.logical_or(hit_stop_B)
|
| 793 |
|
| 794 |
-
# Post-process: convert aux outputs to structured results
|
| 795 |
pixel_mask_batch = batch_inputs["pixel_mask"][:, 0] # (B, H, W)
|
| 796 |
results = []
|
| 797 |
for b in range(B):
|
| 798 |
dets = self._postprocess_aux(
|
| 799 |
aux_output_B[b], pixel_mask_batch[b], original_sizes[b], segm_threshold,
|
|
|
|
| 800 |
)
|
| 801 |
results.append(dets)
|
| 802 |
|
|
@@ -875,11 +894,29 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 875 |
orig_hw: tuple[int, int],
|
| 876 |
threshold: float,
|
| 877 |
nms_iou_threshold: float = 0.6,
|
|
|
|
| 878 |
) -> list[dict]:
|
| 879 |
-
"""Convert raw aux outputs into structured detections
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
orig_h, orig_w = orig_hw
|
| 881 |
|
| 882 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
nonzero = torch.nonzero(pixel_mask_hw, as_tuple=False)
|
| 884 |
if len(nonzero) > 0:
|
| 885 |
min_h, min_w = nonzero.min(dim=0)[0]
|
|
@@ -890,30 +927,26 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 890 |
min_h = min_w = 0
|
| 891 |
act_h = act_w = None
|
| 892 |
|
| 893 |
-
# Group into triplets: coord, size, mask
|
| 894 |
candidates = []
|
| 895 |
-
|
| 896 |
-
for
|
| 897 |
-
if
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
mask_logits = mask_logits
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
# Threshold
|
| 915 |
-
binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
|
| 916 |
-
candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
|
| 917 |
|
| 918 |
if not candidates:
|
| 919 |
return []
|
|
|
|
| 620 |
self,
|
| 621 |
images,
|
| 622 |
queries,
|
| 623 |
+
task: str | None = None,
|
| 624 |
max_new_tokens: int = 2048,
|
| 625 |
temperature: float = 0.0,
|
| 626 |
top_k: int | None = None,
|
|
|
|
| 631 |
segm_threshold: float = 0.5,
|
| 632 |
) -> list[list[dict]]:
|
| 633 |
"""
|
| 634 |
+
Detect (and optionally segment) objects in images matching the given queries.
|
| 635 |
|
| 636 |
Args:
|
| 637 |
images: Single PIL Image (or path/URL) or list of them.
|
| 638 |
queries: Single query string or list of query strings (one per image).
|
| 639 |
+
task: ``"segmentation"`` or ``"detection"``. Defaults to ``"segmentation"``
|
| 640 |
+
when the model supports it, ``"detection"`` otherwise.
|
| 641 |
max_new_tokens: Maximum generation steps.
|
| 642 |
temperature: Sampling temperature (0.0 = greedy).
|
| 643 |
top_k: Top-k sampling (None = disabled).
|
|
|
|
| 648 |
segm_threshold: Sigmoid threshold for binary mask.
|
| 649 |
|
| 650 |
Returns:
|
| 651 |
+
List (per image) of lists (per detection) of dicts.
|
| 652 |
+
For segmentation::
|
| 653 |
|
| 654 |
+
{"xy": {"x": float, "y": float}, "hw": {"h": float, "w": float},
|
| 655 |
+
"mask_rle": {"counts": str, "size": [H, W]}}
|
| 656 |
+
|
| 657 |
+
For detection::
|
| 658 |
+
|
| 659 |
+
{"xy": {"x": float, "y": float}, "hw": {"h": float, "w": float}}
|
| 660 |
"""
|
| 661 |
+
if task is None:
|
| 662 |
+
task = "segmentation" if self.config.do_segmentation else "detection"
|
| 663 |
+
if task == "segmentation" and not self.config.do_segmentation:
|
| 664 |
+
raise ValueError(
|
| 665 |
+
"Task 'segmentation' requires a model with segmentation heads, "
|
| 666 |
+
"but this model was exported with do_segmentation=False. "
|
| 667 |
+
"Use task='detection' instead."
|
| 668 |
+
)
|
| 669 |
+
do_segm = task == "segmentation"
|
| 670 |
self._ensure_device_buffers()
|
| 671 |
if compile:
|
| 672 |
self.compile_model()
|
|
|
|
| 730 |
coord_xy=coord_xy, size_hw=size_hw_t,
|
| 731 |
)
|
| 732 |
|
| 733 |
+
if do_segm:
|
| 734 |
+
hr_img_features = self.upsample_img_features(
|
| 735 |
+
h_BSD, tokens, batch_inputs["pixel_values"], batch_inputs["pixel_mask"],
|
| 736 |
+
)
|
| 737 |
+
else:
|
| 738 |
+
hr_img_features = None
|
| 739 |
|
| 740 |
aux_output_B = [[] for _ in range(B)]
|
| 741 |
stop_ids = torch.tensor(stop_token_ids).to(device)
|
|
|
|
| 791 |
for i, b in enumerate(sample_w_size.tolist()):
|
| 792 |
aux_output_B[b].append(size_preds[i])
|
| 793 |
|
| 794 |
+
# Decode segmentation (only when model has segmentation heads)
|
| 795 |
+
if do_segm:
|
| 796 |
+
sample_w_segm = torch.where(tokens_B1 == self.config.seg_token_id)[0]
|
| 797 |
+
segm_tokens = h_BSD[sample_w_segm, -1, :]
|
| 798 |
+
segm_tokens = self.proj_segm(segm_tokens)
|
| 799 |
+
segm_masks = torch.einsum("kdhw,kd->khw", hr_img_features[sample_w_segm], segm_tokens)
|
| 800 |
+
for i, b in enumerate(sample_w_segm):
|
| 801 |
+
aux_output_B[b].append(segm_masks[i])
|
| 802 |
|
| 803 |
# Next step
|
| 804 |
logits_BSV, h_BSD = self.forward(
|
|
|
|
| 809 |
hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
|
| 810 |
should_stop_B = should_stop_B.logical_or(hit_stop_B)
|
| 811 |
|
| 812 |
+
# Post-process: convert aux outputs to structured results
|
| 813 |
pixel_mask_batch = batch_inputs["pixel_mask"][:, 0] # (B, H, W)
|
| 814 |
results = []
|
| 815 |
for b in range(B):
|
| 816 |
dets = self._postprocess_aux(
|
| 817 |
aux_output_B[b], pixel_mask_batch[b], original_sizes[b], segm_threshold,
|
| 818 |
+
task=task,
|
| 819 |
)
|
| 820 |
results.append(dets)
|
| 821 |
|
|
|
|
| 894 |
orig_hw: tuple[int, int],
|
| 895 |
threshold: float,
|
| 896 |
nms_iou_threshold: float = 0.6,
|
| 897 |
+
task: str = "segmentation",
|
| 898 |
) -> list[dict]:
|
| 899 |
+
"""Convert raw aux outputs into structured detections.
|
| 900 |
+
|
| 901 |
+
For segmentation, returns dicts with ``xy``, ``hw``, and ``mask_rle``.
|
| 902 |
+
For detection, returns dicts with ``xy`` and ``hw`` only.
|
| 903 |
+
"""
|
| 904 |
orig_h, orig_w = orig_hw
|
| 905 |
|
| 906 |
+
if task == "detection":
|
| 907 |
+
# Detection-only: aux_list is interleaved coord/size dicts
|
| 908 |
+
detections = []
|
| 909 |
+
xy = None
|
| 910 |
+
for item in aux_list:
|
| 911 |
+
if isinstance(item, dict):
|
| 912 |
+
if "x" in item or "y" in item:
|
| 913 |
+
xy = item
|
| 914 |
+
elif ("h" in item or "w" in item) and xy is not None:
|
| 915 |
+
detections.append({"xy": xy, "hw": item})
|
| 916 |
+
xy = None
|
| 917 |
+
return detections
|
| 918 |
+
|
| 919 |
+
# Segmentation: find active image region from pixel mask
|
| 920 |
nonzero = torch.nonzero(pixel_mask_hw, as_tuple=False)
|
| 921 |
if len(nonzero) > 0:
|
| 922 |
min_h, min_w = nonzero.min(dim=0)[0]
|
|
|
|
| 927 |
min_h = min_w = 0
|
| 928 |
act_h = act_w = None
|
| 929 |
|
| 930 |
+
# Group into triplets: coord, size, mask
|
| 931 |
candidates = []
|
| 932 |
+
xy = hw = None
|
| 933 |
+
for item in aux_list:
|
| 934 |
+
if isinstance(item, dict):
|
| 935 |
+
if "x" in item or "y" in item:
|
| 936 |
+
xy = item
|
| 937 |
+
hw = None
|
| 938 |
+
elif "h" in item or "w" in item:
|
| 939 |
+
hw = item
|
| 940 |
+
elif isinstance(item, torch.Tensor) and xy is not None and hw is not None:
|
| 941 |
+
mask_logits = item
|
| 942 |
+
if act_h is not None and act_w is not None:
|
| 943 |
+
mask_logits = mask_logits[min_h:min_h + act_h, min_w:min_w + act_w]
|
| 944 |
+
mask_logits = mask_logits.unsqueeze(0).unsqueeze(0).float()
|
| 945 |
+
mask_logits = F.interpolate(mask_logits, size=(orig_h, orig_w), mode="bilinear", align_corners=False)
|
| 946 |
+
mask_logits = mask_logits.squeeze(0).squeeze(0)
|
| 947 |
+
binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
|
| 948 |
+
candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
|
| 949 |
+
xy = hw = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
|
| 951 |
if not candidates:
|
| 952 |
return []
|