lorebianchi98 commited on
Commit
29cb4d8
·
1 Parent(s): 7de983f

First commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +1 -0
  3. LICENSE.md +66 -0
  4. README.md +118 -4
  5. assets/overview.png +3 -0
  6. assets/pikachu.png +3 -0
  7. assets/pikachu_seg.png +3 -0
  8. assets/qualitatives.png +3 -0
  9. assets/qualitatives/cityscapes/1_clipdinoiser.png +3 -0
  10. assets/qualitatives/cityscapes/1_freeda.png +3 -0
  11. assets/qualitatives/cityscapes/1_gt.png +3 -0
  12. assets/qualitatives/cityscapes/1_image.png +3 -0
  13. assets/qualitatives/cityscapes/1_proxyclip.png +3 -0
  14. assets/qualitatives/cityscapes/1_talk2dino.png +3 -0
  15. assets/qualitatives/cityscapes/1r_clipdinoiser.png +3 -0
  16. assets/qualitatives/cityscapes/1r_freeda.png +3 -0
  17. assets/qualitatives/cityscapes/1r_gt.png +3 -0
  18. assets/qualitatives/cityscapes/1r_image.png +3 -0
  19. assets/qualitatives/cityscapes/1r_proxyclip.png +3 -0
  20. assets/qualitatives/cityscapes/1r_talk2dino.png +3 -0
  21. assets/qualitatives/context/1r_clipdinoiser.png +3 -0
  22. assets/qualitatives/context/1r_freeda.png +3 -0
  23. assets/qualitatives/context/1r_gt.png +3 -0
  24. assets/qualitatives/context/1r_img.png +3 -0
  25. assets/qualitatives/context/1r_proxy.png +3 -0
  26. assets/qualitatives/context/1r_talk2dino.png +3 -0
  27. assets/qualitatives/object/2r_clipdinoiser.png +3 -0
  28. assets/qualitatives/object/2r_freeda.png +3 -0
  29. assets/qualitatives/object/2r_gt.png +3 -0
  30. assets/qualitatives/object/2r_img.png +3 -0
  31. assets/qualitatives/object/2r_proxy.png +3 -0
  32. assets/qualitatives/object/2r_talk2dino.png +3 -0
  33. assets/qualitatives/voc/1_clipdinoiser.png +3 -0
  34. assets/qualitatives/voc/1_freeda.png +3 -0
  35. assets/qualitatives/voc/1_gt.png +3 -0
  36. assets/qualitatives/voc/1_img.jpg +0 -0
  37. assets/qualitatives/voc/1_proxy.png +3 -0
  38. assets/qualitatives/voc/1_talk2dino.png +3 -0
  39. assets/qualitatives/voc/2_clipdinoiser.png +3 -0
  40. assets/qualitatives/voc/2_freeda.png +3 -0
  41. assets/qualitatives/voc/2_gt.png +3 -0
  42. assets/qualitatives/voc/2_img.jpg +0 -0
  43. assets/qualitatives/voc/2_proxy.png +3 -0
  44. assets/qualitatives/voc/2_talk2dino.png +3 -0
  45. config.json +6 -0
  46. configuration_talk2dino.py +49 -0
  47. dinotext.py +399 -0
  48. hf_demo.ipynb +0 -0
  49. hooks.py +52 -0
  50. masker.py +246 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
LICENSE.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DINOv3 License
2
+
3
+ *Last Updated: August 19, 2025*
4
+
5
+ **“Agreement”** means the terms and conditions for use, reproduction, distribution and modification of the DINO Materials set forth herein.
6
+
7
+ **“DINO Materials”** means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
8
+
9
+ **“Documentation”** means the specifications, manuals and documentation accompanying
10
+ DINO Materials distributed by Meta.
11
+
12
+ **“Licensee”** or **“you”** means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
13
+
14
+ **“Meta”** or **“we”** means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
15
+
16
+ **“Sanctions”** means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
17
+
18
+ **“Trade Controls”** means any of the following: Sanctions and applicable export and import controls.
19
+
20
+ By clicking “I Accept” below or by using or distributing any portion or element of the DINO Materials, you agree to be bound by this Agreement.
21
+
22
+ ## 1. License Rights and Redistribution.
23
+
24
+ a. <ins>Grant of Rights</ins>. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the DINO Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DINO Materials.
25
+
26
+ b. <ins>Redistribution and Use</ins>.
27
+
28
+ i. Distribution of DINO Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the DINO Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such DINO Materials.
29
+
30
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with DINO Materials, you must acknowledge the use of DINO Materials in your publication.
31
+
32
+ iii. Your use of the DINO Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
33
+
34
+ iv. Your use of the DINO Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the DINO Materials.
35
+
36
+ v. You are not the target of Trade Controls and your use of DINO Materials must comply with Trade Controls. You agree not to use, or permit others to use, DINO Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
37
+
38
+ ## 2. User Support.
39
+
40
+ Your use of the DINO Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the DINO Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
41
+
42
+ ## 3. Disclaimer of Warranty.
43
+
44
+ UNLESS REQUIRED BY APPLICABLE LAW, THE DINO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DINO MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DINO MATERIALS AND ANY OUTPUT AND RESULTS.
45
+
46
+ ## 4. Limitation of Liability.
47
+
48
+ IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
49
+
50
+ ## 5. Intellectual Property.
51
+
52
+ a. Subject to Meta’s ownership of DINO Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the DINO Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
53
+
54
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DINO Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the DINO Materials.
55
+
56
+ ## 6. Term and Termination.
57
+
58
+ The term of this Agreement will commence upon your acceptance of this Agreement or access to the DINO Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DINO Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
59
+
60
+ ## 7. Governing Law and Jurisdiction.
61
+
62
+ This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
63
+
64
+ ## 8. Modifications and Amendments.
65
+
66
+ Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the DINO Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
README.md CHANGED
@@ -1,10 +1,124 @@
1
  ---
 
 
 
 
2
  tags:
3
  - model_hub_mixin
4
  - pytorch_model_hub_mixin
 
 
 
5
  ---
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Code: [More Information Needed]
9
- - Paper: [More Information Needed]
10
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: other
3
+ license_name: dinov3-license
4
+ pipeline_tag: image-segmentation
5
+ library_name: Pytorch
6
  tags:
7
  - model_hub_mixin
8
  - pytorch_model_hub_mixin
9
+ - DINOv3
10
+ - CLIP
11
+ - open-vocabulary segmentation
12
  ---
13
 
14
+ <div align="center">
15
+ <h1>
16
+ Talking to DINO: Bridging Self-Supervised Vision Backbones with Language for Open-Vocabulary Segmentation (ICCV 2025)
17
+ </h1>
18
+
19
+ <h3>
20
+ <a href="https://www.linkedin.com/in/luca-barsellotti/">Luca Barsellotti*</a>&ensp;
21
+ <a href="https://www.linkedin.com/in/lorenzo-bianchi-893bb225a/">Lorenzo Bianchi*</a>&ensp;
22
+ <a href="https://www.linkedin.com/in/nicola-messina-a33848164/">Nicola Messina</a>&ensp;
23
+ <a href="https://www.linkedin.com/in/fabio-carrara-b28a2b111/">Fabio Carrara</a>&ensp;
24
+ <a href="https://aimagelab.ing.unimore.it/imagelab/person.asp?idpersona=90">Marcella Cornia</a>&ensp;
25
+ <a href="https://www.lorenzobaraldi.com/">Lorenzo Baraldi</a>&ensp;
26
+ <a href="https://fabriziofalchi.it">Fabrizio Falchi</a>&ensp;
27
+ <a href="https://www.linkedin.com/in/rita-cucchiara-a4653a13/">Rita Cucchiara</a>
28
+ </h3>
29
+
30
+ [Project Page](https://lorebianchi98.github.io/Talk2DINO/) | [Paper](http://arxiv.org/abs/2411.19331) | [Code](https://github.com/lorebianchi98/Talk2DINO)
31
+
32
+ </div>
33
+
34
+ <div align="center">
35
+ <figure>
36
+ <img alt="Overview of Talk2DINO" src="./assets/overview.png" width="90%">
37
+ </figure>
38
+ </div>
39
+
40
+ ## About
41
+ Open-Vocabulary Segmentation (OVS) aims at segmenting images from free-form textual concepts without predefined training classes. While existing vision-language models such as CLIP can generate segmentation masks by leveraging coarse spatial information from Vision Transformers, they face challenges in spatial localization due to their global alignment of image and text features. Conversely, self-supervised visual models like DINO excel in fine-grained visual encoding but lack integration with language. To bridge this gap, we present Talk2DINO, a novel hybrid approach that combines the spatial accuracy of DINOv2 with the language understanding of CLIP. Our approach aligns the textual embeddings of CLIP to the patch-level features of DINOv2 through a learned mapping function without the need to fine-tune the underlying backbones. At training time, we exploit the attention maps of DINOv2 to selectively align local visual patches with textual embeddings. We show that the powerful semantic and localization abilities of Talk2DINO can enhance the segmentation process, resulting in more natural and less noisy segmentations, and that our approach can also effectively distinguish foreground objects from the background. Experimental results demonstrate that Talk2DINO achieves state-of-the-art performance across several unsupervised OVS benchmarks.
42
+
43
+ ## Sample Usage
44
+
45
+ ### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
46
+ We can use Talk2DINO to map CLIP text embeddings into the DINOv3 patch embedding space.
47
+ ```python
48
+ from transformers import AutoModel
49
+ from torchvision.io import read_image
50
+
51
+ # Device setup
52
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
+
54
+ # Model Loading
55
+ model = AutoModel.from_pretrained("lorebianchi98/Talk2DINO_v3-ViTL").to(device).eval()
56
+
57
+ # Embedding generation
58
+ with torch.no_grad():
59
+ text_embed = model.encode_text("a pikachu")
60
+ image_embed = model.encode_image(image)
61
+
62
+ # normalize the features to perform cosine similarity
63
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
64
+ image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
65
+
66
+ similarity = (image_embed @ text_embed.T).squeeze(0, -1).cpu().numpy()
67
+ ```
68
+
69
+ ### Demo
70
+ In `demo.ipynb` we provide a simple example on how to use Talk2DINO for inference on a given image with custom textual categories.
71
+ Result:
72
+ <div align="center">
73
+ <table><tr><td><figure>
74
+ <img alt="" src="./assets/pikachu.png" width=300>
75
+ </figure></td><td><figure>
76
+ <img alt="" src="./assets/pikachu_seg.png" width=300>
77
+ </figure></td></tr></table>
78
+ </div>
79
+
80
+ ## Installation
81
+
82
+ To use the **Hugging Face interface** for inference:
83
+
84
+ ```bash
85
+ # Clone the repository
86
+ git clone https://huggingface.co/lorebianchi98/Talk2DINO-ViTB
87
+ cd Talk2DINO-ViTB
88
+
89
+ # Install dependencies
90
+ pip install -r requirements.txt
91
+
92
+ # Install PyTorch and torchvision with the appropriate CUDA version
93
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
94
+ ```
95
+
96
+ For the **full MMCV interface** to perform evaluation on segmentation benchmarks, please refer to the [original Talk2DINO repository](https://github.com/lorebianchi98/Talk2DINO).
97
+
98
+
99
+
100
+ <details>
101
+ <summary>Qualitative Results</summary>
102
+
103
+ | **Image** | **Ground Truth** | **FreeDA** | **ProxyCLIP** | **CLIP-DINOiser** | **Ours (Talk2DINO)** |
104
+ |-----------|------------------|------------|---------------|-------------------|------------------|
105
+ | ![Image](assets/qualitatives/voc/2_img.jpg) | ![Ground Truth](assets/qualitatives/voc/2_gt.png) | ![FreeDA](assets/qualitatives/voc/2_freeda.png) | ![ProxyCLIP](assets/qualitatives/voc/2_proxy.png) | ![CLIP-DINOiser](assets/qualitatives/voc/2_clipdinoiser.png) | ![Ours](assets/qualitatives/voc/2_talk2dino.png) |
106
+ | ![Image](assets/qualitatives/object/2r_img.png) | ![Ground Truth](assets/qualitatives/object/2r_gt.png) | ![FreeDA](assets/qualitatives/object/2r_freeda.png) | ![ProxyCLIP](assets/qualitatives/object/2r_proxy.png) | ![CLIP-DINOiser](assets/qualitatives/object/2r_clipdinoiser.png) | ![Ours](assets/qualitatives/object/2r_talk2dino.png) |
107
+ | ![Image](assets/qualitatives/cityscapes/1r_image.png) | ![Ground Truth](assets/qualitatives/cityscapes/1r_gt.png) | ![FreeDA](assets/qualitatives/cityscapes/1r_freeda.png) | ![ProxyCLIP](assets/qualitatives/cityscapes/1r_proxyclip.png) | ![CLIP-DINOiser](assets/qualitatives/cityscapes/1r_clipdinoiser.png) | ![Ours](assets/qualitatives/cityscapes/1r_talk2dino.png) |
108
+ | ![Image](assets/qualitatives/context/1r_img.png) | ![Ground Truth](assets/qualitatives/context/1r_gt.png) | ![FreeDA](assets/qualitatives/context/1r_freeda.png) | ![ProxyCLIP](assets/qualitatives/context/1r_proxy.png) | ![CLIP-DINOiser](assets/qualitatives/context/1r_clipdinoiser.png) | ![Ours](assets/qualitatives/context/1r_talk2dino.png) |
109
+ </details>
110
+
111
+
112
+ ## Reference
113
+ If you found this code useful, please cite the following paper:
114
+ ```
115
+ @misc{barsellotti2024talkingdinobridgingselfsupervised,
116
+ title={Talking to DINO: Bridging Self-Supervised Vision Backbones with Language for Open-Vocabulary Segmentation},
117
+ author={Luca Barsellotti and Lorenzo Bianchi and Nicola Messina and Fabio Carrara and Marcella Cornia and Lorenzo Baraldi and Fabrizio Falchi and Rita Cucchiara},
118
+ year={2024},
119
+ eprint={2411.19331},
120
+ archivePrefix={arXiv},
121
+ primaryClass={cs.CV},
122
+ url={https://arxiv.org/abs/2411.19331},
123
+ }
124
+ ```
assets/overview.png ADDED

Git LFS Details

  • SHA256: fcefc8c68cf95a966f769852ea51e7efa7ea2398b21936cacaa2eb5c6fff0358
  • Pointer size: 130 Bytes
  • Size of remote file: 89.5 kB
assets/pikachu.png ADDED

Git LFS Details

  • SHA256: 7a5efcbce11e4a293ebb743c8857c0654c6bce0b89beb59f6ca71d64311c4106
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB
assets/pikachu_seg.png ADDED

Git LFS Details

  • SHA256: 4b200c8a069a9d277073989a1bf3398a432fc47ade91ab3e90599fa04e0db33b
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
assets/qualitatives.png ADDED

Git LFS Details

  • SHA256: 7aafda1e9b4816d125c7a9a2294da44ca08f8be90b496e2bc1f72a58c5fbc859
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
assets/qualitatives/cityscapes/1_clipdinoiser.png ADDED

Git LFS Details

  • SHA256: dc7d50518fa3fb82c9ffe101c37ef71430636065d9b3287831d5deac74d2e958
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
assets/qualitatives/cityscapes/1_freeda.png ADDED

Git LFS Details

  • SHA256: 021db152eea34d9140c2c113fcbc4883c4a8a2b7714c5c6ae26f8494addc5d4b
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
assets/qualitatives/cityscapes/1_gt.png ADDED

Git LFS Details

  • SHA256: d4319e6ff02331c3bd80f23ffa2a15c99c4381a60770f2a147f3ac3410a4d4c1
  • Pointer size: 131 Bytes
  • Size of remote file: 215 kB
assets/qualitatives/cityscapes/1_image.png ADDED

Git LFS Details

  • SHA256: e78e72e601e9113dfc900743c7f1ea483c8cf86cd424789dded0d50f90e2a4c1
  • Pointer size: 131 Bytes
  • Size of remote file: 732 kB
assets/qualitatives/cityscapes/1_proxyclip.png ADDED

Git LFS Details

  • SHA256: 110c7e9507e95e5ecc0c0dedbbed9fe43fb3cd92961a3edb710f79a666088035
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
assets/qualitatives/cityscapes/1_talk2dino.png ADDED

Git LFS Details

  • SHA256: c0fe49fd3056deca04dd67146101b910e20286312d2dcccac59536ff37082e1d
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
assets/qualitatives/cityscapes/1r_clipdinoiser.png ADDED

Git LFS Details

  • SHA256: 850ff27eddf5ed958440f76b51b59551c40ff53527349292c5d3f5e6784f966c
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
assets/qualitatives/cityscapes/1r_freeda.png ADDED

Git LFS Details

  • SHA256: 97c3d6a6ff0ca34429727ca57d86f5db97473c67f035f364091b09a542869419
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB
assets/qualitatives/cityscapes/1r_gt.png ADDED

Git LFS Details

  • SHA256: a5f53f91cbba23d5a9865f5674f3827e6abec5165d68f1d799abbcfdb31aa148
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
assets/qualitatives/cityscapes/1r_image.png ADDED

Git LFS Details

  • SHA256: 441e3846e7b9adf5b6f66e394f6c2988c12cd306dc20a21488b0080cb13f94c0
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
assets/qualitatives/cityscapes/1r_proxyclip.png ADDED

Git LFS Details

  • SHA256: 565024c62febde5ecbb90c428914341b100a7ec6a8cd4045b35b427d87acba65
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
assets/qualitatives/cityscapes/1r_talk2dino.png ADDED

Git LFS Details

  • SHA256: 7737bdf861eaeb8b7595279d3cd4ba5eb8b3f14e6b3bc45f275430005f783426
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
assets/qualitatives/context/1r_clipdinoiser.png ADDED

Git LFS Details

  • SHA256: e494d3a59eef1360f76756c2b6e5c240fcdf2e964328c95d31166ae538733c00
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
assets/qualitatives/context/1r_freeda.png ADDED

Git LFS Details

  • SHA256: 897c90477ce48ca384952138f2cbafc9273e18e855592828cc3d61621d33a4a5
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB
assets/qualitatives/context/1r_gt.png ADDED

Git LFS Details

  • SHA256: 3a32121bae796e860c460c8d10fd15e0bf89e396aa498af4bedcab31d467eab3
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
assets/qualitatives/context/1r_img.png ADDED

Git LFS Details

  • SHA256: 782c879d32d409d87bfb8458f815f780f3ce0c4320eac50bd80877568049a129
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
assets/qualitatives/context/1r_proxy.png ADDED

Git LFS Details

  • SHA256: dbdc03938f29c23c7e2adf04a2b4aecc4ca4897ed316f50412ef07c926cbea8b
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
assets/qualitatives/context/1r_talk2dino.png ADDED

Git LFS Details

  • SHA256: 7b24878e048c1024795dccf81a8ff2efce7d7fc410adc02e8d0270fc448b8da4
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
assets/qualitatives/object/2r_clipdinoiser.png ADDED

Git LFS Details

  • SHA256: 947b12f655f028f0dbf1028531d8a3a79792d6d8ad476c30ea33adef3ef49e67
  • Pointer size: 131 Bytes
  • Size of remote file: 259 kB
assets/qualitatives/object/2r_freeda.png ADDED

Git LFS Details

  • SHA256: f4e02e1578410da7db1d5709d568d9597f915da8b6783f4352cb96da3fc27d5d
  • Pointer size: 131 Bytes
  • Size of remote file: 261 kB
assets/qualitatives/object/2r_gt.png ADDED

Git LFS Details

  • SHA256: 3202de65eb235c709a9a5bfcd9275b2786c7802cdc5014eee9d33533e5cce4a0
  • Pointer size: 131 Bytes
  • Size of remote file: 258 kB
assets/qualitatives/object/2r_img.png ADDED

Git LFS Details

  • SHA256: 20437921ab37fd684adf5af51d26016d9e081a1c80f2b8f5a4572f2a8d699a7e
  • Pointer size: 131 Bytes
  • Size of remote file: 345 kB
assets/qualitatives/object/2r_proxy.png ADDED

Git LFS Details

  • SHA256: 3d2bbdf80167d8e2924ffffe5f097bdcafd046de801c126cfb4fd71727aef995
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
assets/qualitatives/object/2r_talk2dino.png ADDED

Git LFS Details

  • SHA256: 97dec62cdb8919aa30325a0a2bf15d07d2f7c644829b5183f97e21faa6fdcd4e
  • Pointer size: 131 Bytes
  • Size of remote file: 258 kB
assets/qualitatives/voc/1_clipdinoiser.png ADDED

Git LFS Details

  • SHA256: 1b77e7692a0f3d70636a9a5b0efc6718216af5c762b9a96b68ec91fb9e0570f7
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
assets/qualitatives/voc/1_freeda.png ADDED

Git LFS Details

  • SHA256: d849bab0384a9d661a32a87faea349d047649d0272df9e16f1994ce7f9b51b3c
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
assets/qualitatives/voc/1_gt.png ADDED

Git LFS Details

  • SHA256: d3060aabbe6a4fa6fd1895f96d9f09c4e66537c94ce59bd4b37a3a1e11825c94
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
assets/qualitatives/voc/1_img.jpg ADDED
assets/qualitatives/voc/1_proxy.png ADDED

Git LFS Details

  • SHA256: e88055ef402a7ec39ec582fd9d5fe5e85d4dc092e981ee1291d999c4709c435e
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
assets/qualitatives/voc/1_talk2dino.png ADDED

Git LFS Details

  • SHA256: a4dd7ba3df8cc21bb9dafc5eabe545f8af2619d096ed42da04461c0d0a35fc7d
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
assets/qualitatives/voc/2_clipdinoiser.png ADDED

Git LFS Details

  • SHA256: 731afb02c5190b7cbacf0c18a0f3cedac5b47903d78c0aef8afa7c33d9fe4ec8
  • Pointer size: 131 Bytes
  • Size of remote file: 311 kB
assets/qualitatives/voc/2_freeda.png ADDED

Git LFS Details

  • SHA256: d790a1f9ae79b83fc761f33a4821558dd1c2f3251cbb3e702aa46d73a89dd823
  • Pointer size: 131 Bytes
  • Size of remote file: 308 kB
assets/qualitatives/voc/2_gt.png ADDED

Git LFS Details

  • SHA256: 0d7a9abef94bd64a4452c837749d5d97770a10da3552c90549b0ae9c59b9754c
  • Pointer size: 131 Bytes
  • Size of remote file: 308 kB
assets/qualitatives/voc/2_img.jpg ADDED
assets/qualitatives/voc/2_proxy.png ADDED

Git LFS Details

  • SHA256: 0271de330e49dfdc82d934f808d19e81d8c0f2ec0d0c8bc9149ec26e59ea4237
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
assets/qualitatives/voc/2_talk2dino.png ADDED

Git LFS Details

  • SHA256: addb7639db5064a086666a76345a13ea74bceefb33870bb7a74ee3a540f9d361
  • Pointer size: 131 Bytes
  • Size of remote file: 312 kB
config.json CHANGED
@@ -1,4 +1,10 @@
1
  {
 
 
 
 
 
 
2
  "avg_self_attn_token": false,
3
  "clip_model_name": "ViT-B/16",
4
  "disentangled_self_attn_token": true,
 
1
  {
2
+ "architectures": ["Talk2DINO"],
3
+ "model_type": "talk2dino",
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_talk2dino.Talk2DINOConfig",
6
+ "AutoModel": "modeling_talk2dino.Talk2DINO"
7
+ },
8
  "avg_self_attn_token": false,
9
  "clip_model_name": "ViT-B/16",
10
  "disentangled_self_attn_token": true,
configuration_talk2dino.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+
4
+ class Talk2DINOConfig(PretrainedConfig):
5
+ model_type = "talk2dino"
6
+
7
+ def __init__(
8
+ self,
9
+ avg_self_attn_token=False,
10
+ clip_model_name="ViT-B/16",
11
+ disentangled_self_attn_token=True,
12
+ is_eval=True,
13
+ keep_cls=False,
14
+ keep_end_seq=False,
15
+ loss=None,
16
+ model_name="dinov2_vitb14_reg",
17
+ pre_trained=True,
18
+ proj_class="vitb_mlp_infonce",
19
+ proj_model="ProjectionLayer",
20
+ proj_name="vitb_mlp_infonce",
21
+ resize_dim=518,
22
+ type="DINOText",
23
+ unfreeze_last_image_layer=False,
24
+ unfreeze_last_text_layer=False,
25
+ use_avg_text_token=False,
26
+ with_bg_clean=False,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+
31
+ # Store all parameters
32
+ self.avg_self_attn_token = avg_self_attn_token
33
+ self.clip_model_name = clip_model_name
34
+ self.disentangled_self_attn_token = disentangled_self_attn_token
35
+ self.is_eval = is_eval
36
+ self.keep_cls = keep_cls
37
+ self.keep_end_seq = keep_end_seq
38
+ self.loss = loss
39
+ self.model_name = model_name
40
+ self.pre_trained = pre_trained
41
+ self.proj_class = proj_class
42
+ self.proj_model = proj_model
43
+ self.proj_name = proj_name
44
+ self.resize_dim = resize_dim
45
+ self.type = type
46
+ self.unfreeze_last_image_layer = unfreeze_last_image_layer
47
+ self.unfreeze_last_text_layer = unfreeze_last_text_layer
48
+ self.use_avg_text_token = use_avg_text_token
49
+ self.with_bg_clean = with_bg_clean
dinotext.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import pickle
4
+ from math import sqrt
5
+ import re
6
+ import yaml
7
+
8
+ import numpy as np
9
+ import timm
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torchvision
14
+ from einops import rearrange
15
+ from transformers import BertModel, AutoTokenizer
16
+ import torchvision.transforms as T
17
+ import clip
18
+ import importlib
19
+ from .us import normalize
20
+
21
+ from .pamr import PAMR
22
+ from .masker import DINOTextMasker
23
+ from .templates import get_template
24
+
25
+ from .model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
26
+ from .hooks import average_text_tokens, get_vit_out, feats
27
+
28
+
29
+
30
+
31
+ class DINOText(nn.Module):
32
+
33
+ def get_self_attention(self, module, input, output):
34
+ self.feats['self_attn'] = output
35
+
36
+ def get_clip_second_last_dense_out(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
37
+ self.feats['clip_second_last_out'] = output
38
+ self.feats['clip_second_last_out'].to(dtype=torch.float32)
39
+
40
+ def get_all_out_tokens(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
41
+ self.feats['clip_txt_out_tokens'] = output
42
+
43
+ def __init__(
44
+ self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True,
45
+ unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs
46
+ ):
47
+ nn.Module.__init__(self)
48
+ self.feats = {}
49
+ self.model_name = model_name
50
+ # loading the model
51
+
52
+ if 'dinov2' in model_name:
53
+ self.model_family = 'facebookresearch/dinov2' if 'dinov2' in model_name else 'facebookresearch/dino:main'
54
+ self.model = torch.hub.load(self.model_family, model_name)
55
+
56
+ elif 'mae' in model_name or 'sam' in model_name or 'clip' in model_name or 'dino' in model_name:
57
+ self.model = timm.create_model(
58
+ model_name,
59
+ pretrained=True,
60
+ num_classes=0, # remove classifier nn.Linear
61
+ img_size=resize_dim
62
+ )
63
+
64
+ if 'sam' in model_name:
65
+ self.model.blocks[-1].register_forward_hook(get_vit_out)
66
+ else:
67
+ raise Exception("Unknown ViT model")
68
+ # self.model.eval()
69
+ mean = (0.485, 0.456, 0.406) if not 'clip' in model_name else (0.4815, 0.4578, 0.4082)
70
+ std = (0.229, 0.224, 0.225) if not 'clip' in model_name else (0.2686, 0.2613, 0.2758)
71
+ self.image_transforms = T.Compose([
72
+ T.Resize((resize_dim, resize_dim)),
73
+ lambda x: T.ToTensor()(x) if not isinstance(x, torch.Tensor) else x / 255.0, # ensure tensor
74
+ T.Normalize(mean, std),
75
+ ])
76
+
77
+ self.model
78
+ self.model.requires_grad_(False)
79
+
80
+ self.clip_model_name = clip_model_name
81
+ if 'bert' in self.clip_model_name:
82
+ self.clip_model = BertModel.from_pretrained(self.clip_model_name, output_hidden_states = False)
83
+ # load the corresponding wordtokenizer
84
+ self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name)
85
+ else:
86
+ self.clip_model, _ = clip.load(clip_model_name, device='meta')
87
+ self.clip_model.eval()
88
+ self.clip_model.requires_grad_(False)
89
+ if unfreeze_last_text_layer:
90
+ for param in self.clip_model.transformer.resblocks[-1].parameters():
91
+ param.requires_grad = True
92
+ for param in self.clip_model.ln_final.parameters():
93
+ param.requires_grad = True
94
+ self.clip_model.text_projection.requires_grad = True
95
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
96
+
97
+
98
+ if 'vitb_mlp_infonce' in proj_class:
99
+ config = {
100
+ 'act': 'tanh', # None, tanh, relu or sigmoid
101
+ 'hidden_layer': True,
102
+ 'dino_embed_dim': 768
103
+ }
104
+ elif 'vitl_mlp_infonce' in proj_class:
105
+ config = {
106
+ 'act': 'tanh', # None, tanh, relu or sigmoid
107
+ 'hidden_layer': True,
108
+ 'dino_embed_dim': 1024
109
+ }
110
+
111
+ ProjClass = ProjectionLayer
112
+ self.proj = ProjClass.from_config(config)
113
+
114
+ self.masker = DINOTextMasker(similarity_type="cosine")
115
+ self.masker = self.masker.eval()
116
+
117
+ self.pamr = None
118
+
119
+ self.avg_self_attn_token = avg_self_attn_token
120
+ self.disentangled_self_attn_token = disentangled_self_attn_token
121
+
122
+ if self.avg_self_attn_token or self.disentangled_self_attn_token or is_eval:
123
+ self.model.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention)
124
+ self.num_global_tokens = 5 if 'reg' in model_name or 'dinov3' in model_name else 1
125
+ if 'sam' in self.model_name:
126
+ self.num_global_tokens = 0
127
+ if 'dinov3' in self.model_name:
128
+ if 'vit_base' in self.model_name:
129
+ self.num_attn_heads = 12
130
+ elif 'vit_large' in self.model_name:
131
+ self.num_attn_heads = 16
132
+ else:
133
+ raise Exception("Unknown dinov3 model")
134
+ else:
135
+ self.num_attn_heads = self.model.num_heads
136
+ self.scale = 0.125
137
+
138
+ self.use_avg_text_token = use_avg_text_token
139
+ if self.use_avg_text_token:
140
+ self.feats = {}
141
+ # in this case we register a forward hook with the aim of getting all the tokens and not only the cls
142
+ self.clip_model.ln_final.register_forward_hook(self.get_all_out_tokens)
143
+ self.keep_cls = keep_cls
144
+ self.keep_end_seq = keep_end_seq
145
+
146
+ self.with_bg_clean = with_bg_clean
147
+
148
+
149
+ def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
150
+ qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
151
+ q, k, v = qkv[0] * scale, qkv[1], qkv[2]
152
+ attn = q @ k.transpose(-2, -1)
153
+ self_attn_maps = attn[:, : , 0, num_global_tokens:]
154
+ self_attn = self_attn_maps.mean(dim=1)
155
+ self_attn = self_attn.softmax(dim=-1)
156
+ if ret_self_attn_maps:
157
+ return self_attn, self_attn_maps
158
+ else:
159
+ return self_attn
160
+
161
+ def encode_text(self, tokenized_texts):
162
+ if type(self.proj) == CLIPLastLayer:
163
+ self.clip_model.encode_text(tokenized_texts)
164
+ x = self.feats['clip_second_last_out']
165
+ x = x.to(dtype=torch.float32)
166
+ else:
167
+ x = self.clip_model.encode_text(tokenized_texts)
168
+ return x
169
+
170
+ def encode_image(self, images):
171
+ batch_size, _, _, _ = images.shape
172
+ self_attn_maps = None
173
+ x = self.model(images, is_training=(self.avg_self_attn_token or self.disentangled_self_attn_token))
174
+ batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
175
+ num_tokens = num_tokens + self.num_global_tokens
176
+ if self.avg_self_attn_token or self.disentangled_self_attn_token:
177
+ self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
178
+ if self.avg_self_attn_token:
179
+ x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1)
180
+ elif self.disentangled_self_attn_token:
181
+ self_attn_maps = self_attn_maps.softmax(dim=-1)
182
+ x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2)
183
+
184
+ return x, self_attn_maps
185
+
186
+ def forward(self, image, text, return_logit_scale=False):
187
+ with torch.no_grad():
188
+ txt_embed = self.encode_text(text)
189
+
190
+ img_embed, self_attn_maps = self.encode_image(image)
191
+
192
+ if type(self.proj) == CLIPLastLayer:
193
+ img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps, text_argmax=text.argmax(dim=-1))
194
+ else:
195
+ img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps)
196
+
197
+ if return_logit_scale:
198
+ return txt_embed, img_embed, self.logit_scale
199
+
200
+ return txt_embed, img_embed
201
+
202
+ def compute_loss(self, image, text, cosine=True, ret_similarity_matrix=True):
203
+ ret = {}
204
+ if cosine:
205
+ img_embed = F.normalize(img_embed, p=2, dim=1)
206
+ txt_embed = F.normalize(txt_embed, p=2, dim=1)
207
+ sim = img_embed @ txt_embed.transpose(1, 0)
208
+ if not ret_similarity_matrix:
209
+ sim = sim[torch.eye(len(sim)) > 0.5] # only diagonal elements
210
+
211
+ ret['contrastive_loss'] = self.contrastive_loss.compute_contrastive_loss(sim)
212
+
213
+ return ret
214
+
215
+
216
+ @torch.no_grad()
217
+ def build_dataset_class_tokens(self, template_set, classnames):
218
+ tokens = []
219
+ templates = get_template(template_set)
220
+ for classname in classnames:
221
+ if 'bert' not in self.clip_model_name:
222
+ tokens.append(
223
+ clip.tokenize([template.format(classname) for template in templates])
224
+ )
225
+ else:
226
+ tokens.append(self.tokenizer([template.format(classname) for template in templates], return_tensors='pt', padding='max_length')['input_ids'])
227
+ # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
228
+ tokens = torch.stack(tokens)
229
+
230
+ return tokens
231
+
232
+ @torch.no_grad()
233
+ def build_text_embedding(self, text):
234
+ """
235
+ Args:
236
+ text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH] text tokens
237
+
238
+ Returns:
239
+ text_embs
240
+ """
241
+ text = text.to(next(self.parameters()).device)
242
+ num_classes, num_templates = text.shape[:2]
243
+ text_argmax = text.argmax(dim=-1)
244
+ text_argmax = rearrange(text_argmax, 'n t -> (n t)', n=num_classes, t=num_templates)
245
+ text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
246
+ # chunked inference for memory limitation
247
+ chunk_size = 32
248
+ N = text.size(0)
249
+ if type(self.proj) == CLIPLastLayer:
250
+ text_embs = torch.cat([
251
+ self.proj.project_clip_txt(self.encode_text(text[i:i + chunk_size]).permute(1, 0, 2), text_argmax=text_argmax[i:i + chunk_size])
252
+ for i in range(0, N, chunk_size)
253
+ ])
254
+ else:
255
+ if not self.use_avg_text_token:
256
+ # performing classification using CLS textual token
257
+ if 'bert' not in self.clip_model_name:
258
+ text_embs = torch.cat([
259
+ self.clip_model.encode_text(text[i:i + chunk_size])
260
+ for i in range(0, N, chunk_size)
261
+ ])
262
+ else:
263
+ # encoding with BERT
264
+ text_embs = []
265
+ for i in range(0, N, chunk_size):
266
+ outputs = self.clip_model(text[i:i + chunk_size])
267
+ text_embs.append(outputs['pooler_output'])
268
+ text_embs = torch.cat(text_embs)
269
+ else:
270
+ # using text token average
271
+ text_embs = []
272
+ for i in range(0, N, chunk_size):
273
+ self.clip_model.encode_text(text[i:i + chunk_size])
274
+ text_embs.append(average_text_tokens(self.feats['clip_txt_out_tokens'] @ self.clip_model.text_projection, text[i:i + chunk_size] > 0, self.keep_cls, self.keep_end_seq))
275
+ text_embs = torch.cat(text_embs)
276
+ # [N, T, C]
277
+ text_embs = rearrange(text_embs, '(n t) c -> n t c', n=num_classes, t=num_templates)
278
+ # [N, C]
279
+ text_embs = text_embs.mean(dim=1).float()
280
+ if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
281
+ text_embs = self.proj.project_clip_txt(text_embs)
282
+ text_embs = normalize(text_embs, dim=-1)
283
+
284
+ return text_embs
285
+
286
+ def apply_pamr(self, image, mask):
287
+ image = F.interpolate(image, mask.shape[-2:], mode="bilinear", align_corners=True)
288
+ if self.pamr is None:
289
+ pamr_iter = 10
290
+ pamr_kernel = [1, 2, 4, 8, 12, 24]
291
+ self.pamr = PAMR(pamr_iter, pamr_kernel)
292
+ self.pamr.eval()
293
+ self.pamr.to(next(self.parameters()).device)
294
+
295
+ mask = self.pamr(image, mask)
296
+ return mask
297
+
298
+ def compute_padsize(self, H: int, W: int, patch_size: int):
299
+ l, r, t, b = 0, 0, 0, 0
300
+ if W % patch_size:
301
+ lr = patch_size - (W % patch_size)
302
+ l = lr // 2
303
+ r = lr - l
304
+
305
+ if H % patch_size:
306
+ tb = patch_size - (H % patch_size)
307
+ t = tb // 2
308
+ b = tb - t
309
+
310
+ return l, r, t, b
311
+
312
+ @torch.no_grad()
313
+ def generate_masks(
314
+ self, image, img_metas, text_emb, classnames, text_is_token=False, apply_pamr=False, background_func="weighted_average_sigmoid", lambda_bg=0.2,
315
+ # kp_w=0.3,
316
+ ):
317
+ """Generate masks for each text embeddings
318
+
319
+ Args:
320
+ image [B, 3, H, W]
321
+
322
+ Returns:
323
+ softmask [B, N, H, W]: softmasks for each text embeddings
324
+ """
325
+
326
+ H, W = image.shape[2:] # original image shape
327
+
328
+ # padded image size
329
+ pH, pW = image.shape[2:]
330
+ num_classes = text_emb.shape[0]
331
+ batch_size = image.shape[0]
332
+
333
+ image = image[:, [2, 1, 0], :, :] # BGR to RGB
334
+ ori_image = image.clone()
335
+
336
+ img_preprocessed = self.image_transforms(image).to(next(self.parameters()).device)
337
+ if 'dinov2' in self.model_name:
338
+ image_feat = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
339
+ elif 'dinov3' in self.model_name:
340
+ image_feat = self.model.forward_features(img_preprocessed)[:, 5:, :]
341
+ elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
342
+ image_feat = self.model.forward_features(img_preprocessed)[:, 1:, :]
343
+ elif 'sam' in self.model_name:
344
+ self.model.forward_features(img_preprocessed)
345
+ image_feat = feats['vit_out'].reshape(feats['vit_out'].shape[0], feats['vit_out'].shape[1]**2, feats['vit_out'].shape[-1]) # BS x N_PATCHES x EMBED_DIM
346
+
347
+ batch_size, num_tokens, embed_dim = image_feat.shape
348
+ if type(self.proj) == VisualProjectionLayer:
349
+ image_feat = self.proj.project_dino(image_feat.float())
350
+ if type(self.proj) == DoubleMLP:
351
+ image_feat = self.proj.project_visual(image_feat.float())
352
+ b, np, c = image_feat.shape
353
+ np_h = np_w = int(sqrt(np))
354
+ image_feat = image_feat.reshape(b, np_h, np_w, c).permute(0, 3, 1, 2)
355
+
356
+ self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens + self.num_global_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
357
+ mask, simmap = self.masker.forward_seg(image_feat, text_emb, hard=False) # [B, N, H', W']
358
+
359
+ if self.with_bg_clean:
360
+ mask = self.similarity_assignment_weighted(mask, image_feat, self_attn_maps, text_emb, lambda_bg)
361
+
362
+ # resize
363
+ mask = F.interpolate(mask, (pH, pW), mode='bilinear', align_corners=True) # [B, N, H, W]
364
+
365
+ if apply_pamr:
366
+ for c in range(0, mask.shape[1], 30):
367
+ mask[:, c:c + 30] = self.apply_pamr(ori_image, mask[:, c:c + 30])
368
+
369
+ assert mask.shape[2] == H and mask.shape[3] == W, f"shape mismatch: ({H}, {W}) / {mask.shape}"
370
+
371
+ return mask, simmap
372
+
373
+ def similarity_assignment_weighted(self, mask, image_feat, self_attn_maps, text_emb, lambda_bg=0.2):
374
+ bs, c, h, w = image_feat.shape
375
+ bs, num_classes, h, w = mask.shape
376
+ bs, num_heads, hw = self_attn_maps.shape
377
+ image_feat = image_feat.reshape(bs, c, hw)
378
+ num_classes, c = text_emb.shape
379
+ avg_head_embed = (self_attn_maps.unsqueeze(2) * image_feat.unsqueeze(1)).mean(dim=-1)
380
+ avg_head_embed = avg_head_embed / avg_head_embed.norm(dim=-1, keepdim=True)
381
+ avg_head_embed = avg_head_embed.permute(0, 2, 1) # [B, C, M]
382
+ head_text_sim = text_emb.unsqueeze(0) @ avg_head_embed # [B, M, N]
383
+ head_text_sim = (head_text_sim).softmax(dim=-1)
384
+ head_text_sim_sum = head_text_sim.sum(dim=-1)
385
+
386
+ self_attn_maps_repeat = self_attn_maps.unsqueeze(1).repeat(1, num_classes, 1, 1)
387
+ head_text_sim_repeat = head_text_sim.unsqueeze(-1).repeat(1, 1, 1, hw)
388
+ avg_self_attn_per_class = (self_attn_maps_repeat * head_text_sim_repeat).sum(dim=2) / head_text_sim_sum.unsqueeze(-1).repeat(1, 1, hw)
389
+ avg_self_attn_per_class = avg_self_attn_per_class.softmax(dim=-1)
390
+
391
+ min_self_attn = avg_self_attn_per_class.min().item()
392
+ max_self_attn = avg_self_attn_per_class.max().item()
393
+ max_self_attn = max(max_self_attn, max_self_attn - min_self_attn)
394
+ avg_self_attn_per_class = avg_self_attn_per_class - min_self_attn
395
+ avg_self_attn_per_class = avg_self_attn_per_class / max_self_attn
396
+ avg_self_attn_per_class = avg_self_attn_per_class * (mask.max() - mask.min()) + mask.min()
397
+ mask = mask.reshape(num_classes, hw) # [N, P]
398
+ mask_output = (mask + lambda_bg * avg_self_attn_per_class).reshape(bs, num_classes, h, w) / (1 + lambda_bg)
399
+ return mask_output
hf_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
hooks.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ feats = {}
3
+ def get_self_attention(module, input, output):
4
+ feats['self_attn'] = output
5
+
6
+ def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
7
+ qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
8
+ q, k, v = qkv[0] * scale, qkv[1], qkv[2]
9
+ attn = q @ k.transpose(-2, -1)
10
+ self_attn_maps = attn[:, : , 0, num_global_tokens:]
11
+ self_attn = self_attn_maps.mean(dim=1)
12
+ self_attn = self_attn.softmax(dim=-1)
13
+ if ret_self_attn_maps:
14
+ return self_attn, self_attn_maps
15
+ else:
16
+ return self_attn
17
+
18
+ def get_vit_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
19
+ feats['vit_out'] = output
20
+
21
+ def get_second_last_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
22
+ feats['second_last_out'] = output
23
+
24
+ def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
25
+ feats['clip_txt_out_tokens'] = output
26
+
27
+ def get_clip_second_last_dense_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
28
+ feats['clip_second_last_out'] = output.permute(1,0,2)
29
+
30
+ def get_dinov1_patches(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
31
+ feats['dinov1_patches'] = output
32
+
33
+ def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
34
+ feats['clip_txt_out_tokens'] = output
35
+
36
+ def average_text_tokens(text_embeddings, mask, keep_cls=False, keep_end_seq=False):
37
+ if not keep_end_seq:
38
+ mask[torch.arange(mask.shape[0]), mask.sum(dim=1) - 1] = False # excluding end of sequence
39
+ if not keep_cls:
40
+ mask[:, 0] = False # excluding CLS token
41
+
42
+
43
+ masked_embeddings = text_embeddings * mask.unsqueeze(-1) # shape: [BS, SEQ_LEN, 512]
44
+
45
+ sum_embeddings = masked_embeddings.sum(dim=1) # shape: [BS, 512]
46
+
47
+ valid_elements = mask.sum(dim=1, keepdim=True) # shape: [BS, 1]
48
+
49
+ mean_embeddings = sum_embeddings / valid_elements # shape: [BS, 512]
50
+
51
+ return mean_embeddings
52
+
masker.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Talk2DINO
3
+ # ------------------------------------------------------------------------------
4
+ import copy
5
+ from collections import OrderedDict
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from .us import normalize
12
+ from einops import rearrange, repeat
13
+
14
+ # from models.dinotext.gumbel import gumbel_sigmoid
15
+ from .modules import FeatureEncoder
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ def build_model(config):
20
+ model = OmegaConf.to_container(config, resolve=True)
21
+ return model
22
+
23
+ class Sim2Mask(nn.Module):
24
+ def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True):
25
+ super().__init__()
26
+ self.init_w = init_w
27
+ self.init_b = init_b
28
+ self.gumbel_tau = gumbel_tau
29
+ self.learnable = learnable
30
+
31
+ assert not ((init_w is None) ^ (init_b is None))
32
+ if learnable:
33
+ self.w = nn.Parameter(torch.full([], float(init_w)))
34
+ self.b = nn.Parameter(torch.full([], float(init_b)))
35
+ else:
36
+ self.w = init_w
37
+ self.b = init_b
38
+
39
+ def forward(self, x, deterministic=False):
40
+ logits = x * self.w + self.b
41
+
42
+ soft_mask = torch.sigmoid(logits)
43
+ if deterministic:
44
+ hard_mask = soft_mask.gt(0.5).type(logits.dtype)
45
+ else:
46
+ hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
47
+
48
+ return hard_mask, soft_mask
49
+
50
+ def extra_repr(self):
51
+ return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
52
+
53
+
54
+ class MaskerBackbone(nn.Module):
55
+ """Masker image encoder backbone.
56
+ """
57
+ def __init__(self, clip_visual, freeze_idx):
58
+ super().__init__()
59
+ self.transformer = copy.deepcopy(clip_visual.transformer)
60
+ self.transformer.resblocks = self.transformer.resblocks[freeze_idx:]
61
+
62
+ for block in self.transformer.resblocks:
63
+ if hasattr(block, "hook_handler"):
64
+ block.hook_handler.remove()
65
+
66
+ self.ln_post = copy.deepcopy(clip_visual.ln_post)
67
+ self.proj = copy.deepcopy(clip_visual.proj)
68
+
69
+ self.layers = len(self.transformer.resblocks)
70
+ self.patch_size = clip_visual.patch_size
71
+
72
+ self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width
73
+
74
+ def forward(self, x, spatial=True, ignore_last_attn=True):
75
+ if self.layers:
76
+ x = self.transformer(x, ignore_last_attn=ignore_last_attn)
77
+
78
+ x = x.permute(1, 0, 2) # LND -> NLD
79
+
80
+ if spatial:
81
+ x = self.ln_post(x)
82
+ else:
83
+ x = self.ln_post(x[:, 0, :])
84
+
85
+ if self.proj is not None:
86
+ x = x @ self.proj
87
+
88
+ return x
89
+
90
+ class MaskerImageFeatureEncoder(FeatureEncoder):
91
+ def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True):
92
+ super().__init__()
93
+ self.ignore_last_attn = ignore_last_attn
94
+ self.patch_size = backbone.patch_size
95
+ self.backbone = backbone
96
+ self.decoder = decoder
97
+
98
+ for resblock in self.backbone.transformer.resblocks:
99
+ resblock.hook_handler = resblock.register_forward_hook(self.hook)
100
+
101
+ def _encode(self, image, image_feat):
102
+ H, W = image.shape[-2:]
103
+ h = H // self.patch_size
104
+ w = W // self.patch_size
105
+
106
+ x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn) # BLC
107
+ x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w)
108
+ x = self.decoder(x)
109
+
110
+ return x
111
+
112
+ class Masker(nn.Module):
113
+ def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs):
114
+ super().__init__()
115
+ self.ignore_last_attn = ignore_last_attn
116
+
117
+ decoder["C"] = backbone.output_dim
118
+ decoder = MODELS.build(decoder)
119
+ decoder = nn.Sequential(OrderedDict([
120
+ ("decoder", decoder),
121
+ ("image_proj", image_proj)
122
+ ]))
123
+
124
+ self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn)
125
+
126
+ self.sim2mask = Sim2Mask(**sim2mask)
127
+
128
+ def forward(self, image, image_feat, text_emb, deterministic=False):
129
+ B = image.size(0)
130
+ image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
131
+
132
+ image_emb_norm = normalize(image_emb, dim=1)
133
+ text_emb_norm = normalize(text_emb, dim=-1)
134
+
135
+ H, W = image_emb.shape[2:]
136
+ D = dist.get_world_size()
137
+
138
+ # simmap [B, B*D, H, W] where D is #devices
139
+ all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
140
+ simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
141
+ mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
142
+
143
+ # mask [B, B*D, H, W] where D is #devices
144
+ # positive global label
145
+ pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank()
146
+ pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1) # [B, 1, H, W]
147
+
148
+ offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device)
149
+ offdiag[torch.arange(B), pos_indices] = False
150
+
151
+ soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1)
152
+ soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W)
153
+
154
+ masks = {
155
+ "pos": pos_mask, # [B, 1, H, W]
156
+
157
+ "soft_pos": soft_pos_mask,
158
+ "soft_neg": soft_neg_mask,
159
+ "soft_all": soft_mask, # [B, N, H, W]
160
+ }
161
+
162
+ return masks, image_emb, text_emb, feats
163
+
164
+ @torch.no_grad()
165
+ def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False):
166
+ """Make mask by 1:N matching
167
+
168
+ Args:
169
+ image [B, 3, H, W]
170
+ image_feat [L, B, C]: CLIP features
171
+ text_emb [N, C]
172
+ deterministic (bool): deterministic inference flag for gumbel noise
173
+ hard (bool): decide hard or soft returning segmentation mask.
174
+ Note that soft mask is required for proper evaluation
175
+
176
+ Return:
177
+ mask [B, N, H', W'] (H' and W' are downsampled H/W)
178
+ """
179
+ image_emb = self.image_encoder(image, image_feat) # [BCHW]
180
+
181
+ image_emb = normalize(image_emb, dim=1) # BCHW
182
+ text_emb = normalize(text_emb, dim=-1) # NC
183
+
184
+ simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
185
+
186
+ hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
187
+ mask = hard_mask if hard else soft_mask
188
+
189
+ return mask, simmap
190
+
191
+ class DINOTextMasker(nn.Module):
192
+ def __init__(self, similarity_type="cosine"):
193
+ super().__init__()
194
+ self.sim2mask = DINOTextSim2Mask()
195
+ self.sim2mask = self.sim2mask.eval()
196
+ self.similarity_type = similarity_type
197
+
198
+ def forward(self, image, image_feat, text_emb, deterministic=False):
199
+ pass
200
+
201
+ @torch.no_grad()
202
+ def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False):
203
+ """Make mask by 1:N matching
204
+
205
+ Args:
206
+ image [B, 3, H, W]
207
+ image_feat [L, B, C]: CLIP features
208
+ text_emb [N, K, C]
209
+ deterministic (bool): deterministic inference flag for gumbel noise
210
+ hard (bool): decide hard or soft returning segmentation mask.
211
+ Note that soft mask is required for proper evaluation
212
+ use_k_nn (bool): use kNN to segment
213
+ k_nn (int): number of nearest neighbors for kNN segmentation
214
+
215
+ Return:
216
+ mask [B, N, H', W'] (H' and W' are downsampled H/W)
217
+ """
218
+ b, c, h, w = image_feat.shape
219
+ n, c = text_emb.shape
220
+
221
+ if self.similarity_type == "cosine":
222
+ image_feat = normalize(image_feat, dim=1) # BCHW
223
+ # text_emb = normalize(text_emb, dim=-1) # NKC
224
+ simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
225
+ else:
226
+ raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
227
+
228
+ hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
229
+ mask = hard_mask if hard else soft_mask
230
+
231
+ return mask, simmap
232
+
233
+
234
+ class DINOTextSim2Mask(nn.Module):
235
+ def __init__(self, gumbel_tau=1.0):
236
+ super().__init__()
237
+ self.gumbel_tau = gumbel_tau
238
+
239
+ def forward(self, x, deterministic=False):
240
+ soft_mask = torch.sigmoid(x)
241
+ if deterministic:
242
+ hard_mask = soft_mask.gt(0.5).type(x.dtype)
243
+ else:
244
+ hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau)
245
+
246
+ return hard_mask, soft_mask