Add files using upload-large-folder tool
Browse files- .gitattributes +6 -0
- LICENSE +21 -0
- PET_Finetuned.safetensors +3 -0
- README.md +151 -3
- TechnicalReport.pdf +3 -0
- images/pexels-558331748-30295833.jpg +3 -0
- images/pexels-ilyasajpg-7038431.jpg +3 -0
- images/pexels-peter-almario-388108-19472286.jpg +3 -0
- images/pexels-rafeeque-kodungookaran-374579689-18755903.jpg +3 -0
- images/pexels-wendywei-4945353.jpg +3 -0
- test.py +406 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
TechnicalReport.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
images/pexels-558331748-30295833.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
images/pexels-ilyasajpg-7038431.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
images/pexels-peter-almario-388108-19472286.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
images/pexels-rafeeque-kodungookaran-374579689-18755903.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
images/pexels-wendywei-4945353.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Awiros
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
PET_Finetuned.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab940304e869fff4afe92dd8c2ebff798603ce18f4548e0435aa923bf4f15f39
|
| 3 |
+
size 224692940
|
README.md
CHANGED
|
@@ -1,3 +1,151 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags: [crowd-counting, localization, PET]
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Hierarchical Training on Partial Annotations Enables Density-Robust Crowd Counting and Localization
|
| 10 |
+
|
| 11 |
+
## Abstract
|
| 12 |
+
|
| 13 |
+
Reliable crowd analysis requires both accurate counting and precise head-point
|
| 14 |
+
localization under severe density and scale variation. In practice, dense
|
| 15 |
+
scenes exhibit heavy occlusion and perspective distortion, while the same
|
| 16 |
+
camera can undergo abrupt distribution shifts over time due to zoom and
|
| 17 |
+
viewpoint changes or event dynamics. We present the model, obtained by fine-tuning Point Query Tranformer(PET) on a
|
| 18 |
+
curated, multi-source dataset with partial and heterogeneous annotations. Our
|
| 19 |
+
training recipe combines (i) a hierarchical iterative loop that aligns count
|
| 20 |
+
distributions across partial ground truth, fine-tuned predictions, and the
|
| 21 |
+
pre-trained baseline to guide outlier-driven data refinement, (ii)
|
| 22 |
+
multi-patch resolution training (128x128, 256x256, and 512x512) to reduce
|
| 23 |
+
scale sensitivity, (iii) count-aware patch sampling to mitigate long-tailed
|
| 24 |
+
density skew, and (iv) adaptive background-query loss weighting to prevent
|
| 25 |
+
resolution-dependent background dominance. This approach improves F1 scores
|
| 26 |
+
F1@4px and F1@8px on ShanghaiTech Part A (SHHA), ShanghaiTech Part B (SHHB),
|
| 27 |
+
JHU-Crowd++, and UCF-QNRF, and exhibits more stable behavior during
|
| 28 |
+
sparse-to-dense density transitions.
|
| 29 |
+
|
| 30 |
+
For detailed data curation and training recipe, refer to our technical
|
| 31 |
+
report: [Technical Report](TechnicalReport.pdf).
|
| 32 |
+
|
| 33 |
+
## Evaluation and Results
|
| 34 |
+
|
| 35 |
+
Across four benchmarks, PET-Finetuned shows the strongest overall transfer,
|
| 36 |
+
with consistent gains in both counting and localization on SHHB, UCF-QNRF, and
|
| 37 |
+
JHU-Crowd++. On SHHB, it reduces MAE/MSE to 13.794/22.163 from 19.472/29.651
|
| 38 |
+
(PET-SHHA) and 19.579/28.398 (APGCC-SHHA), while increasing F1@8 to 0.820.
|
| 39 |
+
The same pattern holds on UCF-QNRF (MAE 105.772, MSE 199.544, F1@8 0.738) and
|
| 40 |
+
JHU-Crowd++ (MAE 74.778, MSE 271.886, F1@8 0.698), where PET-Finetuned
|
| 41 |
+
outperforms both references by clear margins. On SHHA, counting error is higher
|
| 42 |
+
than PET-SHHA and APGCC-SHHA (MAE 62.742 vs 48.879/48.725), but localization is
|
| 43 |
+
best in table (F1@4 0.614, F1@8 0.794), indicating a stronger precision-recall
|
| 44 |
+
balance for head-point prediction at both matching thresholds.
|
| 45 |
+
|
| 46 |
+
> **Note (evaluation protocol):** PET-SHHA and APGCC-SHHA numbers in this
|
| 47 |
+
> section can differ from values reported in the original papers. The original
|
| 48 |
+
> works typically train one model per target dataset and evaluate in-domain. In
|
| 49 |
+
> contrast, `PET-Finetuned(Ours)` is initialized from PET-SHHA weights and
|
| 50 |
+
> fine-tuned in our framework. For cross-dataset baseline comparison, we use
|
| 51 |
+
> the best public SHHA Part A checkpoints released by the authors for PET-SHHA
|
| 52 |
+
> and APGCC-SHHA (APGCC publicly provides only the SHHA-best checkpoint).
|
| 53 |
+
> Therefore, the PET-SHHA and APGCC-SHHA rows above reflect transfer from SHHA
|
| 54 |
+
> initialization rather than per-dataset retraining. All metrics in this
|
| 55 |
+
> section are evaluated at `threshold = 0.5`.
|
| 56 |
+
|
| 57 |
+
### ShanghaiTech Part A (SHHA)
|
| 58 |
+
|
| 59 |
+
| Model | MAE | MSE | AP@4px | AR@4px | F1@4px | AP@8px | AR@8px | F1@8px |
|
| 60 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 61 |
+
| PET-Finetuned(Ours) | 62.742 | 102.996 | **0.615** | **0.613** | **0.614** | **0.796** | **0.793** | **0.794** |
|
| 62 |
+
| PET-SHHA | 48.879 | **76.520** | 0.596 | 0.604 | 0.600 | 0.781 | 0.792 | 0.786 |
|
| 63 |
+
| APGCC-SHHA | **48.725** | 76.721 | 0.439 | 0.428 | 0.433 | 0.773 | 0.754 | 0.764 |
|
| 64 |
+
|
| 65 |
+
### ShanghaiTech Part B (SHHB)
|
| 66 |
+
|
| 67 |
+
| Model | MAE | MSE | AP@4px | AR@4px | F1@4px | AP@8px | AR@8px | F1@8px |
|
| 68 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 69 |
+
| PET-Finetuned(Ours) | **13.794** | **22.163** | **0.666** | **0.596** | **0.629** | **0.869** | **0.777** | **0.820** |
|
| 70 |
+
| PET-SHHA | 19.472 | 29.651 | 0.640 | 0.547 | 0.590 | 0.847 | 0.724 | 0.781 |
|
| 71 |
+
| APGCC-SHHA | 19.579 | 28.398 | 0.517 | 0.441 | 0.476 | 0.837 | 0.714 | 0.771 |
|
| 72 |
+
|
| 73 |
+
### UCF-QNRF
|
| 74 |
+
|
| 75 |
+
| Model | MAE | MSE | AP@4px | AR@4px | F1@4px | AP@8px | AR@8px | F1@8px |
|
| 76 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 77 |
+
| PET-Finetuned(Ours) | **105.772** | **199.544** | **0.533** | **0.505** | **0.519** | **0.759** | **0.719** | **0.738** |
|
| 78 |
+
| PET-SHHA | 123.135 | 240.943 | 0.495 | 0.487 | 0.491 | 0.708 | 0.696 | 0.702 |
|
| 79 |
+
| APGCC-SHHA | 126.763 | 228.998 | 0.311 | 0.284 | 0.297 | 0.638 | 0.583 | 0.609 |
|
| 80 |
+
|
| 81 |
+
### JHU-Crowd++
|
| 82 |
+
|
| 83 |
+
| Model | MAE | MSE | AP@4px | AR@4px | F1@4px | AP@8px | AR@8px | F1@8px |
|
| 84 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 85 |
+
| PET-Finetuned(Ours) | **74.778** | **271.886** | **0.467** | **0.491** | **0.479** | **0.681** | **0.715** | **0.698** |
|
| 86 |
+
| PET-SHHA | 115.861 | 393.281 | 0.379 | 0.449 | 0.411 | 0.582 | 0.690 | 0.632 |
|
| 87 |
+
| APGCC-SHHA | 102.461 | 331.883 | 0.303 | 0.330 | 0.316 | 0.578 | 0.630 | 0.603 |
|
| 88 |
+
|
| 89 |
+
## Qualitative Analysis
|
| 90 |
+
|
| 91 |
+
Full-resolution qualitative comparisons in the report use horizontal stacked
|
| 92 |
+
panels ordered as `PET-Finetuned(Ours)`, `PET-SHHA`, and `APGCC-SHHA`, with
|
| 93 |
+
point colors green, yellow, and red. Inference for these comparisons uses
|
| 94 |
+
`threshold = 0.5` and `upper_bound = -1`. Qualitatively,
|
| 95 |
+
`PET-Finetuned(Ours)` shows fewer sparse-scene false positives, stronger
|
| 96 |
+
dense-scene recall under occlusion, and more stable localization under
|
| 97 |
+
perspective and scale variation.
|
| 98 |
+
|
| 99 |
+
[](images/pexels-558331748-30295833.jpg)
|
| 100 |
+
|
| 101 |
+
[](images/pexels-ilyasajpg-7038431.jpg)
|
| 102 |
+
|
| 103 |
+
[](images/pexels-peter-almario-388108-19472286.jpg)
|
| 104 |
+
|
| 105 |
+
[](images/pexels-rafeeque-kodungookaran-374579689-18755903.jpg)
|
| 106 |
+
|
| 107 |
+
[](images/pexels-wendywei-4945353.jpg)
|
| 108 |
+
|
| 109 |
+
## Model Inference
|
| 110 |
+
|
| 111 |
+
Use the official PET repository to run single-image inference with this
|
| 112 |
+
release model.
|
| 113 |
+
|
| 114 |
+
1. Clone PET and move into the repository root.
|
| 115 |
+
```bash
|
| 116 |
+
git clone https://github.com/cxliu0/PET.git
|
| 117 |
+
cd PET
|
| 118 |
+
```
|
| 119 |
+
2. Install dependencies.
|
| 120 |
+
```bash
|
| 121 |
+
pip install -r requirements.txt
|
| 122 |
+
pip install safetensors pillow
|
| 123 |
+
```
|
| 124 |
+
3. Copy `test.py` from this release folder into the PET repository root.
|
| 125 |
+
4. Place `PET_Finetuned.safetensors` in the PET repository root.
|
| 126 |
+
5. Run inference (dummy example).
|
| 127 |
+
```bash
|
| 128 |
+
python test.py \
|
| 129 |
+
--image_path path/to/image.jpg \
|
| 130 |
+
--resume PET_Finetuned.safetensors \
|
| 131 |
+
--device cpu \
|
| 132 |
+
--output_json outputs/prediction.json \
|
| 133 |
+
--output_image outputs/prediction.jpg
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## Summary
|
| 137 |
+
|
| 138 |
+
We present a practical adaptation of PET for density-robust
|
| 139 |
+
crowd counting and head-point localization under partial and heterogeneous
|
| 140 |
+
annotations. The training framework combines a hierarchical iterative
|
| 141 |
+
fine-tuning loop with outlier-driven data refinement, mixed patch-resolution
|
| 142 |
+
optimization (128x128/256x256/512x512), count-aware sampling for dense-scene
|
| 143 |
+
emphasis, and adaptive background-query loss weighting to stabilize supervision
|
| 144 |
+
across scales.
|
| 145 |
+
|
| 146 |
+
Under the reported cross-dataset transfer protocol from SHHA initialization,
|
| 147 |
+
the model achieves the strongest overall transfer on SHHB, UCF-QNRF, and
|
| 148 |
+
JHU-Crowd++, while maintaining the best localization balance on SHHA at both
|
| 149 |
+
matching thresholds. Qualitative evidence is consistent with these trends,
|
| 150 |
+
showing fewer sparse-scene false positives and stronger dense-scene recall
|
| 151 |
+
under occlusion and perspective variation.
|
TechnicalReport.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20927c9dc1d5c4b1a32d5d11c7b804b24273bd345e22e3635ac615ed9b4dc4a2
|
| 3 |
+
size 9209917
|
images/pexels-558331748-30295833.jpg
ADDED
|
Git LFS Details
|
images/pexels-ilyasajpg-7038431.jpg
ADDED
|
Git LFS Details
|
images/pexels-peter-almario-388108-19472286.jpg
ADDED
|
Git LFS Details
|
images/pexels-rafeeque-kodungookaran-374579689-18755903.jpg
ADDED
|
Git LFS Details
|
images/pexels-wendywei-4945353.jpg
ADDED
|
Git LFS Details
|
test.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms as standard_transforms
|
| 11 |
+
|
| 12 |
+
import util.misc as utils
|
| 13 |
+
from models import build_model
|
| 14 |
+
|
| 15 |
+
PET_TRANSFORM = standard_transforms.Compose([
|
| 16 |
+
standard_transforms.ToTensor(),
|
| 17 |
+
standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_args_parser() -> argparse.ArgumentParser:
|
| 22 |
+
parser = argparse.ArgumentParser('PET single-image inference (HF release)', add_help=False)
|
| 23 |
+
|
| 24 |
+
parser.add_argument('--image_path', required=True, type=str,
|
| 25 |
+
help='Path to a single input image.')
|
| 26 |
+
parser.add_argument('--resume', default='PET_Finetuned.safetensors', type=str,
|
| 27 |
+
help='Path to model weights (.safetensors or .pth).')
|
| 28 |
+
parser.add_argument('--device', default='cuda', type=str,
|
| 29 |
+
help='Device for inference, e.g. cuda or cpu.')
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--backbone', default='vgg16_bn', type=str)
|
| 32 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned', 'fourier'))
|
| 33 |
+
parser.add_argument('--dec_layers', default=2, type=int)
|
| 34 |
+
parser.add_argument('--dim_feedforward', default=512, type=int)
|
| 35 |
+
parser.add_argument('--hidden_dim', default=256, type=int)
|
| 36 |
+
parser.add_argument('--dropout', default=0.0, type=float)
|
| 37 |
+
parser.add_argument('--nheads', default=8, type=int)
|
| 38 |
+
|
| 39 |
+
parser.add_argument('--set_cost_class', default=1, type=float)
|
| 40 |
+
parser.add_argument('--set_cost_point', default=0.05, type=float)
|
| 41 |
+
parser.add_argument('--ce_loss_coef', default=1.0, type=float)
|
| 42 |
+
parser.add_argument('--point_loss_coef', default=5.0, type=float)
|
| 43 |
+
parser.add_argument('--eos_coef', default=0.5, type=float)
|
| 44 |
+
|
| 45 |
+
parser.add_argument('--dataset_file', default='SHA')
|
| 46 |
+
parser.add_argument('--data_path', default='./data/ShanghaiTech/PartA', type=str)
|
| 47 |
+
|
| 48 |
+
parser.add_argument('--upper_bound', default=-1, type=int,
|
| 49 |
+
help='Max image side for inference; -1 means only cap at 2560 (same as compare_models).')
|
| 50 |
+
parser.add_argument('--output_image', default='', type=str,
|
| 51 |
+
help='Optional path to save annotated image panel.')
|
| 52 |
+
parser.add_argument('--title_text', default='PET-Finetuned', type=str,
|
| 53 |
+
help='Title prefix used in top panel text.')
|
| 54 |
+
parser.add_argument('--radius', default=3, type=int)
|
| 55 |
+
parser.add_argument('--point_color', default='0,255,0', type=str,
|
| 56 |
+
help='BGR color for points, e.g., 0,255,0')
|
| 57 |
+
parser.add_argument('--panel_long_side', default=1600, type=int,
|
| 58 |
+
help='Resize annotated panel long side to this value.')
|
| 59 |
+
parser.add_argument('--panel_pad', default=24, type=int,
|
| 60 |
+
help='Panel padding around the image and title area.')
|
| 61 |
+
parser.add_argument('--panel_font_size', default=48, type=int,
|
| 62 |
+
help='Font size for panel title text.')
|
| 63 |
+
|
| 64 |
+
parser.add_argument('--output_json', default='', type=str,
|
| 65 |
+
help='Optional output JSON path for prediction details.')
|
| 66 |
+
parser.add_argument('--seed', default=42, type=int)
|
| 67 |
+
|
| 68 |
+
return parser
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def parse_color(color_str: str):
|
| 72 |
+
parts = color_str.split(',')
|
| 73 |
+
if len(parts) != 3:
|
| 74 |
+
raise ValueError('color must be B,G,R like 0,255,0')
|
| 75 |
+
return tuple(int(p.strip()) for p in parts)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def resolve_device(device_str: str) -> torch.device:
|
| 79 |
+
if device_str.startswith('cuda') and not torch.cuda.is_available():
|
| 80 |
+
print('CUDA not available. Falling back to CPU.')
|
| 81 |
+
return torch.device('cpu')
|
| 82 |
+
device = torch.device(device_str)
|
| 83 |
+
if device.type == 'cuda' and device.index is not None:
|
| 84 |
+
torch.cuda.set_device(device.index)
|
| 85 |
+
return device
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def resize_for_eval(frame_rgb, upper_bound):
|
| 89 |
+
h, w = frame_rgb.shape[:2]
|
| 90 |
+
max_size = max(h, w)
|
| 91 |
+
if upper_bound != -1 and max_size > upper_bound:
|
| 92 |
+
scale = float(upper_bound) / float(max_size)
|
| 93 |
+
elif max_size > 2560:
|
| 94 |
+
scale = 2560.0 / float(max_size)
|
| 95 |
+
else:
|
| 96 |
+
scale = 1.0
|
| 97 |
+
if scale == 1.0:
|
| 98 |
+
return frame_rgb, scale
|
| 99 |
+
new_w = max(1, int(round(w * scale)))
|
| 100 |
+
new_h = max(1, int(round(h * scale)))
|
| 101 |
+
resized = cv2.resize(frame_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 102 |
+
return resized, scale
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_font(font_size=40, bold=False, font_paths=None):
|
| 106 |
+
if font_paths is None:
|
| 107 |
+
if bold:
|
| 108 |
+
font_paths = [
|
| 109 |
+
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
|
| 110 |
+
'/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf',
|
| 111 |
+
'/usr/share/fonts/truetype/freefont/FreeSansBold.ttf',
|
| 112 |
+
]
|
| 113 |
+
else:
|
| 114 |
+
font_paths = [
|
| 115 |
+
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
| 116 |
+
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
|
| 117 |
+
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
|
| 118 |
+
]
|
| 119 |
+
for font_path in font_paths:
|
| 120 |
+
if os.path.exists(font_path):
|
| 121 |
+
try:
|
| 122 |
+
return ImageFont.truetype(font_path, font_size)
|
| 123 |
+
except OSError:
|
| 124 |
+
continue
|
| 125 |
+
try:
|
| 126 |
+
fallback = 'DejaVuSans-Bold.ttf' if bold else 'DejaVuSans.ttf'
|
| 127 |
+
return ImageFont.truetype(fallback, font_size)
|
| 128 |
+
except OSError:
|
| 129 |
+
return ImageFont.load_default()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def draw_text(draw, xy, text, font, fill, bold=False, stroke_width=0):
|
| 133 |
+
if bold and stroke_width <= 0:
|
| 134 |
+
stroke_width = 2
|
| 135 |
+
try:
|
| 136 |
+
if bold:
|
| 137 |
+
draw.text(
|
| 138 |
+
xy,
|
| 139 |
+
text,
|
| 140 |
+
fill=fill,
|
| 141 |
+
font=font,
|
| 142 |
+
stroke_width=stroke_width,
|
| 143 |
+
stroke_fill=fill,
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
draw.text(xy, text, fill=fill, font=font)
|
| 147 |
+
except TypeError:
|
| 148 |
+
if bold:
|
| 149 |
+
offsets = [(0, 0), (1, 0), (0, 1), (1, 1)]
|
| 150 |
+
for dx, dy in offsets:
|
| 151 |
+
draw.text((xy[0] + dx, xy[1] + dy), text, fill=fill, font=font)
|
| 152 |
+
else:
|
| 153 |
+
draw.text(xy, text, fill=fill, font=font)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _get_text_size(draw, text, font, bold=False, stroke_width=0):
|
| 157 |
+
if hasattr(draw, 'textbbox'):
|
| 158 |
+
try:
|
| 159 |
+
x0, y0, x1, y1 = draw.textbbox(
|
| 160 |
+
(0, 0),
|
| 161 |
+
text,
|
| 162 |
+
font=font,
|
| 163 |
+
stroke_width=stroke_width if bold else 0,
|
| 164 |
+
)
|
| 165 |
+
except TypeError:
|
| 166 |
+
x0, y0, x1, y1 = draw.textbbox((0, 0), text, font=font)
|
| 167 |
+
return x1 - x0, y1 - y0
|
| 168 |
+
w, h = draw.textsize(text, font=font)
|
| 169 |
+
if bold:
|
| 170 |
+
w += stroke_width * 2
|
| 171 |
+
h += stroke_width * 2
|
| 172 |
+
return w, h
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def fit_text_to_width(draw, text, font, max_w, bold=False, stroke_width=0):
|
| 176 |
+
text = text or ''
|
| 177 |
+
if max_w <= 0:
|
| 178 |
+
return ''
|
| 179 |
+
|
| 180 |
+
text_w, _ = _get_text_size(draw, text, font, bold=bold, stroke_width=stroke_width)
|
| 181 |
+
if text_w <= max_w:
|
| 182 |
+
return text
|
| 183 |
+
|
| 184 |
+
ellipsis = '...'
|
| 185 |
+
ellipsis_w, _ = _get_text_size(draw, ellipsis, font, bold=bold, stroke_width=stroke_width)
|
| 186 |
+
if ellipsis_w > max_w:
|
| 187 |
+
return ''
|
| 188 |
+
|
| 189 |
+
trimmed = text
|
| 190 |
+
while trimmed:
|
| 191 |
+
trimmed = trimmed[:-1]
|
| 192 |
+
candidate = trimmed + ellipsis
|
| 193 |
+
cand_w, _ = _get_text_size(draw, candidate, font, bold=bold, stroke_width=stroke_width)
|
| 194 |
+
if cand_w <= max_w:
|
| 195 |
+
return candidate
|
| 196 |
+
return ellipsis
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def bgr_to_rgb(color):
|
| 200 |
+
return (color[2], color[1], color[0])
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def resize_with_points(img, pts, target_long_side):
|
| 204 |
+
if target_long_side is None or target_long_side <= 0:
|
| 205 |
+
return img, pts
|
| 206 |
+
w, h = img.size
|
| 207 |
+
max_dim = max(w, h)
|
| 208 |
+
if max_dim <= 0 or max_dim == target_long_side:
|
| 209 |
+
return img, pts
|
| 210 |
+
scale = float(target_long_side) / float(max_dim)
|
| 211 |
+
new_w = max(1, int(round(w * scale)))
|
| 212 |
+
new_h = max(1, int(round(h * scale)))
|
| 213 |
+
img = img.resize((new_w, new_h), Image.BILINEAR)
|
| 214 |
+
if pts is not None and pts.size > 0:
|
| 215 |
+
pts = pts * scale
|
| 216 |
+
return img, pts
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def add_padding_with_text(img, text, pad, font, text_color, bg_color, bold, stroke_width):
|
| 220 |
+
if pad is None or pad <= 0:
|
| 221 |
+
return img
|
| 222 |
+
draw_tmp = ImageDraw.Draw(img)
|
| 223 |
+
text = text or ''
|
| 224 |
+
text_w, text_h = _get_text_size(draw_tmp, text, font, bold=bold, stroke_width=stroke_width)
|
| 225 |
+
min_text_gap = 24
|
| 226 |
+
min_pad = text_h + (2 * min_text_gap)
|
| 227 |
+
pad = max(pad, min_pad)
|
| 228 |
+
new_w = img.width + pad * 2
|
| 229 |
+
new_h = img.height + pad * 2
|
| 230 |
+
canvas = Image.new('RGB', (new_w, new_h), color=bg_color)
|
| 231 |
+
canvas.paste(img, (pad, pad))
|
| 232 |
+
|
| 233 |
+
draw = ImageDraw.Draw(canvas)
|
| 234 |
+
max_text_w = max(0, new_w - (2 * pad))
|
| 235 |
+
text = fit_text_to_width(draw, text, font, max_text_w, bold=bold, stroke_width=stroke_width)
|
| 236 |
+
text_w, text_h = _get_text_size(draw, text, font, bold=bold, stroke_width=stroke_width)
|
| 237 |
+
text_x = pad
|
| 238 |
+
text_y = max(min_text_gap, (pad - text_h) // 2)
|
| 239 |
+
text_y = min(text_y, max(0, pad - text_h - min_text_gap))
|
| 240 |
+
draw_text(draw, (text_x, text_y), text, font, text_color, bold=bold, stroke_width=stroke_width)
|
| 241 |
+
return canvas
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def annotate_panel(
|
| 245 |
+
img_bgr,
|
| 246 |
+
pts,
|
| 247 |
+
title_text,
|
| 248 |
+
point_color_bgr,
|
| 249 |
+
radius,
|
| 250 |
+
font,
|
| 251 |
+
text_color,
|
| 252 |
+
title_bg,
|
| 253 |
+
target_long_side,
|
| 254 |
+
pad,
|
| 255 |
+
):
|
| 256 |
+
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 257 |
+
img = Image.fromarray(rgb)
|
| 258 |
+
img, pts = resize_with_points(img, pts, target_long_side)
|
| 259 |
+
draw = ImageDraw.Draw(img)
|
| 260 |
+
|
| 261 |
+
max_dim = max(img.width, img.height)
|
| 262 |
+
auto_radius = max(3, int(round(max_dim * 0.004)))
|
| 263 |
+
if radius is None or radius < auto_radius:
|
| 264 |
+
radius = auto_radius
|
| 265 |
+
|
| 266 |
+
if pts is not None and pts.size > 0:
|
| 267 |
+
color = bgr_to_rgb(point_color_bgr)
|
| 268 |
+
for x, y in pts:
|
| 269 |
+
x0 = x - radius
|
| 270 |
+
y0 = y - radius
|
| 271 |
+
x1 = x + radius
|
| 272 |
+
y1 = y + radius
|
| 273 |
+
draw.ellipse((x0, y0, x1, y1), fill=color, outline=color)
|
| 274 |
+
|
| 275 |
+
return add_padding_with_text(
|
| 276 |
+
img,
|
| 277 |
+
title_text or '',
|
| 278 |
+
pad,
|
| 279 |
+
font,
|
| 280 |
+
text_color,
|
| 281 |
+
title_bg,
|
| 282 |
+
bold=False,
|
| 283 |
+
stroke_width=0,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _load_state_dict(weight_path: Path):
|
| 288 |
+
if not weight_path.exists():
|
| 289 |
+
raise FileNotFoundError(f'Weights file not found: {weight_path}')
|
| 290 |
+
|
| 291 |
+
if weight_path.suffix == '.safetensors':
|
| 292 |
+
try:
|
| 293 |
+
from safetensors.torch import load_file as load_safetensors
|
| 294 |
+
except ImportError as exc:
|
| 295 |
+
raise ImportError(
|
| 296 |
+
'safetensors is required to load .safetensors weights. Install with: pip install safetensors'
|
| 297 |
+
) from exc
|
| 298 |
+
return load_safetensors(str(weight_path), device='cpu')
|
| 299 |
+
|
| 300 |
+
checkpoint = torch.load(str(weight_path), map_location='cpu')
|
| 301 |
+
if isinstance(checkpoint, dict) and 'model' in checkpoint and isinstance(checkpoint['model'], dict):
|
| 302 |
+
return checkpoint['model']
|
| 303 |
+
if isinstance(checkpoint, dict) and checkpoint and all(torch.is_tensor(v) for v in checkpoint.values()):
|
| 304 |
+
return checkpoint
|
| 305 |
+
raise ValueError(
|
| 306 |
+
'Unsupported checkpoint format. Expected .safetensors or .pth containing a model state_dict.'
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
def infer_pet_points(model, frame_bgr, device, upper_bound):
|
| 312 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 313 |
+
resized_rgb, scale = resize_for_eval(frame_rgb, upper_bound)
|
| 314 |
+
resized_h, resized_w = resized_rgb.shape[:2]
|
| 315 |
+
|
| 316 |
+
img = Image.fromarray(resized_rgb)
|
| 317 |
+
img = PET_TRANSFORM(img)
|
| 318 |
+
samples = utils.nested_tensor_from_tensor_list([img]).to(device)
|
| 319 |
+
img_h, img_w = samples.tensors.shape[-2:]
|
| 320 |
+
|
| 321 |
+
outputs = model(samples, test=True)
|
| 322 |
+
outputs_points = outputs['pred_points']
|
| 323 |
+
if outputs_points.dim() == 3:
|
| 324 |
+
outputs_points = outputs_points[0]
|
| 325 |
+
pred_points = outputs_points.detach().cpu().numpy()
|
| 326 |
+
|
| 327 |
+
if pred_points.size == 0:
|
| 328 |
+
return np.zeros((0, 2), dtype=np.float32), scale
|
| 329 |
+
|
| 330 |
+
pred_points[:, 0] *= float(img_h)
|
| 331 |
+
pred_points[:, 1] *= float(img_w)
|
| 332 |
+
|
| 333 |
+
pred_points[:, 0] = np.clip(pred_points[:, 0], 0.0, float(resized_h - 1))
|
| 334 |
+
pred_points[:, 1] = np.clip(pred_points[:, 1], 0.0, float(resized_w - 1))
|
| 335 |
+
|
| 336 |
+
if scale != 1.0:
|
| 337 |
+
pred_points = pred_points / float(scale)
|
| 338 |
+
|
| 339 |
+
orig_h, orig_w = frame_bgr.shape[:2]
|
| 340 |
+
pred_points[:, 0] = np.clip(pred_points[:, 0], 0.0, float(orig_h - 1))
|
| 341 |
+
pred_points[:, 1] = np.clip(pred_points[:, 1], 0.0, float(orig_w - 1))
|
| 342 |
+
|
| 343 |
+
points_xy = np.stack([pred_points[:, 1], pred_points[:, 0]], axis=1)
|
| 344 |
+
return points_xy, scale
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def main(args) -> None:
|
| 348 |
+
device = resolve_device(args.device)
|
| 349 |
+
|
| 350 |
+
model, _ = build_model(args)
|
| 351 |
+
model.to(device)
|
| 352 |
+
model.eval()
|
| 353 |
+
|
| 354 |
+
state_dict = _load_state_dict(Path(args.resume))
|
| 355 |
+
model.load_state_dict(state_dict, strict=True)
|
| 356 |
+
|
| 357 |
+
image_path = Path(args.image_path)
|
| 358 |
+
frame_bgr = cv2.imread(str(image_path))
|
| 359 |
+
if frame_bgr is None:
|
| 360 |
+
raise ValueError(f'Failed to read image: {image_path}')
|
| 361 |
+
|
| 362 |
+
points_xy, scale = infer_pet_points(model, frame_bgr, device, args.upper_bound)
|
| 363 |
+
count = int(points_xy.shape[0]) if points_xy.size > 0 else 0
|
| 364 |
+
|
| 365 |
+
result = {
|
| 366 |
+
'image': str(image_path),
|
| 367 |
+
'count': count,
|
| 368 |
+
'points_xy': points_xy.tolist(),
|
| 369 |
+
'scale': scale,
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
print(f'image: {result["image"]}')
|
| 373 |
+
print(f'predicted_count: {result["count"]}')
|
| 374 |
+
|
| 375 |
+
if args.output_json:
|
| 376 |
+
output_json = Path(args.output_json)
|
| 377 |
+
output_json.parent.mkdir(parents=True, exist_ok=True)
|
| 378 |
+
output_json.write_text(json.dumps(result, indent=2))
|
| 379 |
+
print(f'json_saved_to: {output_json}')
|
| 380 |
+
|
| 381 |
+
if args.output_image:
|
| 382 |
+
output_image = Path(args.output_image)
|
| 383 |
+
output_image.parent.mkdir(parents=True, exist_ok=True)
|
| 384 |
+
|
| 385 |
+
panel = annotate_panel(
|
| 386 |
+
frame_bgr,
|
| 387 |
+
points_xy,
|
| 388 |
+
f'{args.title_text} Count : {count}',
|
| 389 |
+
parse_color(args.point_color),
|
| 390 |
+
args.radius,
|
| 391 |
+
load_font(font_size=args.panel_font_size, bold=False),
|
| 392 |
+
text_color=(0, 0, 0),
|
| 393 |
+
title_bg=(255, 255, 255),
|
| 394 |
+
target_long_side=args.panel_long_side,
|
| 395 |
+
pad=args.panel_pad,
|
| 396 |
+
)
|
| 397 |
+
panel.save(str(output_image))
|
| 398 |
+
print(f'annotated_image_saved_to: {output_image}')
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
if __name__ == '__main__':
|
| 402 |
+
parser = argparse.ArgumentParser(
|
| 403 |
+
'PET single-image inference',
|
| 404 |
+
parents=[get_args_parser()],
|
| 405 |
+
)
|
| 406 |
+
main(parser.parse_args())
|