Commit ·
29cb4d8
1
Parent(s): 7de983f
First commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +1 -0
- LICENSE.md +66 -0
- README.md +118 -4
- assets/overview.png +3 -0
- assets/pikachu.png +3 -0
- assets/pikachu_seg.png +3 -0
- assets/qualitatives.png +3 -0
- assets/qualitatives/cityscapes/1_clipdinoiser.png +3 -0
- assets/qualitatives/cityscapes/1_freeda.png +3 -0
- assets/qualitatives/cityscapes/1_gt.png +3 -0
- assets/qualitatives/cityscapes/1_image.png +3 -0
- assets/qualitatives/cityscapes/1_proxyclip.png +3 -0
- assets/qualitatives/cityscapes/1_talk2dino.png +3 -0
- assets/qualitatives/cityscapes/1r_clipdinoiser.png +3 -0
- assets/qualitatives/cityscapes/1r_freeda.png +3 -0
- assets/qualitatives/cityscapes/1r_gt.png +3 -0
- assets/qualitatives/cityscapes/1r_image.png +3 -0
- assets/qualitatives/cityscapes/1r_proxyclip.png +3 -0
- assets/qualitatives/cityscapes/1r_talk2dino.png +3 -0
- assets/qualitatives/context/1r_clipdinoiser.png +3 -0
- assets/qualitatives/context/1r_freeda.png +3 -0
- assets/qualitatives/context/1r_gt.png +3 -0
- assets/qualitatives/context/1r_img.png +3 -0
- assets/qualitatives/context/1r_proxy.png +3 -0
- assets/qualitatives/context/1r_talk2dino.png +3 -0
- assets/qualitatives/object/2r_clipdinoiser.png +3 -0
- assets/qualitatives/object/2r_freeda.png +3 -0
- assets/qualitatives/object/2r_gt.png +3 -0
- assets/qualitatives/object/2r_img.png +3 -0
- assets/qualitatives/object/2r_proxy.png +3 -0
- assets/qualitatives/object/2r_talk2dino.png +3 -0
- assets/qualitatives/voc/1_clipdinoiser.png +3 -0
- assets/qualitatives/voc/1_freeda.png +3 -0
- assets/qualitatives/voc/1_gt.png +3 -0
- assets/qualitatives/voc/1_img.jpg +0 -0
- assets/qualitatives/voc/1_proxy.png +3 -0
- assets/qualitatives/voc/1_talk2dino.png +3 -0
- assets/qualitatives/voc/2_clipdinoiser.png +3 -0
- assets/qualitatives/voc/2_freeda.png +3 -0
- assets/qualitatives/voc/2_gt.png +3 -0
- assets/qualitatives/voc/2_img.jpg +0 -0
- assets/qualitatives/voc/2_proxy.png +3 -0
- assets/qualitatives/voc/2_talk2dino.png +3 -0
- config.json +6 -0
- configuration_talk2dino.py +49 -0
- dinotext.py +399 -0
- hf_demo.ipynb +0 -0
- hooks.py +52 -0
- masker.py +246 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
LICENSE.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DINOv3 License
|
| 2 |
+
|
| 3 |
+
*Last Updated: August 19, 2025*
|
| 4 |
+
|
| 5 |
+
**“Agreement”** means the terms and conditions for use, reproduction, distribution and modification of the DINO Materials set forth herein.
|
| 6 |
+
|
| 7 |
+
**“DINO Materials”** means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 8 |
+
|
| 9 |
+
**“Documentation”** means the specifications, manuals and documentation accompanying
|
| 10 |
+
DINO Materials distributed by Meta.
|
| 11 |
+
|
| 12 |
+
**“Licensee”** or **“you”** means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 13 |
+
|
| 14 |
+
**“Meta”** or **“we”** means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 15 |
+
|
| 16 |
+
**“Sanctions”** means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
|
| 17 |
+
|
| 18 |
+
**“Trade Controls”** means any of the following: Sanctions and applicable export and import controls.
|
| 19 |
+
|
| 20 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the DINO Materials, you agree to be bound by this Agreement.
|
| 21 |
+
|
| 22 |
+
## 1. License Rights and Redistribution.
|
| 23 |
+
|
| 24 |
+
a. <ins>Grant of Rights</ins>. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the DINO Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DINO Materials.
|
| 25 |
+
|
| 26 |
+
b. <ins>Redistribution and Use</ins>.
|
| 27 |
+
|
| 28 |
+
i. Distribution of DINO Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the DINO Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such DINO Materials.
|
| 29 |
+
|
| 30 |
+
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with DINO Materials, you must acknowledge the use of DINO Materials in your publication.
|
| 31 |
+
|
| 32 |
+
iii. Your use of the DINO Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
|
| 33 |
+
|
| 34 |
+
iv. Your use of the DINO Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the DINO Materials.
|
| 35 |
+
|
| 36 |
+
v. You are not the target of Trade Controls and your use of DINO Materials must comply with Trade Controls. You agree not to use, or permit others to use, DINO Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
|
| 37 |
+
|
| 38 |
+
## 2. User Support.
|
| 39 |
+
|
| 40 |
+
Your use of the DINO Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the DINO Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 41 |
+
|
| 42 |
+
## 3. Disclaimer of Warranty.
|
| 43 |
+
|
| 44 |
+
UNLESS REQUIRED BY APPLICABLE LAW, THE DINO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DINO MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DINO MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 45 |
+
|
| 46 |
+
## 4. Limitation of Liability.
|
| 47 |
+
|
| 48 |
+
IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 49 |
+
|
| 50 |
+
## 5. Intellectual Property.
|
| 51 |
+
|
| 52 |
+
a. Subject to Meta’s ownership of DINO Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the DINO Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 53 |
+
|
| 54 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DINO Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the DINO Materials.
|
| 55 |
+
|
| 56 |
+
## 6. Term and Termination.
|
| 57 |
+
|
| 58 |
+
The term of this Agreement will commence upon your acceptance of this Agreement or access to the DINO Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DINO Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
| 59 |
+
|
| 60 |
+
## 7. Governing Law and Jurisdiction.
|
| 61 |
+
|
| 62 |
+
This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 63 |
+
|
| 64 |
+
## 8. Modifications and Amendments.
|
| 65 |
+
|
| 66 |
+
Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the DINO Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
README.md
CHANGED
|
@@ -1,10 +1,124 @@
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
tags:
|
| 3 |
- model_hub_mixin
|
| 4 |
- pytorch_model_hub_mixin
|
|
|
|
|
|
|
|
|
|
| 5 |
---
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
license: other
|
| 3 |
+
license_name: dinov3-license
|
| 4 |
+
pipeline_tag: image-segmentation
|
| 5 |
+
library_name: Pytorch
|
| 6 |
tags:
|
| 7 |
- model_hub_mixin
|
| 8 |
- pytorch_model_hub_mixin
|
| 9 |
+
- DINOv3
|
| 10 |
+
- CLIP
|
| 11 |
+
- open-vocabulary segmentation
|
| 12 |
---
|
| 13 |
|
| 14 |
+
<div align="center">
|
| 15 |
+
<h1>
|
| 16 |
+
Talking to DINO: Bridging Self-Supervised Vision Backbones with Language for Open-Vocabulary Segmentation (ICCV 2025)
|
| 17 |
+
</h1>
|
| 18 |
+
|
| 19 |
+
<h3>
|
| 20 |
+
<a href="https://www.linkedin.com/in/luca-barsellotti/">Luca Barsellotti*</a> 
|
| 21 |
+
<a href="https://www.linkedin.com/in/lorenzo-bianchi-893bb225a/">Lorenzo Bianchi*</a> 
|
| 22 |
+
<a href="https://www.linkedin.com/in/nicola-messina-a33848164/">Nicola Messina</a> 
|
| 23 |
+
<a href="https://www.linkedin.com/in/fabio-carrara-b28a2b111/">Fabio Carrara</a> 
|
| 24 |
+
<a href="https://aimagelab.ing.unimore.it/imagelab/person.asp?idpersona=90">Marcella Cornia</a> 
|
| 25 |
+
<a href="https://www.lorenzobaraldi.com/">Lorenzo Baraldi</a> 
|
| 26 |
+
<a href="https://fabriziofalchi.it">Fabrizio Falchi</a> 
|
| 27 |
+
<a href="https://www.linkedin.com/in/rita-cucchiara-a4653a13/">Rita Cucchiara</a>
|
| 28 |
+
</h3>
|
| 29 |
+
|
| 30 |
+
[Project Page](https://lorebianchi98.github.io/Talk2DINO/) | [Paper](http://arxiv.org/abs/2411.19331) | [Code](https://github.com/lorebianchi98/Talk2DINO)
|
| 31 |
+
|
| 32 |
+
</div>
|
| 33 |
+
|
| 34 |
+
<div align="center">
|
| 35 |
+
<figure>
|
| 36 |
+
<img alt="Overview of Talk2DINO" src="./assets/overview.png" width="90%">
|
| 37 |
+
</figure>
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
## About
|
| 41 |
+
Open-Vocabulary Segmentation (OVS) aims at segmenting images from free-form textual concepts without predefined training classes. While existing vision-language models such as CLIP can generate segmentation masks by leveraging coarse spatial information from Vision Transformers, they face challenges in spatial localization due to their global alignment of image and text features. Conversely, self-supervised visual models like DINO excel in fine-grained visual encoding but lack integration with language. To bridge this gap, we present Talk2DINO, a novel hybrid approach that combines the spatial accuracy of DINOv2 with the language understanding of CLIP. Our approach aligns the textual embeddings of CLIP to the patch-level features of DINOv2 through a learned mapping function without the need to fine-tune the underlying backbones. At training time, we exploit the attention maps of DINOv2 to selectively align local visual patches with textual embeddings. We show that the powerful semantic and localization abilities of Talk2DINO can enhance the segmentation process, resulting in more natural and less noisy segmentations, and that our approach can also effectively distinguish foreground objects from the background. Experimental results demonstrate that Talk2DINO achieves state-of-the-art performance across several unsupervised OVS benchmarks.
|
| 42 |
+
|
| 43 |
+
## Sample Usage
|
| 44 |
+
|
| 45 |
+
### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
|
| 46 |
+
We can use Talk2DINO to map CLIP text embeddings into the DINOv3 patch embedding space.
|
| 47 |
+
```python
|
| 48 |
+
from transformers import AutoModel
|
| 49 |
+
from torchvision.io import read_image
|
| 50 |
+
|
| 51 |
+
# Device setup
|
| 52 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 53 |
+
|
| 54 |
+
# Model Loading
|
| 55 |
+
model = AutoModel.from_pretrained("lorebianchi98/Talk2DINO_v3-ViTL").to(device).eval()
|
| 56 |
+
|
| 57 |
+
# Embedding generation
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
text_embed = model.encode_text("a pikachu")
|
| 60 |
+
image_embed = model.encode_image(image)
|
| 61 |
+
|
| 62 |
+
# normalize the features to perform cosine similarity
|
| 63 |
+
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
| 64 |
+
image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
|
| 65 |
+
|
| 66 |
+
similarity = (image_embed @ text_embed.T).squeeze(0, -1).cpu().numpy()
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Demo
|
| 70 |
+
In `demo.ipynb` we provide a simple example on how to use Talk2DINO for inference on a given image with custom textual categories.
|
| 71 |
+
Result:
|
| 72 |
+
<div align="center">
|
| 73 |
+
<table><tr><td><figure>
|
| 74 |
+
<img alt="" src="./assets/pikachu.png" width=300>
|
| 75 |
+
</figure></td><td><figure>
|
| 76 |
+
<img alt="" src="./assets/pikachu_seg.png" width=300>
|
| 77 |
+
</figure></td></tr></table>
|
| 78 |
+
</div>
|
| 79 |
+
|
| 80 |
+
## Installation
|
| 81 |
+
|
| 82 |
+
To use the **Hugging Face interface** for inference:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
# Clone the repository
|
| 86 |
+
git clone https://huggingface.co/lorebianchi98/Talk2DINO-ViTB
|
| 87 |
+
cd Talk2DINO-ViTB
|
| 88 |
+
|
| 89 |
+
# Install dependencies
|
| 90 |
+
pip install -r requirements.txt
|
| 91 |
+
|
| 92 |
+
# Install PyTorch and torchvision with the appropriate CUDA version
|
| 93 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
For the **full MMCV interface** to perform evaluation on segmentation benchmarks, please refer to the [original Talk2DINO repository](https://github.com/lorebianchi98/Talk2DINO).
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
<details>
|
| 101 |
+
<summary>Qualitative Results</summary>
|
| 102 |
+
|
| 103 |
+
| **Image** | **Ground Truth** | **FreeDA** | **ProxyCLIP** | **CLIP-DINOiser** | **Ours (Talk2DINO)** |
|
| 104 |
+
|-----------|------------------|------------|---------------|-------------------|------------------|
|
| 105 |
+
|  |  |  |  |  |  |
|
| 106 |
+
|  |  |  |  |  |  |
|
| 107 |
+
|  |  |  |  |  |  |
|
| 108 |
+
|  |  |  |  |  |  |
|
| 109 |
+
</details>
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
## Reference
|
| 113 |
+
If you found this code useful, please cite the following paper:
|
| 114 |
+
```
|
| 115 |
+
@misc{barsellotti2024talkingdinobridgingselfsupervised,
|
| 116 |
+
title={Talking to DINO: Bridging Self-Supervised Vision Backbones with Language for Open-Vocabulary Segmentation},
|
| 117 |
+
author={Luca Barsellotti and Lorenzo Bianchi and Nicola Messina and Fabio Carrara and Marcella Cornia and Lorenzo Baraldi and Fabrizio Falchi and Rita Cucchiara},
|
| 118 |
+
year={2024},
|
| 119 |
+
eprint={2411.19331},
|
| 120 |
+
archivePrefix={arXiv},
|
| 121 |
+
primaryClass={cs.CV},
|
| 122 |
+
url={https://arxiv.org/abs/2411.19331},
|
| 123 |
+
}
|
| 124 |
+
```
|
assets/overview.png
ADDED
|
Git LFS Details
|
assets/pikachu.png
ADDED
|
Git LFS Details
|
assets/pikachu_seg.png
ADDED
|
Git LFS Details
|
assets/qualitatives.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1_clipdinoiser.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1_freeda.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1_gt.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1_image.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1_proxyclip.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1_talk2dino.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1r_clipdinoiser.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1r_freeda.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1r_gt.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1r_image.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1r_proxyclip.png
ADDED
|
Git LFS Details
|
assets/qualitatives/cityscapes/1r_talk2dino.png
ADDED
|
Git LFS Details
|
assets/qualitatives/context/1r_clipdinoiser.png
ADDED
|
Git LFS Details
|
assets/qualitatives/context/1r_freeda.png
ADDED
|
Git LFS Details
|
assets/qualitatives/context/1r_gt.png
ADDED
|
Git LFS Details
|
assets/qualitatives/context/1r_img.png
ADDED
|
Git LFS Details
|
assets/qualitatives/context/1r_proxy.png
ADDED
|
Git LFS Details
|
assets/qualitatives/context/1r_talk2dino.png
ADDED
|
Git LFS Details
|
assets/qualitatives/object/2r_clipdinoiser.png
ADDED
|
Git LFS Details
|
assets/qualitatives/object/2r_freeda.png
ADDED
|
Git LFS Details
|
assets/qualitatives/object/2r_gt.png
ADDED
|
Git LFS Details
|
assets/qualitatives/object/2r_img.png
ADDED
|
Git LFS Details
|
assets/qualitatives/object/2r_proxy.png
ADDED
|
Git LFS Details
|
assets/qualitatives/object/2r_talk2dino.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/1_clipdinoiser.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/1_freeda.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/1_gt.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/1_img.jpg
ADDED
|
assets/qualitatives/voc/1_proxy.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/1_talk2dino.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/2_clipdinoiser.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/2_freeda.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/2_gt.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/2_img.jpg
ADDED
|
assets/qualitatives/voc/2_proxy.png
ADDED
|
Git LFS Details
|
assets/qualitatives/voc/2_talk2dino.png
ADDED
|
Git LFS Details
|
config.json
CHANGED
|
@@ -1,4 +1,10 @@
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"avg_self_attn_token": false,
|
| 3 |
"clip_model_name": "ViT-B/16",
|
| 4 |
"disentangled_self_attn_token": true,
|
|
|
|
| 1 |
{
|
| 2 |
+
"architectures": ["Talk2DINO"],
|
| 3 |
+
"model_type": "talk2dino",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configuration_talk2dino.Talk2DINOConfig",
|
| 6 |
+
"AutoModel": "modeling_talk2dino.Talk2DINO"
|
| 7 |
+
},
|
| 8 |
"avg_self_attn_token": false,
|
| 9 |
"clip_model_name": "ViT-B/16",
|
| 10 |
"disentangled_self_attn_token": true,
|
configuration_talk2dino.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class Talk2DINOConfig(PretrainedConfig):
|
| 5 |
+
model_type = "talk2dino"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
avg_self_attn_token=False,
|
| 10 |
+
clip_model_name="ViT-B/16",
|
| 11 |
+
disentangled_self_attn_token=True,
|
| 12 |
+
is_eval=True,
|
| 13 |
+
keep_cls=False,
|
| 14 |
+
keep_end_seq=False,
|
| 15 |
+
loss=None,
|
| 16 |
+
model_name="dinov2_vitb14_reg",
|
| 17 |
+
pre_trained=True,
|
| 18 |
+
proj_class="vitb_mlp_infonce",
|
| 19 |
+
proj_model="ProjectionLayer",
|
| 20 |
+
proj_name="vitb_mlp_infonce",
|
| 21 |
+
resize_dim=518,
|
| 22 |
+
type="DINOText",
|
| 23 |
+
unfreeze_last_image_layer=False,
|
| 24 |
+
unfreeze_last_text_layer=False,
|
| 25 |
+
use_avg_text_token=False,
|
| 26 |
+
with_bg_clean=False,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
|
| 31 |
+
# Store all parameters
|
| 32 |
+
self.avg_self_attn_token = avg_self_attn_token
|
| 33 |
+
self.clip_model_name = clip_model_name
|
| 34 |
+
self.disentangled_self_attn_token = disentangled_self_attn_token
|
| 35 |
+
self.is_eval = is_eval
|
| 36 |
+
self.keep_cls = keep_cls
|
| 37 |
+
self.keep_end_seq = keep_end_seq
|
| 38 |
+
self.loss = loss
|
| 39 |
+
self.model_name = model_name
|
| 40 |
+
self.pre_trained = pre_trained
|
| 41 |
+
self.proj_class = proj_class
|
| 42 |
+
self.proj_model = proj_model
|
| 43 |
+
self.proj_name = proj_name
|
| 44 |
+
self.resize_dim = resize_dim
|
| 45 |
+
self.type = type
|
| 46 |
+
self.unfreeze_last_image_layer = unfreeze_last_image_layer
|
| 47 |
+
self.unfreeze_last_text_layer = unfreeze_last_text_layer
|
| 48 |
+
self.use_avg_text_token = use_avg_text_token
|
| 49 |
+
self.with_bg_clean = with_bg_clean
|
dinotext.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
from math import sqrt
|
| 5 |
+
import re
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import timm
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torchvision
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from transformers import BertModel, AutoTokenizer
|
| 16 |
+
import torchvision.transforms as T
|
| 17 |
+
import clip
|
| 18 |
+
import importlib
|
| 19 |
+
from .us import normalize
|
| 20 |
+
|
| 21 |
+
from .pamr import PAMR
|
| 22 |
+
from .masker import DINOTextMasker
|
| 23 |
+
from .templates import get_template
|
| 24 |
+
|
| 25 |
+
from .model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
|
| 26 |
+
from .hooks import average_text_tokens, get_vit_out, feats
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DINOText(nn.Module):
|
| 32 |
+
|
| 33 |
+
def get_self_attention(self, module, input, output):
|
| 34 |
+
self.feats['self_attn'] = output
|
| 35 |
+
|
| 36 |
+
def get_clip_second_last_dense_out(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 37 |
+
self.feats['clip_second_last_out'] = output
|
| 38 |
+
self.feats['clip_second_last_out'].to(dtype=torch.float32)
|
| 39 |
+
|
| 40 |
+
def get_all_out_tokens(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 41 |
+
self.feats['clip_txt_out_tokens'] = output
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True,
|
| 45 |
+
unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs
|
| 46 |
+
):
|
| 47 |
+
nn.Module.__init__(self)
|
| 48 |
+
self.feats = {}
|
| 49 |
+
self.model_name = model_name
|
| 50 |
+
# loading the model
|
| 51 |
+
|
| 52 |
+
if 'dinov2' in model_name:
|
| 53 |
+
self.model_family = 'facebookresearch/dinov2' if 'dinov2' in model_name else 'facebookresearch/dino:main'
|
| 54 |
+
self.model = torch.hub.load(self.model_family, model_name)
|
| 55 |
+
|
| 56 |
+
elif 'mae' in model_name or 'sam' in model_name or 'clip' in model_name or 'dino' in model_name:
|
| 57 |
+
self.model = timm.create_model(
|
| 58 |
+
model_name,
|
| 59 |
+
pretrained=True,
|
| 60 |
+
num_classes=0, # remove classifier nn.Linear
|
| 61 |
+
img_size=resize_dim
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if 'sam' in model_name:
|
| 65 |
+
self.model.blocks[-1].register_forward_hook(get_vit_out)
|
| 66 |
+
else:
|
| 67 |
+
raise Exception("Unknown ViT model")
|
| 68 |
+
# self.model.eval()
|
| 69 |
+
mean = (0.485, 0.456, 0.406) if not 'clip' in model_name else (0.4815, 0.4578, 0.4082)
|
| 70 |
+
std = (0.229, 0.224, 0.225) if not 'clip' in model_name else (0.2686, 0.2613, 0.2758)
|
| 71 |
+
self.image_transforms = T.Compose([
|
| 72 |
+
T.Resize((resize_dim, resize_dim)),
|
| 73 |
+
lambda x: T.ToTensor()(x) if not isinstance(x, torch.Tensor) else x / 255.0, # ensure tensor
|
| 74 |
+
T.Normalize(mean, std),
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
self.model
|
| 78 |
+
self.model.requires_grad_(False)
|
| 79 |
+
|
| 80 |
+
self.clip_model_name = clip_model_name
|
| 81 |
+
if 'bert' in self.clip_model_name:
|
| 82 |
+
self.clip_model = BertModel.from_pretrained(self.clip_model_name, output_hidden_states = False)
|
| 83 |
+
# load the corresponding wordtokenizer
|
| 84 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name)
|
| 85 |
+
else:
|
| 86 |
+
self.clip_model, _ = clip.load(clip_model_name, device='meta')
|
| 87 |
+
self.clip_model.eval()
|
| 88 |
+
self.clip_model.requires_grad_(False)
|
| 89 |
+
if unfreeze_last_text_layer:
|
| 90 |
+
for param in self.clip_model.transformer.resblocks[-1].parameters():
|
| 91 |
+
param.requires_grad = True
|
| 92 |
+
for param in self.clip_model.ln_final.parameters():
|
| 93 |
+
param.requires_grad = True
|
| 94 |
+
self.clip_model.text_projection.requires_grad = True
|
| 95 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if 'vitb_mlp_infonce' in proj_class:
|
| 99 |
+
config = {
|
| 100 |
+
'act': 'tanh', # None, tanh, relu or sigmoid
|
| 101 |
+
'hidden_layer': True,
|
| 102 |
+
'dino_embed_dim': 768
|
| 103 |
+
}
|
| 104 |
+
elif 'vitl_mlp_infonce' in proj_class:
|
| 105 |
+
config = {
|
| 106 |
+
'act': 'tanh', # None, tanh, relu or sigmoid
|
| 107 |
+
'hidden_layer': True,
|
| 108 |
+
'dino_embed_dim': 1024
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
ProjClass = ProjectionLayer
|
| 112 |
+
self.proj = ProjClass.from_config(config)
|
| 113 |
+
|
| 114 |
+
self.masker = DINOTextMasker(similarity_type="cosine")
|
| 115 |
+
self.masker = self.masker.eval()
|
| 116 |
+
|
| 117 |
+
self.pamr = None
|
| 118 |
+
|
| 119 |
+
self.avg_self_attn_token = avg_self_attn_token
|
| 120 |
+
self.disentangled_self_attn_token = disentangled_self_attn_token
|
| 121 |
+
|
| 122 |
+
if self.avg_self_attn_token or self.disentangled_self_attn_token or is_eval:
|
| 123 |
+
self.model.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention)
|
| 124 |
+
self.num_global_tokens = 5 if 'reg' in model_name or 'dinov3' in model_name else 1
|
| 125 |
+
if 'sam' in self.model_name:
|
| 126 |
+
self.num_global_tokens = 0
|
| 127 |
+
if 'dinov3' in self.model_name:
|
| 128 |
+
if 'vit_base' in self.model_name:
|
| 129 |
+
self.num_attn_heads = 12
|
| 130 |
+
elif 'vit_large' in self.model_name:
|
| 131 |
+
self.num_attn_heads = 16
|
| 132 |
+
else:
|
| 133 |
+
raise Exception("Unknown dinov3 model")
|
| 134 |
+
else:
|
| 135 |
+
self.num_attn_heads = self.model.num_heads
|
| 136 |
+
self.scale = 0.125
|
| 137 |
+
|
| 138 |
+
self.use_avg_text_token = use_avg_text_token
|
| 139 |
+
if self.use_avg_text_token:
|
| 140 |
+
self.feats = {}
|
| 141 |
+
# in this case we register a forward hook with the aim of getting all the tokens and not only the cls
|
| 142 |
+
self.clip_model.ln_final.register_forward_hook(self.get_all_out_tokens)
|
| 143 |
+
self.keep_cls = keep_cls
|
| 144 |
+
self.keep_end_seq = keep_end_seq
|
| 145 |
+
|
| 146 |
+
self.with_bg_clean = with_bg_clean
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
|
| 150 |
+
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
|
| 151 |
+
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
|
| 152 |
+
attn = q @ k.transpose(-2, -1)
|
| 153 |
+
self_attn_maps = attn[:, : , 0, num_global_tokens:]
|
| 154 |
+
self_attn = self_attn_maps.mean(dim=1)
|
| 155 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 156 |
+
if ret_self_attn_maps:
|
| 157 |
+
return self_attn, self_attn_maps
|
| 158 |
+
else:
|
| 159 |
+
return self_attn
|
| 160 |
+
|
| 161 |
+
def encode_text(self, tokenized_texts):
|
| 162 |
+
if type(self.proj) == CLIPLastLayer:
|
| 163 |
+
self.clip_model.encode_text(tokenized_texts)
|
| 164 |
+
x = self.feats['clip_second_last_out']
|
| 165 |
+
x = x.to(dtype=torch.float32)
|
| 166 |
+
else:
|
| 167 |
+
x = self.clip_model.encode_text(tokenized_texts)
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def encode_image(self, images):
|
| 171 |
+
batch_size, _, _, _ = images.shape
|
| 172 |
+
self_attn_maps = None
|
| 173 |
+
x = self.model(images, is_training=(self.avg_self_attn_token or self.disentangled_self_attn_token))
|
| 174 |
+
batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
|
| 175 |
+
num_tokens = num_tokens + self.num_global_tokens
|
| 176 |
+
if self.avg_self_attn_token or self.disentangled_self_attn_token:
|
| 177 |
+
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
|
| 178 |
+
if self.avg_self_attn_token:
|
| 179 |
+
x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1)
|
| 180 |
+
elif self.disentangled_self_attn_token:
|
| 181 |
+
self_attn_maps = self_attn_maps.softmax(dim=-1)
|
| 182 |
+
x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2)
|
| 183 |
+
|
| 184 |
+
return x, self_attn_maps
|
| 185 |
+
|
| 186 |
+
def forward(self, image, text, return_logit_scale=False):
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
txt_embed = self.encode_text(text)
|
| 189 |
+
|
| 190 |
+
img_embed, self_attn_maps = self.encode_image(image)
|
| 191 |
+
|
| 192 |
+
if type(self.proj) == CLIPLastLayer:
|
| 193 |
+
img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps, text_argmax=text.argmax(dim=-1))
|
| 194 |
+
else:
|
| 195 |
+
img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps)
|
| 196 |
+
|
| 197 |
+
if return_logit_scale:
|
| 198 |
+
return txt_embed, img_embed, self.logit_scale
|
| 199 |
+
|
| 200 |
+
return txt_embed, img_embed
|
| 201 |
+
|
| 202 |
+
def compute_loss(self, image, text, cosine=True, ret_similarity_matrix=True):
|
| 203 |
+
ret = {}
|
| 204 |
+
if cosine:
|
| 205 |
+
img_embed = F.normalize(img_embed, p=2, dim=1)
|
| 206 |
+
txt_embed = F.normalize(txt_embed, p=2, dim=1)
|
| 207 |
+
sim = img_embed @ txt_embed.transpose(1, 0)
|
| 208 |
+
if not ret_similarity_matrix:
|
| 209 |
+
sim = sim[torch.eye(len(sim)) > 0.5] # only diagonal elements
|
| 210 |
+
|
| 211 |
+
ret['contrastive_loss'] = self.contrastive_loss.compute_contrastive_loss(sim)
|
| 212 |
+
|
| 213 |
+
return ret
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def build_dataset_class_tokens(self, template_set, classnames):
|
| 218 |
+
tokens = []
|
| 219 |
+
templates = get_template(template_set)
|
| 220 |
+
for classname in classnames:
|
| 221 |
+
if 'bert' not in self.clip_model_name:
|
| 222 |
+
tokens.append(
|
| 223 |
+
clip.tokenize([template.format(classname) for template in templates])
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
tokens.append(self.tokenizer([template.format(classname) for template in templates], return_tensors='pt', padding='max_length')['input_ids'])
|
| 227 |
+
# [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
|
| 228 |
+
tokens = torch.stack(tokens)
|
| 229 |
+
|
| 230 |
+
return tokens
|
| 231 |
+
|
| 232 |
+
@torch.no_grad()
|
| 233 |
+
def build_text_embedding(self, text):
|
| 234 |
+
"""
|
| 235 |
+
Args:
|
| 236 |
+
text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH] text tokens
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
text_embs
|
| 240 |
+
"""
|
| 241 |
+
text = text.to(next(self.parameters()).device)
|
| 242 |
+
num_classes, num_templates = text.shape[:2]
|
| 243 |
+
text_argmax = text.argmax(dim=-1)
|
| 244 |
+
text_argmax = rearrange(text_argmax, 'n t -> (n t)', n=num_classes, t=num_templates)
|
| 245 |
+
text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
|
| 246 |
+
# chunked inference for memory limitation
|
| 247 |
+
chunk_size = 32
|
| 248 |
+
N = text.size(0)
|
| 249 |
+
if type(self.proj) == CLIPLastLayer:
|
| 250 |
+
text_embs = torch.cat([
|
| 251 |
+
self.proj.project_clip_txt(self.encode_text(text[i:i + chunk_size]).permute(1, 0, 2), text_argmax=text_argmax[i:i + chunk_size])
|
| 252 |
+
for i in range(0, N, chunk_size)
|
| 253 |
+
])
|
| 254 |
+
else:
|
| 255 |
+
if not self.use_avg_text_token:
|
| 256 |
+
# performing classification using CLS textual token
|
| 257 |
+
if 'bert' not in self.clip_model_name:
|
| 258 |
+
text_embs = torch.cat([
|
| 259 |
+
self.clip_model.encode_text(text[i:i + chunk_size])
|
| 260 |
+
for i in range(0, N, chunk_size)
|
| 261 |
+
])
|
| 262 |
+
else:
|
| 263 |
+
# encoding with BERT
|
| 264 |
+
text_embs = []
|
| 265 |
+
for i in range(0, N, chunk_size):
|
| 266 |
+
outputs = self.clip_model(text[i:i + chunk_size])
|
| 267 |
+
text_embs.append(outputs['pooler_output'])
|
| 268 |
+
text_embs = torch.cat(text_embs)
|
| 269 |
+
else:
|
| 270 |
+
# using text token average
|
| 271 |
+
text_embs = []
|
| 272 |
+
for i in range(0, N, chunk_size):
|
| 273 |
+
self.clip_model.encode_text(text[i:i + chunk_size])
|
| 274 |
+
text_embs.append(average_text_tokens(self.feats['clip_txt_out_tokens'] @ self.clip_model.text_projection, text[i:i + chunk_size] > 0, self.keep_cls, self.keep_end_seq))
|
| 275 |
+
text_embs = torch.cat(text_embs)
|
| 276 |
+
# [N, T, C]
|
| 277 |
+
text_embs = rearrange(text_embs, '(n t) c -> n t c', n=num_classes, t=num_templates)
|
| 278 |
+
# [N, C]
|
| 279 |
+
text_embs = text_embs.mean(dim=1).float()
|
| 280 |
+
if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
|
| 281 |
+
text_embs = self.proj.project_clip_txt(text_embs)
|
| 282 |
+
text_embs = normalize(text_embs, dim=-1)
|
| 283 |
+
|
| 284 |
+
return text_embs
|
| 285 |
+
|
| 286 |
+
def apply_pamr(self, image, mask):
|
| 287 |
+
image = F.interpolate(image, mask.shape[-2:], mode="bilinear", align_corners=True)
|
| 288 |
+
if self.pamr is None:
|
| 289 |
+
pamr_iter = 10
|
| 290 |
+
pamr_kernel = [1, 2, 4, 8, 12, 24]
|
| 291 |
+
self.pamr = PAMR(pamr_iter, pamr_kernel)
|
| 292 |
+
self.pamr.eval()
|
| 293 |
+
self.pamr.to(next(self.parameters()).device)
|
| 294 |
+
|
| 295 |
+
mask = self.pamr(image, mask)
|
| 296 |
+
return mask
|
| 297 |
+
|
| 298 |
+
def compute_padsize(self, H: int, W: int, patch_size: int):
|
| 299 |
+
l, r, t, b = 0, 0, 0, 0
|
| 300 |
+
if W % patch_size:
|
| 301 |
+
lr = patch_size - (W % patch_size)
|
| 302 |
+
l = lr // 2
|
| 303 |
+
r = lr - l
|
| 304 |
+
|
| 305 |
+
if H % patch_size:
|
| 306 |
+
tb = patch_size - (H % patch_size)
|
| 307 |
+
t = tb // 2
|
| 308 |
+
b = tb - t
|
| 309 |
+
|
| 310 |
+
return l, r, t, b
|
| 311 |
+
|
| 312 |
+
@torch.no_grad()
|
| 313 |
+
def generate_masks(
|
| 314 |
+
self, image, img_metas, text_emb, classnames, text_is_token=False, apply_pamr=False, background_func="weighted_average_sigmoid", lambda_bg=0.2,
|
| 315 |
+
# kp_w=0.3,
|
| 316 |
+
):
|
| 317 |
+
"""Generate masks for each text embeddings
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
image [B, 3, H, W]
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
softmask [B, N, H, W]: softmasks for each text embeddings
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
H, W = image.shape[2:] # original image shape
|
| 327 |
+
|
| 328 |
+
# padded image size
|
| 329 |
+
pH, pW = image.shape[2:]
|
| 330 |
+
num_classes = text_emb.shape[0]
|
| 331 |
+
batch_size = image.shape[0]
|
| 332 |
+
|
| 333 |
+
image = image[:, [2, 1, 0], :, :] # BGR to RGB
|
| 334 |
+
ori_image = image.clone()
|
| 335 |
+
|
| 336 |
+
img_preprocessed = self.image_transforms(image).to(next(self.parameters()).device)
|
| 337 |
+
if 'dinov2' in self.model_name:
|
| 338 |
+
image_feat = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
|
| 339 |
+
elif 'dinov3' in self.model_name:
|
| 340 |
+
image_feat = self.model.forward_features(img_preprocessed)[:, 5:, :]
|
| 341 |
+
elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
|
| 342 |
+
image_feat = self.model.forward_features(img_preprocessed)[:, 1:, :]
|
| 343 |
+
elif 'sam' in self.model_name:
|
| 344 |
+
self.model.forward_features(img_preprocessed)
|
| 345 |
+
image_feat = feats['vit_out'].reshape(feats['vit_out'].shape[0], feats['vit_out'].shape[1]**2, feats['vit_out'].shape[-1]) # BS x N_PATCHES x EMBED_DIM
|
| 346 |
+
|
| 347 |
+
batch_size, num_tokens, embed_dim = image_feat.shape
|
| 348 |
+
if type(self.proj) == VisualProjectionLayer:
|
| 349 |
+
image_feat = self.proj.project_dino(image_feat.float())
|
| 350 |
+
if type(self.proj) == DoubleMLP:
|
| 351 |
+
image_feat = self.proj.project_visual(image_feat.float())
|
| 352 |
+
b, np, c = image_feat.shape
|
| 353 |
+
np_h = np_w = int(sqrt(np))
|
| 354 |
+
image_feat = image_feat.reshape(b, np_h, np_w, c).permute(0, 3, 1, 2)
|
| 355 |
+
|
| 356 |
+
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens + self.num_global_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
|
| 357 |
+
mask, simmap = self.masker.forward_seg(image_feat, text_emb, hard=False) # [B, N, H', W']
|
| 358 |
+
|
| 359 |
+
if self.with_bg_clean:
|
| 360 |
+
mask = self.similarity_assignment_weighted(mask, image_feat, self_attn_maps, text_emb, lambda_bg)
|
| 361 |
+
|
| 362 |
+
# resize
|
| 363 |
+
mask = F.interpolate(mask, (pH, pW), mode='bilinear', align_corners=True) # [B, N, H, W]
|
| 364 |
+
|
| 365 |
+
if apply_pamr:
|
| 366 |
+
for c in range(0, mask.shape[1], 30):
|
| 367 |
+
mask[:, c:c + 30] = self.apply_pamr(ori_image, mask[:, c:c + 30])
|
| 368 |
+
|
| 369 |
+
assert mask.shape[2] == H and mask.shape[3] == W, f"shape mismatch: ({H}, {W}) / {mask.shape}"
|
| 370 |
+
|
| 371 |
+
return mask, simmap
|
| 372 |
+
|
| 373 |
+
def similarity_assignment_weighted(self, mask, image_feat, self_attn_maps, text_emb, lambda_bg=0.2):
|
| 374 |
+
bs, c, h, w = image_feat.shape
|
| 375 |
+
bs, num_classes, h, w = mask.shape
|
| 376 |
+
bs, num_heads, hw = self_attn_maps.shape
|
| 377 |
+
image_feat = image_feat.reshape(bs, c, hw)
|
| 378 |
+
num_classes, c = text_emb.shape
|
| 379 |
+
avg_head_embed = (self_attn_maps.unsqueeze(2) * image_feat.unsqueeze(1)).mean(dim=-1)
|
| 380 |
+
avg_head_embed = avg_head_embed / avg_head_embed.norm(dim=-1, keepdim=True)
|
| 381 |
+
avg_head_embed = avg_head_embed.permute(0, 2, 1) # [B, C, M]
|
| 382 |
+
head_text_sim = text_emb.unsqueeze(0) @ avg_head_embed # [B, M, N]
|
| 383 |
+
head_text_sim = (head_text_sim).softmax(dim=-1)
|
| 384 |
+
head_text_sim_sum = head_text_sim.sum(dim=-1)
|
| 385 |
+
|
| 386 |
+
self_attn_maps_repeat = self_attn_maps.unsqueeze(1).repeat(1, num_classes, 1, 1)
|
| 387 |
+
head_text_sim_repeat = head_text_sim.unsqueeze(-1).repeat(1, 1, 1, hw)
|
| 388 |
+
avg_self_attn_per_class = (self_attn_maps_repeat * head_text_sim_repeat).sum(dim=2) / head_text_sim_sum.unsqueeze(-1).repeat(1, 1, hw)
|
| 389 |
+
avg_self_attn_per_class = avg_self_attn_per_class.softmax(dim=-1)
|
| 390 |
+
|
| 391 |
+
min_self_attn = avg_self_attn_per_class.min().item()
|
| 392 |
+
max_self_attn = avg_self_attn_per_class.max().item()
|
| 393 |
+
max_self_attn = max(max_self_attn, max_self_attn - min_self_attn)
|
| 394 |
+
avg_self_attn_per_class = avg_self_attn_per_class - min_self_attn
|
| 395 |
+
avg_self_attn_per_class = avg_self_attn_per_class / max_self_attn
|
| 396 |
+
avg_self_attn_per_class = avg_self_attn_per_class * (mask.max() - mask.min()) + mask.min()
|
| 397 |
+
mask = mask.reshape(num_classes, hw) # [N, P]
|
| 398 |
+
mask_output = (mask + lambda_bg * avg_self_attn_per_class).reshape(bs, num_classes, h, w) / (1 + lambda_bg)
|
| 399 |
+
return mask_output
|
hf_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hooks.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
feats = {}
|
| 3 |
+
def get_self_attention(module, input, output):
|
| 4 |
+
feats['self_attn'] = output
|
| 5 |
+
|
| 6 |
+
def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
|
| 7 |
+
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
|
| 8 |
+
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
|
| 9 |
+
attn = q @ k.transpose(-2, -1)
|
| 10 |
+
self_attn_maps = attn[:, : , 0, num_global_tokens:]
|
| 11 |
+
self_attn = self_attn_maps.mean(dim=1)
|
| 12 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 13 |
+
if ret_self_attn_maps:
|
| 14 |
+
return self_attn, self_attn_maps
|
| 15 |
+
else:
|
| 16 |
+
return self_attn
|
| 17 |
+
|
| 18 |
+
def get_vit_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 19 |
+
feats['vit_out'] = output
|
| 20 |
+
|
| 21 |
+
def get_second_last_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 22 |
+
feats['second_last_out'] = output
|
| 23 |
+
|
| 24 |
+
def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 25 |
+
feats['clip_txt_out_tokens'] = output
|
| 26 |
+
|
| 27 |
+
def get_clip_second_last_dense_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 28 |
+
feats['clip_second_last_out'] = output.permute(1,0,2)
|
| 29 |
+
|
| 30 |
+
def get_dinov1_patches(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 31 |
+
feats['dinov1_patches'] = output
|
| 32 |
+
|
| 33 |
+
def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 34 |
+
feats['clip_txt_out_tokens'] = output
|
| 35 |
+
|
| 36 |
+
def average_text_tokens(text_embeddings, mask, keep_cls=False, keep_end_seq=False):
|
| 37 |
+
if not keep_end_seq:
|
| 38 |
+
mask[torch.arange(mask.shape[0]), mask.sum(dim=1) - 1] = False # excluding end of sequence
|
| 39 |
+
if not keep_cls:
|
| 40 |
+
mask[:, 0] = False # excluding CLS token
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
masked_embeddings = text_embeddings * mask.unsqueeze(-1) # shape: [BS, SEQ_LEN, 512]
|
| 44 |
+
|
| 45 |
+
sum_embeddings = masked_embeddings.sum(dim=1) # shape: [BS, 512]
|
| 46 |
+
|
| 47 |
+
valid_elements = mask.sum(dim=1, keepdim=True) # shape: [BS, 1]
|
| 48 |
+
|
| 49 |
+
mean_embeddings = sum_embeddings / valid_elements # shape: [BS, 512]
|
| 50 |
+
|
| 51 |
+
return mean_embeddings
|
| 52 |
+
|
masker.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Talk2DINO
|
| 3 |
+
# ------------------------------------------------------------------------------
|
| 4 |
+
import copy
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from .us import normalize
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
|
| 14 |
+
# from models.dinotext.gumbel import gumbel_sigmoid
|
| 15 |
+
from .modules import FeatureEncoder
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_model(config):
|
| 20 |
+
model = OmegaConf.to_container(config, resolve=True)
|
| 21 |
+
return model
|
| 22 |
+
|
| 23 |
+
class Sim2Mask(nn.Module):
|
| 24 |
+
def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.init_w = init_w
|
| 27 |
+
self.init_b = init_b
|
| 28 |
+
self.gumbel_tau = gumbel_tau
|
| 29 |
+
self.learnable = learnable
|
| 30 |
+
|
| 31 |
+
assert not ((init_w is None) ^ (init_b is None))
|
| 32 |
+
if learnable:
|
| 33 |
+
self.w = nn.Parameter(torch.full([], float(init_w)))
|
| 34 |
+
self.b = nn.Parameter(torch.full([], float(init_b)))
|
| 35 |
+
else:
|
| 36 |
+
self.w = init_w
|
| 37 |
+
self.b = init_b
|
| 38 |
+
|
| 39 |
+
def forward(self, x, deterministic=False):
|
| 40 |
+
logits = x * self.w + self.b
|
| 41 |
+
|
| 42 |
+
soft_mask = torch.sigmoid(logits)
|
| 43 |
+
if deterministic:
|
| 44 |
+
hard_mask = soft_mask.gt(0.5).type(logits.dtype)
|
| 45 |
+
else:
|
| 46 |
+
hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
|
| 47 |
+
|
| 48 |
+
return hard_mask, soft_mask
|
| 49 |
+
|
| 50 |
+
def extra_repr(self):
|
| 51 |
+
return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MaskerBackbone(nn.Module):
|
| 55 |
+
"""Masker image encoder backbone.
|
| 56 |
+
"""
|
| 57 |
+
def __init__(self, clip_visual, freeze_idx):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.transformer = copy.deepcopy(clip_visual.transformer)
|
| 60 |
+
self.transformer.resblocks = self.transformer.resblocks[freeze_idx:]
|
| 61 |
+
|
| 62 |
+
for block in self.transformer.resblocks:
|
| 63 |
+
if hasattr(block, "hook_handler"):
|
| 64 |
+
block.hook_handler.remove()
|
| 65 |
+
|
| 66 |
+
self.ln_post = copy.deepcopy(clip_visual.ln_post)
|
| 67 |
+
self.proj = copy.deepcopy(clip_visual.proj)
|
| 68 |
+
|
| 69 |
+
self.layers = len(self.transformer.resblocks)
|
| 70 |
+
self.patch_size = clip_visual.patch_size
|
| 71 |
+
|
| 72 |
+
self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width
|
| 73 |
+
|
| 74 |
+
def forward(self, x, spatial=True, ignore_last_attn=True):
|
| 75 |
+
if self.layers:
|
| 76 |
+
x = self.transformer(x, ignore_last_attn=ignore_last_attn)
|
| 77 |
+
|
| 78 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 79 |
+
|
| 80 |
+
if spatial:
|
| 81 |
+
x = self.ln_post(x)
|
| 82 |
+
else:
|
| 83 |
+
x = self.ln_post(x[:, 0, :])
|
| 84 |
+
|
| 85 |
+
if self.proj is not None:
|
| 86 |
+
x = x @ self.proj
|
| 87 |
+
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
class MaskerImageFeatureEncoder(FeatureEncoder):
|
| 91 |
+
def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.ignore_last_attn = ignore_last_attn
|
| 94 |
+
self.patch_size = backbone.patch_size
|
| 95 |
+
self.backbone = backbone
|
| 96 |
+
self.decoder = decoder
|
| 97 |
+
|
| 98 |
+
for resblock in self.backbone.transformer.resblocks:
|
| 99 |
+
resblock.hook_handler = resblock.register_forward_hook(self.hook)
|
| 100 |
+
|
| 101 |
+
def _encode(self, image, image_feat):
|
| 102 |
+
H, W = image.shape[-2:]
|
| 103 |
+
h = H // self.patch_size
|
| 104 |
+
w = W // self.patch_size
|
| 105 |
+
|
| 106 |
+
x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn) # BLC
|
| 107 |
+
x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w)
|
| 108 |
+
x = self.decoder(x)
|
| 109 |
+
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
class Masker(nn.Module):
|
| 113 |
+
def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.ignore_last_attn = ignore_last_attn
|
| 116 |
+
|
| 117 |
+
decoder["C"] = backbone.output_dim
|
| 118 |
+
decoder = MODELS.build(decoder)
|
| 119 |
+
decoder = nn.Sequential(OrderedDict([
|
| 120 |
+
("decoder", decoder),
|
| 121 |
+
("image_proj", image_proj)
|
| 122 |
+
]))
|
| 123 |
+
|
| 124 |
+
self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn)
|
| 125 |
+
|
| 126 |
+
self.sim2mask = Sim2Mask(**sim2mask)
|
| 127 |
+
|
| 128 |
+
def forward(self, image, image_feat, text_emb, deterministic=False):
|
| 129 |
+
B = image.size(0)
|
| 130 |
+
image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
|
| 131 |
+
|
| 132 |
+
image_emb_norm = normalize(image_emb, dim=1)
|
| 133 |
+
text_emb_norm = normalize(text_emb, dim=-1)
|
| 134 |
+
|
| 135 |
+
H, W = image_emb.shape[2:]
|
| 136 |
+
D = dist.get_world_size()
|
| 137 |
+
|
| 138 |
+
# simmap [B, B*D, H, W] where D is #devices
|
| 139 |
+
all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
|
| 140 |
+
simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
|
| 141 |
+
mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 142 |
+
|
| 143 |
+
# mask [B, B*D, H, W] where D is #devices
|
| 144 |
+
# positive global label
|
| 145 |
+
pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank()
|
| 146 |
+
pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1) # [B, 1, H, W]
|
| 147 |
+
|
| 148 |
+
offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device)
|
| 149 |
+
offdiag[torch.arange(B), pos_indices] = False
|
| 150 |
+
|
| 151 |
+
soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1)
|
| 152 |
+
soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W)
|
| 153 |
+
|
| 154 |
+
masks = {
|
| 155 |
+
"pos": pos_mask, # [B, 1, H, W]
|
| 156 |
+
|
| 157 |
+
"soft_pos": soft_pos_mask,
|
| 158 |
+
"soft_neg": soft_neg_mask,
|
| 159 |
+
"soft_all": soft_mask, # [B, N, H, W]
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
return masks, image_emb, text_emb, feats
|
| 163 |
+
|
| 164 |
+
@torch.no_grad()
|
| 165 |
+
def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False):
|
| 166 |
+
"""Make mask by 1:N matching
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
image [B, 3, H, W]
|
| 170 |
+
image_feat [L, B, C]: CLIP features
|
| 171 |
+
text_emb [N, C]
|
| 172 |
+
deterministic (bool): deterministic inference flag for gumbel noise
|
| 173 |
+
hard (bool): decide hard or soft returning segmentation mask.
|
| 174 |
+
Note that soft mask is required for proper evaluation
|
| 175 |
+
|
| 176 |
+
Return:
|
| 177 |
+
mask [B, N, H', W'] (H' and W' are downsampled H/W)
|
| 178 |
+
"""
|
| 179 |
+
image_emb = self.image_encoder(image, image_feat) # [BCHW]
|
| 180 |
+
|
| 181 |
+
image_emb = normalize(image_emb, dim=1) # BCHW
|
| 182 |
+
text_emb = normalize(text_emb, dim=-1) # NC
|
| 183 |
+
|
| 184 |
+
simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
|
| 185 |
+
|
| 186 |
+
hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 187 |
+
mask = hard_mask if hard else soft_mask
|
| 188 |
+
|
| 189 |
+
return mask, simmap
|
| 190 |
+
|
| 191 |
+
class DINOTextMasker(nn.Module):
|
| 192 |
+
def __init__(self, similarity_type="cosine"):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.sim2mask = DINOTextSim2Mask()
|
| 195 |
+
self.sim2mask = self.sim2mask.eval()
|
| 196 |
+
self.similarity_type = similarity_type
|
| 197 |
+
|
| 198 |
+
def forward(self, image, image_feat, text_emb, deterministic=False):
|
| 199 |
+
pass
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False):
|
| 203 |
+
"""Make mask by 1:N matching
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
image [B, 3, H, W]
|
| 207 |
+
image_feat [L, B, C]: CLIP features
|
| 208 |
+
text_emb [N, K, C]
|
| 209 |
+
deterministic (bool): deterministic inference flag for gumbel noise
|
| 210 |
+
hard (bool): decide hard or soft returning segmentation mask.
|
| 211 |
+
Note that soft mask is required for proper evaluation
|
| 212 |
+
use_k_nn (bool): use kNN to segment
|
| 213 |
+
k_nn (int): number of nearest neighbors for kNN segmentation
|
| 214 |
+
|
| 215 |
+
Return:
|
| 216 |
+
mask [B, N, H', W'] (H' and W' are downsampled H/W)
|
| 217 |
+
"""
|
| 218 |
+
b, c, h, w = image_feat.shape
|
| 219 |
+
n, c = text_emb.shape
|
| 220 |
+
|
| 221 |
+
if self.similarity_type == "cosine":
|
| 222 |
+
image_feat = normalize(image_feat, dim=1) # BCHW
|
| 223 |
+
# text_emb = normalize(text_emb, dim=-1) # NKC
|
| 224 |
+
simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
|
| 225 |
+
else:
|
| 226 |
+
raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
|
| 227 |
+
|
| 228 |
+
hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 229 |
+
mask = hard_mask if hard else soft_mask
|
| 230 |
+
|
| 231 |
+
return mask, simmap
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class DINOTextSim2Mask(nn.Module):
|
| 235 |
+
def __init__(self, gumbel_tau=1.0):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.gumbel_tau = gumbel_tau
|
| 238 |
+
|
| 239 |
+
def forward(self, x, deterministic=False):
|
| 240 |
+
soft_mask = torch.sigmoid(x)
|
| 241 |
+
if deterministic:
|
| 242 |
+
hard_mask = soft_mask.gt(0.5).type(x.dtype)
|
| 243 |
+
else:
|
| 244 |
+
hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau)
|
| 245 |
+
|
| 246 |
+
return hard_mask, soft_mask
|