lkhphuc commited on
Commit
3699368
·
1 Parent(s): 13e4160

Update detection-mode code

Browse files
Files changed (2) hide show
  1. README.md +53 -34
  2. modeling_falcon_perception.py +76 -43
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- pipeline_tag: mask-generation
3
  library_name: transformers
4
  tags:
5
  - falcon
@@ -11,33 +11,35 @@ license: apache-2.0
11
 
12
  <img src="main_fig.jpg" width="480" alt="Falcon Perception"/>
13
 
14
- > ![NOTE] This is the smaller version (300M parameters) and only support detection task.
 
15
 
16
- ## Falcon Perception
17
 
18
- Falcon Perception is a 0.3B parameter early-fusion vision-language model for open-vocabulary grounding detection. Given an image and a natural language query, it returns zero, one, or many matching instances with accurate bounding boxes.
19
 
20
- The model is built around a simple interface. Image patches and text tokens are processed together in a single Transformer using a hybrid attention mask: image tokens build bidirectional visual context, while text and task tokens decode causally conditioned on the image. For each instance, the model generates a short structured sequence of task tokens in a fixed order, `<|coord|>` then `<|size|>` then `<|seg|>`. The `<|seg|>` token acts as a mask query whose hidden state is projected and dotted with upsampled image features, producing a full-resolution binary mask without autoregressive mask generation.
21
 
22
 
23
  ### Links
24
 
25
- - Code and inference engine: `https://github.com/tiiuae/Falcon-Perception`
 
26
  - Tech report: arXiv link coming soon
27
  - PBench dataset: `tiiuae/PBench`
28
- - OCR model: `tiiuae/Falcon-OCR`
29
 
30
  ## Quickstart
31
 
32
  ### Installation
33
 
34
  ```bash
35
- pip install "torch>=2.5" transformers pillow einops pycocotools
36
  ```
37
 
38
  This model requires PyTorch 2.5 or newer for FlexAttention. The first call can be slower because `torch.compile` may build optimized kernels.
39
 
40
- ### Run open-vocabulary segmentation
41
 
42
  ```python
43
  import torch
@@ -45,7 +47,7 @@ from PIL import Image
45
  from transformers import AutoModelForCausalLM
46
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
- "tiiuae/falcon-perception-300m",
49
  trust_remote_code=True,
50
  device_map={"": "cuda:0"},
51
  )
@@ -57,18 +59,31 @@ for p in preds:
57
  print(p["xy"], p["hw"])
58
  ```
59
 
60
- ### Decode masks
61
 
62
  ```python
63
- import numpy as np
64
- from pycocotools import mask as mask_utils
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  for p in preds:
67
- rle = p["mask_rle"]
68
- # pycocotools expects bytes for counts
69
- m = {"size": rle["size"], "counts": rle["counts"].encode("utf-8")}
70
- mask = mask_utils.decode(m).astype(bool) # H x W
71
- print(mask.shape, mask.sum())
 
 
72
  ```
73
 
74
  ## API
@@ -79,6 +94,7 @@ for p in preds:
79
  |---|---|---|---|
80
  | `images` | `PIL.Image` or `list` | required | Single image or list of images |
81
  | `queries` | `str` or `list[str]` | required | Query string(s), one per image |
 
82
  | `max_new_tokens` | `int` | `2048` | Maximum decoding steps |
83
  | `min_dimension` | `int` | `256` | Minimum image side after resize |
84
  | `max_dimension` | `int` | `1024` | Maximum image side after resize |
@@ -86,23 +102,26 @@ for p in preds:
86
 
87
  **Returns:** `list[list[dict]]`, one list per image.
88
 
89
- Each prediction dict contains:
90
 
91
  ```python
92
  {
93
- "xy": {"x": float, "y": float}, # center in normalized coordinates (0 to 1)
94
- "hw": {"h": float, "w": float}, # size in normalized coordinates (0 to 1)
95
- "mask_rle": {"counts": str, "size": [H, W]}, # COCO RLE at original resolution
96
  }
97
  ```
98
 
 
 
 
99
  ## What the model is for
100
 
101
- Falcon Perception is designed for dense grounding regimes where the main difficulty is localization under open vocabulary. That includes:
102
 
103
  - Natural language driven object selection in images
104
- - Promptable instance segmentation for downstream pipelines
105
  - Crowded scenes where the number of instances is large and variable
 
106
 
107
  It is not intended as a general-purpose vision-language assistant for open-ended reasoning, long-form generation, or multi-step VQA.
108
 
@@ -112,24 +131,24 @@ The architecture follows a single-stack early-fusion recipe:
112
 
113
  - One dense Transformer backbone processes image patches and text tokens in a shared space from the first layer
114
  - Hybrid attention masking: bidirectional among image tokens, causal for text and task tokens conditioned on the image
115
- - Chain-of-Perception decoding: `<|coord|>` then `<|size|>` then `<|seg|>` per instance
116
  - Specialized heads for coordinates and size, with geometry conditioning via Fourier features
117
- - Parallel mask decoding: each `<|seg|>` token becomes a mask query and produces a full-resolution mask via dot product with upsampled image features
118
-
119
- ## Evaluation summary
120
-
121
- From the technical report:
122
 
123
- - SA-Co (open-vocabulary segmentation): 68.0 Macro F1 compared to 62.3 for SAM 3, with the main remaining gap being presence calibration (Average MCC 0.64 compared to 0.82 for SAM 3)
124
- - PBench: a diagnostic benchmark that breaks down performance by capability (attributes, OCR-guided disambiguation, spatial constraints, relations) and includes a dense long-context crowded split
125
 
126
- Full tables, setup details, and ablations are in the report.
 
 
 
 
 
127
 
128
  ## Limitations
129
 
130
- - Presence calibration remains a key limitation for autoregressive dense interfaces. False positives are more likely on hard negatives than in DETR like segmentation models.
131
  - OCR-driven prompts depend on text size and image resolution. Small text and degraded scans are challenging.
132
  - Dense scenes benefit strongly from high resolution inputs. Low resolution can be sufficient to recognize that a concept is present, but insufficient to localize each instance precisely.
 
133
 
134
  ## Citation
135
 
 
1
  ---
2
+ pipeline_tag: object-detection
3
  library_name: transformers
4
  tags:
5
  - falcon
 
11
 
12
  <img src="main_fig.jpg" width="480" alt="Falcon Perception"/>
13
 
14
+ > [!NOTE]
15
+ > This is the **300M parameter** variant of Falcon Perception. It supports **detection only** (bounding boxes). For the full model with segmentation masks, see [`tiiuae/Falcon-Perception`](https://huggingface.co/tiiuae/Falcon-Perception).
16
 
17
+ ## Falcon Perception 300M
18
 
19
+ Falcon Perception 300M is a 0.3B parameter early-fusion vision-language model for open-vocabulary grounding detection. Given an image and a natural language query, it returns zero, one, or many matching instances with accurate bounding boxes.
20
 
21
+ The model is built around a simple interface. Image patches and text tokens are processed together in a single Transformer using a hybrid attention mask: image tokens build bidirectional visual context, while text and task tokens decode causally conditioned on the image. For each detected instance, the model generates a short structured sequence of task tokens: `<|coord|>` then `<|size|>`, producing a center point and bounding box size in normalized coordinates.
22
 
23
 
24
  ### Links
25
 
26
+ - Full model (with segmentation): [`tiiuae/Falcon-Perception`](https://huggingface.co/tiiuae/Falcon-Perception)
27
+ - Code and inference engine: [`github.com/tiiuae/Falcon-Perception`](https://github.com/tiiuae/Falcon-Perception)
28
  - Tech report: arXiv link coming soon
29
  - PBench dataset: `tiiuae/PBench`
30
+ - OCR model: [`tiiuae/Falcon-OCR`](https://huggingface.co/tiiuae/Falcon-OCR)
31
 
32
  ## Quickstart
33
 
34
  ### Installation
35
 
36
  ```bash
37
+ pip install "torch>=2.5" transformers pillow einops
38
  ```
39
 
40
  This model requires PyTorch 2.5 or newer for FlexAttention. The first call can be slower because `torch.compile` may build optimized kernels.
41
 
42
+ ### Run open-vocabulary detection
43
 
44
  ```python
45
  import torch
 
47
  from transformers import AutoModelForCausalLM
48
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
+ "tiiuae/Falcon-Perception-300M",
51
  trust_remote_code=True,
52
  device_map={"": "cuda:0"},
53
  )
 
59
  print(p["xy"], p["hw"])
60
  ```
61
 
62
+ Each prediction is a dict with normalized bounding box coordinates:
63
 
64
  ```python
65
+ {
66
+ "xy": {"x": float, "y": float}, # center in normalized coordinates (0 to 1)
67
+ "hw": {"h": float, "w": float}, # size in normalized coordinates (0 to 1)
68
+ }
69
+ ```
70
+
71
+ ### Visualize detections
72
+
73
+ ```python
74
+ from PIL import ImageDraw
75
+
76
+ draw = ImageDraw.Draw(image)
77
+ W, H = image.size
78
 
79
  for p in preds:
80
+ cx, cy = p["xy"]["x"] * W, p["xy"]["y"] * H
81
+ bw, bh = p["hw"]["w"] * W, p["hw"]["h"] * H
82
+ x0, y0 = cx - bw / 2, cy - bh / 2
83
+ x1, y1 = cx + bw / 2, cy + bh / 2
84
+ draw.rectangle([x0, y0, x1, y1], outline="lime", width=2)
85
+
86
+ image.save("output.jpg")
87
  ```
88
 
89
  ## API
 
94
  |---|---|---|---|
95
  | `images` | `PIL.Image` or `list` | required | Single image or list of images |
96
  | `queries` | `str` or `list[str]` | required | Query string(s), one per image |
97
+ | `task` | `str` | `"detection"` | Task type. Only `"detection"` is supported by this model. |
98
  | `max_new_tokens` | `int` | `2048` | Maximum decoding steps |
99
  | `min_dimension` | `int` | `256` | Minimum image side after resize |
100
  | `max_dimension` | `int` | `1024` | Maximum image side after resize |
 
102
 
103
  **Returns:** `list[list[dict]]`, one list per image.
104
 
105
+ Each detection dict contains:
106
 
107
  ```python
108
  {
109
+ "xy": {"x": float, "y": float}, # center in normalized coordinates (0 to 1)
110
+ "hw": {"h": float, "w": float}, # size in normalized coordinates (0 to 1)
 
111
  }
112
  ```
113
 
114
+ > [!NOTE]
115
+ > Requesting `task="segmentation"` on this model will raise a `ValueError`. Use the full [`tiiuae/Falcon-Perception`](https://huggingface.co/tiiuae/Falcon-Perception) model for segmentation masks.
116
+
117
  ## What the model is for
118
 
119
+ Falcon Perception 300M is designed for open-vocabulary object detection where the main difficulty is localization under free-form text queries. Use cases include:
120
 
121
  - Natural language driven object selection in images
122
+ - Lightweight bounding-box detection for downstream pipelines
123
  - Crowded scenes where the number of instances is large and variable
124
+ - Edge or resource-constrained deployments where the full model is too large
125
 
126
  It is not intended as a general-purpose vision-language assistant for open-ended reasoning, long-form generation, or multi-step VQA.
127
 
 
131
 
132
  - One dense Transformer backbone processes image patches and text tokens in a shared space from the first layer
133
  - Hybrid attention masking: bidirectional among image tokens, causal for text and task tokens conditioned on the image
134
+ - Chain-of-Perception decoding: `<|coord|>` then `<|size|>` per instance
135
  - Specialized heads for coordinates and size, with geometry conditioning via Fourier features
 
 
 
 
 
136
 
137
+ ## Comparison with the full model
 
138
 
139
+ | | **Falcon-Perception** | **Falcon-Perception-300M** |
140
+ |---|---|---|
141
+ | Parameters | ~7B | ~0.3B |
142
+ | Tasks | Detection + Segmentation | Detection only |
143
+ | Output | Bounding boxes + pixel masks | Bounding boxes |
144
+ | Token sequence | `<\|coord\|>` `<\|size\|>` `<\|seg\|>` | `<\|coord\|>` `<\|size\|>` |
145
 
146
  ## Limitations
147
 
148
+ - Presence calibration remains a key limitation for autoregressive dense interfaces. False positives are more likely on hard negatives than in DETR-like detection models.
149
  - OCR-driven prompts depend on text size and image resolution. Small text and degraded scans are challenging.
150
  - Dense scenes benefit strongly from high resolution inputs. Low resolution can be sufficient to recognize that a concept is present, but insufficient to localize each instance precisely.
151
+ - This variant does **not** produce segmentation masks. Use the full model if pixel-level masks are needed.
152
 
153
  ## Citation
154
 
modeling_falcon_perception.py CHANGED
@@ -620,6 +620,7 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
620
  self,
621
  images,
622
  queries,
 
623
  max_new_tokens: int = 2048,
624
  temperature: float = 0.0,
625
  top_k: int | None = None,
@@ -630,11 +631,13 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
630
  segm_threshold: float = 0.5,
631
  ) -> list[list[dict]]:
632
  """
633
- Segment objects in images matching the given queries.
634
 
635
  Args:
636
  images: Single PIL Image (or path/URL) or list of them.
637
  queries: Single query string or list of query strings (one per image).
 
 
638
  max_new_tokens: Maximum generation steps.
639
  temperature: Sampling temperature (0.0 = greedy).
640
  top_k: Top-k sampling (None = disabled).
@@ -645,14 +648,25 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
645
  segm_threshold: Sigmoid threshold for binary mask.
646
 
647
  Returns:
648
- List (per image) of lists (per detection) of dicts::
 
649
 
650
- {
651
- "xy": {"x": float, "y": float},
652
- "hw": {"h": float, "w": float},
653
- "mask_rle": {"counts": str, "size": [H, W]},
654
- }
 
655
  """
 
 
 
 
 
 
 
 
 
656
  self._ensure_device_buffers()
657
  if compile:
658
  self.compile_model()
@@ -716,9 +730,12 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
716
  coord_xy=coord_xy, size_hw=size_hw_t,
717
  )
718
 
719
- hr_img_features = self.upsample_img_features(
720
- h_BSD, tokens, batch_inputs["pixel_values"], batch_inputs["pixel_mask"],
721
- )
 
 
 
722
 
723
  aux_output_B = [[] for _ in range(B)]
724
  stop_ids = torch.tensor(stop_token_ids).to(device)
@@ -774,13 +791,14 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
774
  for i, b in enumerate(sample_w_size.tolist()):
775
  aux_output_B[b].append(size_preds[i])
776
 
777
- # Decode segmentation
778
- sample_w_segm = torch.where(tokens_B1 == self.config.seg_token_id)[0]
779
- segm_tokens = h_BSD[sample_w_segm, -1, :]
780
- segm_tokens = self.proj_segm(segm_tokens)
781
- segm_masks = torch.einsum("kdhw,kd->khw", hr_img_features[sample_w_segm], segm_tokens)
782
- for i, b in enumerate(sample_w_segm):
783
- aux_output_B[b].append(segm_masks[i])
 
784
 
785
  # Next step
786
  logits_BSV, h_BSD = self.forward(
@@ -791,12 +809,13 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
791
  hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
792
  should_stop_B = should_stop_B.logical_or(hit_stop_B)
793
 
794
- # Post-process: convert aux outputs to structured results with RLE masks
795
  pixel_mask_batch = batch_inputs["pixel_mask"][:, 0] # (B, H, W)
796
  results = []
797
  for b in range(B):
798
  dets = self._postprocess_aux(
799
  aux_output_B[b], pixel_mask_batch[b], original_sizes[b], segm_threshold,
 
800
  )
801
  results.append(dets)
802
 
@@ -875,11 +894,29 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
875
  orig_hw: tuple[int, int],
876
  threshold: float,
877
  nms_iou_threshold: float = 0.6,
 
878
  ) -> list[dict]:
879
- """Convert raw aux outputs into structured detections with RLE masks."""
 
 
 
 
880
  orig_h, orig_w = orig_hw
881
 
882
- # Find active image region from pixel mask
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  nonzero = torch.nonzero(pixel_mask_hw, as_tuple=False)
884
  if len(nonzero) > 0:
885
  min_h, min_w = nonzero.min(dim=0)[0]
@@ -890,30 +927,26 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
890
  min_h = min_w = 0
891
  act_h = act_w = None
892
 
893
- # Group into triplets: coord, size, mask — build binary masks first
894
  candidates = []
895
- step = 3 # coord, size, mask
896
- for i in range(0, len(aux_list), step):
897
- if i + 2 >= len(aux_list):
898
- break
899
- xy = aux_list[i]
900
- hw = aux_list[i + 1]
901
- mask_logits = aux_list[i + 2]
902
- if not isinstance(mask_logits, torch.Tensor):
903
- continue
904
-
905
- # Crop to active region
906
- if act_h is not None and act_w is not None:
907
- mask_logits = mask_logits[min_h:min_h + act_h, min_w:min_w + act_w]
908
-
909
- # Resize to original image size
910
- mask_logits = mask_logits.unsqueeze(0).unsqueeze(0).float()
911
- mask_logits = F.interpolate(mask_logits, size=(orig_h, orig_w), mode="bilinear", align_corners=False)
912
- mask_logits = mask_logits.squeeze(0).squeeze(0)
913
-
914
- # Threshold
915
- binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
916
- candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
917
 
918
  if not candidates:
919
  return []
 
620
  self,
621
  images,
622
  queries,
623
+ task: str | None = None,
624
  max_new_tokens: int = 2048,
625
  temperature: float = 0.0,
626
  top_k: int | None = None,
 
631
  segm_threshold: float = 0.5,
632
  ) -> list[list[dict]]:
633
  """
634
+ Detect (and optionally segment) objects in images matching the given queries.
635
 
636
  Args:
637
  images: Single PIL Image (or path/URL) or list of them.
638
  queries: Single query string or list of query strings (one per image).
639
+ task: ``"segmentation"`` or ``"detection"``. Defaults to ``"segmentation"``
640
+ when the model supports it, ``"detection"`` otherwise.
641
  max_new_tokens: Maximum generation steps.
642
  temperature: Sampling temperature (0.0 = greedy).
643
  top_k: Top-k sampling (None = disabled).
 
648
  segm_threshold: Sigmoid threshold for binary mask.
649
 
650
  Returns:
651
+ List (per image) of lists (per detection) of dicts.
652
+ For segmentation::
653
 
654
+ {"xy": {"x": float, "y": float}, "hw": {"h": float, "w": float},
655
+ "mask_rle": {"counts": str, "size": [H, W]}}
656
+
657
+ For detection::
658
+
659
+ {"xy": {"x": float, "y": float}, "hw": {"h": float, "w": float}}
660
  """
661
+ if task is None:
662
+ task = "segmentation" if self.config.do_segmentation else "detection"
663
+ if task == "segmentation" and not self.config.do_segmentation:
664
+ raise ValueError(
665
+ "Task 'segmentation' requires a model with segmentation heads, "
666
+ "but this model was exported with do_segmentation=False. "
667
+ "Use task='detection' instead."
668
+ )
669
+ do_segm = task == "segmentation"
670
  self._ensure_device_buffers()
671
  if compile:
672
  self.compile_model()
 
730
  coord_xy=coord_xy, size_hw=size_hw_t,
731
  )
732
 
733
+ if do_segm:
734
+ hr_img_features = self.upsample_img_features(
735
+ h_BSD, tokens, batch_inputs["pixel_values"], batch_inputs["pixel_mask"],
736
+ )
737
+ else:
738
+ hr_img_features = None
739
 
740
  aux_output_B = [[] for _ in range(B)]
741
  stop_ids = torch.tensor(stop_token_ids).to(device)
 
791
  for i, b in enumerate(sample_w_size.tolist()):
792
  aux_output_B[b].append(size_preds[i])
793
 
794
+ # Decode segmentation (only when model has segmentation heads)
795
+ if do_segm:
796
+ sample_w_segm = torch.where(tokens_B1 == self.config.seg_token_id)[0]
797
+ segm_tokens = h_BSD[sample_w_segm, -1, :]
798
+ segm_tokens = self.proj_segm(segm_tokens)
799
+ segm_masks = torch.einsum("kdhw,kd->khw", hr_img_features[sample_w_segm], segm_tokens)
800
+ for i, b in enumerate(sample_w_segm):
801
+ aux_output_B[b].append(segm_masks[i])
802
 
803
  # Next step
804
  logits_BSV, h_BSD = self.forward(
 
809
  hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
810
  should_stop_B = should_stop_B.logical_or(hit_stop_B)
811
 
812
+ # Post-process: convert aux outputs to structured results
813
  pixel_mask_batch = batch_inputs["pixel_mask"][:, 0] # (B, H, W)
814
  results = []
815
  for b in range(B):
816
  dets = self._postprocess_aux(
817
  aux_output_B[b], pixel_mask_batch[b], original_sizes[b], segm_threshold,
818
+ task=task,
819
  )
820
  results.append(dets)
821
 
 
894
  orig_hw: tuple[int, int],
895
  threshold: float,
896
  nms_iou_threshold: float = 0.6,
897
+ task: str = "segmentation",
898
  ) -> list[dict]:
899
+ """Convert raw aux outputs into structured detections.
900
+
901
+ For segmentation, returns dicts with ``xy``, ``hw``, and ``mask_rle``.
902
+ For detection, returns dicts with ``xy`` and ``hw`` only.
903
+ """
904
  orig_h, orig_w = orig_hw
905
 
906
+ if task == "detection":
907
+ # Detection-only: aux_list is interleaved coord/size dicts
908
+ detections = []
909
+ xy = None
910
+ for item in aux_list:
911
+ if isinstance(item, dict):
912
+ if "x" in item or "y" in item:
913
+ xy = item
914
+ elif ("h" in item or "w" in item) and xy is not None:
915
+ detections.append({"xy": xy, "hw": item})
916
+ xy = None
917
+ return detections
918
+
919
+ # Segmentation: find active image region from pixel mask
920
  nonzero = torch.nonzero(pixel_mask_hw, as_tuple=False)
921
  if len(nonzero) > 0:
922
  min_h, min_w = nonzero.min(dim=0)[0]
 
927
  min_h = min_w = 0
928
  act_h = act_w = None
929
 
930
+ # Group into triplets: coord, size, mask
931
  candidates = []
932
+ xy = hw = None
933
+ for item in aux_list:
934
+ if isinstance(item, dict):
935
+ if "x" in item or "y" in item:
936
+ xy = item
937
+ hw = None
938
+ elif "h" in item or "w" in item:
939
+ hw = item
940
+ elif isinstance(item, torch.Tensor) and xy is not None and hw is not None:
941
+ mask_logits = item
942
+ if act_h is not None and act_w is not None:
943
+ mask_logits = mask_logits[min_h:min_h + act_h, min_w:min_w + act_w]
944
+ mask_logits = mask_logits.unsqueeze(0).unsqueeze(0).float()
945
+ mask_logits = F.interpolate(mask_logits, size=(orig_h, orig_w), mode="bilinear", align_corners=False)
946
+ mask_logits = mask_logits.squeeze(0).squeeze(0)
947
+ binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
948
+ candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
949
+ xy = hw = None
 
 
 
 
950
 
951
  if not candidates:
952
  return []