Student0809 commited on
Commit
341b9e4
·
verified ·
1 Parent(s): d377feb

Add files using upload-large-folder tool

Browse files
.dev_scripts/build_docs.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pip install -r requirements/docs.txt
2
+ cd docs
3
+ rm -rf build
4
+
5
+ # update api rst
6
+ #rm -rf source/api/
7
+ #sphinx-apidoc --module-first -o source/api/ ../modelscope/
8
+ make html
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
12
+
13
+
14
+ **Your hardware and system info**
15
+ Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
16
+
17
+
18
+ **Additional context**
19
+ Add any other context about the problem here(在这里补充其他信息)
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
2
+
3
+ <p align="center">
4
+ <br>
5
+ <img src="asset/banner.png"/>
6
+ <br>
7
+ <p>
8
+ <p align="center">
9
+ <a href="https://modelscope.cn/home">ModelScope Community Website</a>
10
+ <br>
11
+ <a href="README_CN.md">中文</a> &nbsp | &nbsp English &nbsp
12
+ </p>
13
+
14
+ <p align="center">
15
+ <img src="https://img.shields.io/badge/python-3.10-5be.svg">
16
+ <img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
17
+ <a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
18
+ <a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
19
+ <a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
20
+ <a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
21
+ <a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
22
+ </p>
23
+
24
+ <p align="center">
25
+ <a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
26
+ </p>
27
+
28
+ <p align="center">
29
+ <a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
30
+ </p>
31
+
32
+ ## 📖 Table of Contents
33
+ - [Groups](#-Groups)
34
+ - [Introduction](#-introduction)
35
+ - [News](#-news)
36
+ - [Installation](#%EF%B8%8F-installation)
37
+ - [Quick Start](#-quick-Start)
38
+ - [Usage](#-Usage)
39
+ - [License](#-License)
40
+ - [Citation](#-citation)
41
+
42
+
43
+ ## ☎ Groups
44
+
45
+ You can contact us and communicate with us by adding our group:
46
+
47
+
48
+ [Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
49
+ :-------------------------:|:-------------------------:
50
+ <img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
51
+
52
+
53
+ ## 📝 Introduction
54
+ 🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
55
+
56
+ 🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
57
+
58
+ **Why choose ms-swift?**
59
+
60
+ - 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
61
+ - **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
62
+ - **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
63
+ - 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
64
+ - **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
65
+ - **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
66
+ - **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
67
+ - 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
68
+ - **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
69
+ - **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
70
+ - 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
71
+ - **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
72
+ - **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
73
+ - **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
74
+
75
+
76
+ ## 🎉 News
77
+ - 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
78
+ - 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
79
+ - 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
80
+ - 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
81
+ - 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
82
+ - 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
83
+ - 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
84
+ - 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
85
+ - 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
86
+ - 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
87
+ <details><summary>More</summary>
88
+
89
+ - 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
90
+ - 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
91
+ - 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
92
+ - 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
93
+ - 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
94
+ </details>
95
+
96
+ ## 🛠️ Installation
97
+ To install using pip:
98
+ ```shell
99
+ pip install ms-swift -U
100
+ ```
101
+
102
+ To install from source:
103
+ ```shell
104
+ # pip install git+https://github.com/modelscope/ms-swift.git
105
+
106
+ git clone https://github.com/modelscope/ms-swift.git
107
+ cd ms-swift
108
+ pip install -e .
109
+ ```
110
+
111
+ Running Environment:
112
+
113
+ | | Range | Recommended | Notes |
114
+ | ------------ |--------------| ----------- | ----------------------------------------- |
115
+ | python | >=3.9 | 3.10 | |
116
+ | cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
117
+ | torch | >=2.0 | | |
118
+ | transformers | >=4.33 | 4.51 | |
119
+ | modelscope | >=1.23 | | |
120
+ | peft | >=0.11,<0.16 | ||
121
+ | trl | >=0.13,<0.18 | 0.17 |RLHF|
122
+ | deepspeed | >=0.14 | 0.14.5 | Training |
123
+ | vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
124
+ | lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
125
+ | evalscope | >=0.11 | | Evaluation |
126
+
127
+ For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
128
+
129
+
130
+ ## 🚀 Quick Start
131
+
132
+ 10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
133
+
134
+ ### Command Line Interface
135
+
136
+ ```shell
137
+ # 22GB
138
+ CUDA_VISIBLE_DEVICES=0 \
139
+ swift sft \
140
+ --model Qwen/Qwen2.5-7B-Instruct \
141
+ --train_type lora \
142
+ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
143
+ 'AI-ModelScope/alpaca-gpt4-data-en#500' \
144
+ 'swift/self-cognition#500' \
145
+ --torch_dtype bfloat16 \
146
+ --num_train_epochs 1 \
147
+ --per_device_train_batch_size 1 \
148
+ --per_device_eval_batch_size 1 \
149
+ --learning_rate 1e-4 \
150
+ --lora_rank 8 \
151
+ --lora_alpha 32 \
152
+ --target_modules all-linear \
153
+ --gradient_accumulation_steps 16 \
154
+ --eval_steps 50 \
155
+ --save_steps 50 \
156
+ --save_total_limit 2 \
157
+ --logging_steps 5 \
158
+ --max_length 2048 \
159
+ --output_dir output \
160
+ --system 'You are a helpful assistant.' \
161
+ --warmup_ratio 0.05 \
162
+ --dataloader_num_workers 4 \
163
+ --model_author swift \
164
+ --model_name swift-robot
165
+ ```
166
+
167
+ Tips:
168
+
169
+ - If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
170
+ - The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
171
+ - To train with a different model, simply modify `--model <model_id/model_path>`.
172
+ - By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
173
+
174
+ After training is complete, use the following command to infer with the trained weights:
175
+
176
+ - Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
177
+
178
+ ```shell
179
+ # Using an interactive command line for inference.
180
+ CUDA_VISIBLE_DEVICES=0 \
181
+ swift infer \
182
+ --adapters output/vx-xxx/checkpoint-xxx \
183
+ --stream true \
184
+ --temperature 0 \
185
+ --max_new_tokens 2048
186
+
187
+ # merge-lora and use vLLM for inference acceleration
188
+ CUDA_VISIBLE_DEVICES=0 \
189
+ swift infer \
190
+ --adapters output/vx-xxx/checkpoint-xxx \
191
+ --stream true \
192
+ --merge_lora true \
193
+ --infer_backend vllm \
194
+ --max_model_len 8192 \
195
+ --temperature 0 \
196
+ --max_new_tokens 2048
197
+ ```
198
+
199
+ Finally, use the following command to push the model to ModelScope:
200
+
201
+ ```shell
202
+ CUDA_VISIBLE_DEVICES=0 \
203
+ swift export \
204
+ --adapters output/vx-xxx/checkpoint-xxx \
205
+ --push_to_hub true \
206
+ --hub_model_id '<your-model-id>' \
207
+ --hub_token '<your-sdk-token>' \
208
+ --use_hf false
209
+ ```
210
+
211
+
212
+ ### Web-UI
213
+ The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
214
+
215
+ ```shell
216
+ SWIFT_UI_LANG=en swift web-ui
217
+ ```
218
+
219
+ ![image.png](./docs/resources/web-ui-en.jpg)
220
+
221
+ ### Using Python
222
+
223
+ ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
224
+
225
+ Training:
226
+
227
+ ```python
228
+ # Retrieve the model and template, and add a trainable LoRA module
229
+ model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
230
+ template = get_template(model.model_meta.template, tokenizer, ...)
231
+ model = Swift.prepare_model(model, lora_config)
232
+
233
+ # Download and load the dataset, and encode the text into tokens
234
+ train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
235
+ train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
236
+ val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
237
+
238
+ # Train the model
239
+ trainer = Seq2SeqTrainer(
240
+ model=model,
241
+ args=training_args,
242
+ data_collator=template.data_collator,
243
+ train_dataset=train_dataset,
244
+ eval_dataset=val_dataset,
245
+ template=template,
246
+ )
247
+ trainer.train()
248
+ ```
249
+ Inference:
250
+
251
+ ```python
252
+ # Perform inference using the native PyTorch engine
253
+ engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
254
+ infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
255
+ request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
256
+
257
+ resp_list = engine.infer([infer_request], request_config)
258
+ print(f'response: {resp_list[0].choices[0].message.content}')
259
+ ```
260
+
261
+ ## ✨ Usage
262
+ Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
263
+
264
+ - If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
265
+ - By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
266
+
267
+ | Useful Links |
268
+ | ------ |
269
+ | [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
270
+ | [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
271
+ | [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
272
+ | [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
273
+
274
+ ### Training
275
+
276
+ Supported Training Methods:
277
+
278
+ | Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
279
+ |------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
280
+ | Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
281
+ | Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
282
+ | DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
283
+ | GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
284
+ | Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
285
+ | PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
286
+ | KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
287
+ | CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
288
+ | SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
289
+ | ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
290
+ | Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
291
+ | Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
292
+
293
+
294
+
295
+ Pre-training:
296
+ ```shell
297
+ # 8*A100
298
+ NPROC_PER_NODE=8 \
299
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
300
+ swift pt \
301
+ --model Qwen/Qwen2.5-7B \
302
+ --dataset swift/chinese-c4 \
303
+ --streaming true \
304
+ --train_type full \
305
+ --deepspeed zero2 \
306
+ --output_dir output \
307
+ --max_steps 10000 \
308
+ ...
309
+ ```
310
+
311
+ Fine-tuning:
312
+ ```shell
313
+ CUDA_VISIBLE_DEVICES=0 swift sft \
314
+ --model Qwen/Qwen2.5-7B-Instruct \
315
+ --dataset AI-ModelScope/alpaca-gpt4-data-en \
316
+ --train_type lora \
317
+ --output_dir output \
318
+ ...
319
+ ```
320
+
321
+ RLHF:
322
+ ```shell
323
+ CUDA_VISIBLE_DEVICES=0 swift rlhf \
324
+ --rlhf_type dpo \
325
+ --model Qwen/Qwen2.5-7B-Instruct \
326
+ --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
327
+ --train_type lora \
328
+ --output_dir output \
329
+ ...
330
+ ```
331
+
332
+
333
+ ### Inference
334
+ ```shell
335
+ CUDA_VISIBLE_DEVICES=0 swift infer \
336
+ --model Qwen/Qwen2.5-7B-Instruct \
337
+ --stream true \
338
+ --infer_backend pt \
339
+ --max_new_tokens 2048
340
+
341
+ # LoRA
342
+ CUDA_VISIBLE_DEVICES=0 swift infer \
343
+ --model Qwen/Qwen2.5-7B-Instruct \
344
+ --adapters swift/test_lora \
345
+ --stream true \
346
+ --infer_backend pt \
347
+ --temperature 0 \
348
+ --max_new_tokens 2048
349
+ ```
350
+
351
+ ### Interface Inference
352
+ ```shell
353
+ CUDA_VISIBLE_DEVICES=0 swift app \
354
+ --model Qwen/Qwen2.5-7B-Instruct \
355
+ --stream true \
356
+ --infer_backend pt \
357
+ --max_new_tokens 2048
358
+ ```
359
+
360
+ ### Deployment
361
+ ```shell
362
+ CUDA_VISIBLE_DEVICES=0 swift deploy \
363
+ --model Qwen/Qwen2.5-7B-Instruct \
364
+ --infer_backend vllm
365
+ ```
366
+
367
+ ### Sampling
368
+ ```shell
369
+ CUDA_VISIBLE_DEVICES=0 swift sample \
370
+ --model LLM-Research/Meta-Llama-3.1-8B-Instruct \
371
+ --sampler_engine pt \
372
+ --num_return_sequences 5 \
373
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
374
+ ```
375
+
376
+ ### Evaluation
377
+ ```shell
378
+ CUDA_VISIBLE_DEVICES=0 swift eval \
379
+ --model Qwen/Qwen2.5-7B-Instruct \
380
+ --infer_backend lmdeploy \
381
+ --eval_backend OpenCompass \
382
+ --eval_dataset ARC_c
383
+ ```
384
+
385
+ ### Quantization
386
+ ```shell
387
+ CUDA_VISIBLE_DEVICES=0 swift export \
388
+ --model Qwen/Qwen2.5-7B-Instruct \
389
+ --quant_bits 4 --quant_method awq \
390
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
391
+ --output_dir Qwen2.5-7B-Instruct-AWQ
392
+ ```
393
+
394
+ ### Push Model
395
+ ```shell
396
+ swift export \
397
+ --model <model-path> \
398
+ --push_to_hub true \
399
+ --hub_model_id '<model-id>' \
400
+ --hub_token '<sdk-token>'
401
+ ```
402
+
403
+ ## 🏛 License
404
+
405
+ This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
406
+
407
+ ## 📎 Citation
408
+
409
+ ```bibtex
410
+ @misc{zhao2024swiftascalablelightweightinfrastructure,
411
+ title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
412
+ author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
413
+ year={2024},
414
+ eprint={2408.05517},
415
+ archivePrefix={arXiv},
416
+ primaryClass={cs.CL},
417
+ url={https://arxiv.org/abs/2408.05517},
418
+ }
419
+ ```
420
+
421
+ ## Star History
422
+
423
+ [![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date)
COT_TRAIN.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert a RWKV checkpoint from BlinkDL to the Hugging Face format."""
16
+
17
+ import argparse
18
+ import gc
19
+ import json
20
+ import os
21
+ import re
22
+
23
+ import torch
24
+ from huggingface_hub import hf_hub_download, split_torch_state_dict_into_shards
25
+
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig
27
+ from transformers.modeling_utils import WEIGHTS_INDEX_NAME
28
+
29
+
30
+ NUM_HIDDEN_LAYERS_MAPPING = {
31
+ "169M": 12,
32
+ "430M": 24,
33
+ "1B5": 24,
34
+ "3B": 32,
35
+ "7B": 32,
36
+ "14B": 40,
37
+ }
38
+
39
+ HIDEN_SIZE_MAPPING = {
40
+ "169M": 768,
41
+ "430M": 1024,
42
+ "1B5": 2048,
43
+ "3B": 2560,
44
+ "7B": 4096,
45
+ "14B": 5120,
46
+ }
47
+
48
+
49
+ def convert_state_dict(state_dict):
50
+ state_dict_keys = list(state_dict.keys())
51
+ for name in state_dict_keys:
52
+ weight = state_dict.pop(name)
53
+ # emb -> embedding
54
+ if name.startswith("emb."):
55
+ name = name.replace("emb.", "embeddings.")
56
+ # ln_0 -> pre_ln (only present at block 0)
57
+ if name.startswith("blocks.0.ln0"):
58
+ name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
59
+ # att -> attention
60
+ name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
61
+ # ffn -> feed_forward
62
+ name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
63
+ # time_mix_k -> time_mix_key and reshape
64
+ if name.endswith(".time_mix_k"):
65
+ name = name.replace(".time_mix_k", ".time_mix_key")
66
+ # time_mix_v -> time_mix_value and reshape
67
+ if name.endswith(".time_mix_v"):
68
+ name = name.replace(".time_mix_v", ".time_mix_value")
69
+ # time_mix_r -> time_mix_key and reshape
70
+ if name.endswith(".time_mix_r"):
71
+ name = name.replace(".time_mix_r", ".time_mix_receptance")
72
+
73
+ if name != "head.weight":
74
+ name = "rwkv." + name
75
+
76
+ state_dict[name] = weight
77
+ return state_dict
78
+
79
+
80
+ def convert_rmkv_checkpoint_to_hf_format(
81
+ repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None
82
+ ):
83
+ # 1. If possible, build the tokenizer.
84
+ if tokenizer_file is None:
85
+ print("No `--tokenizer_file` provided, we will use the default tokenizer.")
86
+ vocab_size = 50277
87
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
88
+ else:
89
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
90
+ vocab_size = len(tokenizer)
91
+ tokenizer.save_pretrained(output_dir)
92
+
93
+ # 2. Build the config
94
+ possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys())
95
+ if size is None:
96
+ # Try to infer size from the checkpoint name
97
+ for candidate in possible_sizes:
98
+ if candidate in checkpoint_file:
99
+ size = candidate
100
+ break
101
+ if size is None:
102
+ raise ValueError("Could not infer the size, please provide it with the `--size` argument.")
103
+ if size not in possible_sizes:
104
+ raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.")
105
+
106
+ config = RwkvConfig(
107
+ vocab_size=vocab_size,
108
+ num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size],
109
+ hidden_size=HIDEN_SIZE_MAPPING[size],
110
+ )
111
+ config.save_pretrained(output_dir)
112
+
113
+ # 3. Download model file then convert state_dict
114
+ model_file = hf_hub_download(repo_id, checkpoint_file)
115
+ state_dict = torch.load(model_file, map_location="cpu", weights_only=True)
116
+ state_dict = convert_state_dict(state_dict)
117
+
118
+ # 4. Split in shards and save
119
+ state_dict_split = split_torch_state_dict_into_shards(state_dict)
120
+ shards = index = None
121
+ for tensors in state_dict_split.filename_to_tensors.values():
122
+ shards = {tensor: state_dict[tensor] for tensor in tensors}
123
+ if state_dict_split.is_sharded:
124
+ index = {
125
+ "metadata": state_dict_split.metadata,
126
+ "weight_map": state_dict_split.tensor_to_filename,
127
+ }
128
+
129
+ for shard_file, shard in shards.items():
130
+ torch.save(shard, os.path.join(output_dir, shard_file))
131
+
132
+ if index is not None:
133
+ save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME)
134
+ # Save the index as well
135
+ with open(save_index_file, "w", encoding="utf-8") as f:
136
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
137
+ f.write(content)
138
+
139
+ # 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict
140
+ print(
141
+ "Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model."
142
+ )
143
+ shard_files = list(shards.keys())
144
+
145
+ del state_dict
146
+ del shards
147
+ gc.collect()
148
+
149
+ for shard_file in shard_files:
150
+ state_dict = torch.load(os.path.join(output_dir, shard_file), weights_only=True)
151
+ torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file))
152
+
153
+ del state_dict
154
+ gc.collect()
155
+
156
+ if push_to_hub:
157
+ if model_name is None:
158
+ raise ValueError("Please provide a `model_name` to push the model to the Hub.")
159
+ model = AutoModelForCausalLM.from_pretrained(output_dir)
160
+ model.push_to_hub(model_name, max_shard_size="2GB")
161
+ tokenizer.push_to_hub(model_name)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ parser = argparse.ArgumentParser()
166
+ # Required parameters
167
+ parser.add_argument(
168
+ "--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint."
169
+ )
170
+ parser.add_argument(
171
+ "--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo."
172
+ )
173
+ parser.add_argument(
174
+ "--output_dir", default=None, type=str, required=True, help="Where to save the converted model."
175
+ )
176
+ parser.add_argument(
177
+ "--tokenizer_file",
178
+ default=None,
179
+ type=str,
180
+ help="Path to the tokenizer file to use (if not provided, only the model is converted).",
181
+ )
182
+ parser.add_argument(
183
+ "--size",
184
+ default=None,
185
+ type=str,
186
+ help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.",
187
+ )
188
+ parser.add_argument(
189
+ "--push_to_hub",
190
+ action="store_true",
191
+ help="Push to the Hub the converted model.",
192
+ )
193
+ parser.add_argument(
194
+ "--model_name",
195
+ default=None,
196
+ type=str,
197
+ help="Name of the pushed model on the Hub, including the username / organization.",
198
+ )
199
+
200
+ args = parser.parse_args()
201
+ convert_rmkv_checkpoint_to_hf_format(
202
+ args.repo_id,
203
+ args.checkpoint_file,
204
+ args.output_dir,
205
+ size=args.size,
206
+ tokenizer_file=args.tokenizer_file,
207
+ push_to_hub=args.push_to_hub,
208
+ model_name=args.model_name,
209
+ )
docs/transformers/build/lib/transformers/models/rwkv/modeling_rwkv.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Bo Peng and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RWKV model."""
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_utils import PreTrainedModel
29
+ from ...utils import (
30
+ ModelOutput,
31
+ add_code_sample_docstrings,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ is_bitsandbytes_available,
35
+ is_ninja_available,
36
+ is_torch_cuda_available,
37
+ logging,
38
+ )
39
+ from .configuration_rwkv import RwkvConfig
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _CHECKPOINT_FOR_DOC = "RWKV/rwkv-4-169m-pile"
45
+ _CONFIG_FOR_DOC = "RwkvConfig"
46
+
47
+
48
+ rwkv_cuda_kernel = None
49
+
50
+
51
+ def load_wkv_cuda_kernel(context_length):
52
+ from torch.utils.cpp_extension import load as load_kernel
53
+
54
+ global rwkv_cuda_kernel
55
+
56
+ kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv"
57
+ cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]]
58
+
59
+ # Only load the kernel if it's not been loaded yet or if we changed the context length
60
+ if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length:
61
+ return
62
+
63
+ logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.")
64
+
65
+ flags = [
66
+ "-res-usage",
67
+ "--maxrregcount 60",
68
+ "--use_fast_math",
69
+ "-O3",
70
+ "-Xptxas -O3",
71
+ "--extra-device-vectorization",
72
+ f"-DTmax={context_length}",
73
+ ]
74
+ rwkv_cuda_kernel = load_kernel(
75
+ name=f"wkv_{context_length}",
76
+ sources=cuda_kernel_files,
77
+ verbose=(logging.get_verbosity() == logging.DEBUG),
78
+ extra_cuda_cflags=flags,
79
+ )
80
+ rwkv_cuda_kernel.max_seq_length = context_length
81
+
82
+
83
+ class RwkvLinearAttention(torch.autograd.Function):
84
+ @staticmethod
85
+ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
86
+ batch_size, seq_len, hidden_size = key.size()
87
+ if seq_len > rwkv_cuda_kernel.max_seq_length:
88
+ raise ValueError(
89
+ f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
90
+ f"{rwkv_cuda_kernel.max_seq_length} with this model."
91
+ )
92
+ if batch_size * hidden_size % min(hidden_size, 32) != 0:
93
+ raise ValueError(
94
+ f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
95
+ f"multiple of {min(hidden_size, 32)}."
96
+ )
97
+
98
+ ctx.input_dtype = key.dtype
99
+
100
+ if (
101
+ time_decay.device.type != "cuda"
102
+ or time_first.device.type != "cuda"
103
+ or key.device.type != "cuda"
104
+ or value.device.type != "cuda"
105
+ ):
106
+ raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
107
+
108
+ time_decay = -torch.exp(time_decay.float().contiguous())
109
+ if key.dtype == torch.float16:
110
+ time_first = time_first.float()
111
+ key = key.float()
112
+ value = value.float()
113
+ time_first = time_first.contiguous()
114
+ key = key.contiguous()
115
+ value = value.contiguous()
116
+ # The CUDA kernel will fill this tensor.
117
+ output = torch.empty_like(key, memory_format=torch.contiguous_format)
118
+ if return_state or state is not None:
119
+ if state is None:
120
+ state = torch.zeros(
121
+ batch_size,
122
+ hidden_size,
123
+ 3,
124
+ dtype=torch.float32,
125
+ device=key.device,
126
+ memory_format=torch.contiguous_format,
127
+ )
128
+ state[:, :, 2] -= 1e38
129
+ else:
130
+ state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
131
+ if key.dtype == torch.bfloat16:
132
+ forward_func = rwkv_cuda_kernel.forward_with_state_bf16
133
+ else:
134
+ forward_func = rwkv_cuda_kernel.forward_with_state
135
+ forward_func(time_decay, time_first, key, value, output, state)
136
+ else:
137
+ forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
138
+ forward_func(time_decay, time_first, key, value, output)
139
+
140
+ ctx.save_for_backward(time_decay, time_first, key, value, output)
141
+
142
+ if state is not None:
143
+ state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
144
+
145
+ return output.to(ctx.input_dtype), state
146
+
147
+ @staticmethod
148
+ # g stands for grad
149
+ def backward(ctx, g_output, g_state=None):
150
+ input_dtype = ctx.input_dtype
151
+
152
+ time_decay, time_first, key, value, output = ctx.saved_tensors
153
+ # The CUDA kernel will fill those tensors.
154
+ g_time_decay = torch.empty_like(
155
+ time_decay,
156
+ memory_format=torch.contiguous_format,
157
+ dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
158
+ )
159
+ g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
160
+ g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
161
+ g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
162
+
163
+ if input_dtype == torch.float16:
164
+ g_output = g_output.float()
165
+ backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
166
+ backward_func(
167
+ time_decay,
168
+ time_first,
169
+ key,
170
+ value,
171
+ output,
172
+ g_output.contiguous(),
173
+ g_time_decay,
174
+ g_time_first,
175
+ g_key,
176
+ g_value,
177
+ )
178
+
179
+ return (
180
+ g_time_decay.to(input_dtype),
181
+ g_time_first.to(input_dtype),
182
+ g_key.to(input_dtype),
183
+ g_value.to(input_dtype),
184
+ None,
185
+ None,
186
+ )
187
+
188
+
189
+ def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
190
+ # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
191
+ # within a torch.no_grad.
192
+ _, seq_length, _ = key.size()
193
+ output = torch.zeros_like(key)
194
+
195
+ if state is None:
196
+ num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
197
+ den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
198
+ max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
199
+ else:
200
+ num_state, den_state, max_state = state
201
+ # For numerical stability
202
+ # real_numerator_state = num_state * torch.exp(max_state)
203
+ # real_denominator_state = den_state * torch.exp(max_state)
204
+
205
+ time_decay = -torch.exp(time_decay)
206
+
207
+ for current_index in range(seq_length):
208
+ current_key = key[:, current_index].float()
209
+ current_value = value[:, current_index]
210
+
211
+ # wkv computation at time t
212
+ max_for_output = torch.maximum(max_state, current_key + time_first)
213
+ e1 = torch.exp(max_state - max_for_output)
214
+ e2 = torch.exp(current_key + time_first - max_for_output)
215
+ numerator = e1 * num_state + e2 * current_value
216
+ denominator = e1 * den_state + e2
217
+ output[:, current_index] = (numerator / denominator).to(output.dtype)
218
+
219
+ # Update state for next iteration
220
+ max_for_state = torch.maximum(max_state + time_decay, current_key)
221
+ e1 = torch.exp(max_state + time_decay - max_for_state)
222
+ e2 = torch.exp(current_key - max_for_state)
223
+ num_state = e1 * num_state + e2 * current_value
224
+ den_state = e1 * den_state + e2
225
+ max_state = max_for_state
226
+
227
+ if return_state or state is not None:
228
+ state = [num_state, den_state, max_state]
229
+
230
+ return output, state
231
+
232
+
233
+ def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
234
+ no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
235
+ # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
236
+ # in this case).
237
+ one_token = key.size(1) == 1
238
+ if rwkv_cuda_kernel is None or no_cuda or one_token:
239
+ return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
240
+ else:
241
+ return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
242
+
243
+
244
+ class RwkvSelfAttention(nn.Module):
245
+ def __init__(self, config, layer_id=0):
246
+ super().__init__()
247
+ self.config = config
248
+ kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
249
+ if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
250
+ try:
251
+ load_wkv_cuda_kernel(config.context_length)
252
+ except Exception:
253
+ logger.info("Could not load the custom CUDA kernel for RWKV attention.")
254
+ self.layer_id = layer_id
255
+ hidden_size = config.hidden_size
256
+ attention_hidden_size = (
257
+ config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
258
+ )
259
+ self.attention_hidden_size = attention_hidden_size
260
+
261
+ self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
262
+ self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
263
+
264
+ self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
265
+ self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
266
+ self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
267
+
268
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
269
+ self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
270
+ self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
271
+ self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
272
+ self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
273
+
274
+ # TODO: maybe jit, otherwise move inside forward
275
+ def extract_key_value(self, hidden, state=None):
276
+ # Mix hidden with the previous timestep to produce key, value, receptance
277
+ if hidden.size(1) == 1 and state is not None:
278
+ shifted = state[1][:, :, self.layer_id]
279
+ else:
280
+ shifted = self.time_shift(hidden)
281
+ if state is not None:
282
+ shifted[:, 0] = state[1][:, :, self.layer_id]
283
+ key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
284
+ value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
285
+ receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
286
+
287
+ key = self.key(key)
288
+ value = self.value(value)
289
+ receptance = torch.sigmoid(self.receptance(receptance))
290
+ if state is not None:
291
+ state[1][:, :, self.layer_id] = hidden[:, -1]
292
+ return receptance, key, value, state
293
+
294
+ def forward(self, hidden, state=None, use_cache=False):
295
+ receptance, key, value, state = self.extract_key_value(hidden, state=state)
296
+ layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
297
+ rwkv, layer_state = rwkv_linear_attention(
298
+ self.time_decay,
299
+ self.time_first,
300
+ key,
301
+ value,
302
+ state=layer_state,
303
+ return_state=use_cache,
304
+ )
305
+
306
+ if layer_state is not None:
307
+ state[2][:, :, self.layer_id] = layer_state[0]
308
+ state[3][:, :, self.layer_id] = layer_state[1]
309
+ state[4][:, :, self.layer_id] = layer_state[2]
310
+
311
+ return self.output(receptance * rwkv), state
312
+
313
+
314
+ class RwkvFeedForward(nn.Module):
315
+ def __init__(self, config, layer_id=0):
316
+ super().__init__()
317
+ self.config = config
318
+ self.layer_id = layer_id
319
+ hidden_size = config.hidden_size
320
+ intermediate_size = (
321
+ config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
322
+ )
323
+
324
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
325
+ self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
326
+ self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
327
+
328
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
329
+ self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
330
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
331
+
332
+ def forward(self, hidden, state=None):
333
+ if hidden.size(1) == 1 and state is not None:
334
+ shifted = state[0][:, :, self.layer_id]
335
+ else:
336
+ shifted = self.time_shift(hidden)
337
+ if state is not None:
338
+ shifted[:, 0] = state[0][:, :, self.layer_id]
339
+ key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
340
+ receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
341
+
342
+ key = torch.square(torch.relu(self.key(key)))
343
+ value = self.value(key)
344
+ receptance = torch.sigmoid(self.receptance(receptance))
345
+
346
+ if state is not None:
347
+ state[0][:, :, self.layer_id] = hidden[:, -1]
348
+
349
+ return receptance * value, state
350
+
351
+
352
+ class RwkvBlock(nn.Module):
353
+ def __init__(self, config, layer_id):
354
+ super().__init__()
355
+ self.config = config
356
+ self.layer_id = layer_id
357
+
358
+ if layer_id == 0:
359
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
360
+
361
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
362
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
363
+
364
+ self.attention = RwkvSelfAttention(config, layer_id)
365
+ self.feed_forward = RwkvFeedForward(config, layer_id)
366
+
367
+ def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
368
+ if self.layer_id == 0:
369
+ hidden = self.pre_ln(hidden)
370
+
371
+ attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
372
+ hidden = hidden + attention
373
+
374
+ feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
375
+ hidden = hidden + feed_forward
376
+
377
+ outputs = (hidden, state)
378
+ if output_attentions:
379
+ outputs += (attention,)
380
+ else:
381
+ outputs += (None,)
382
+
383
+ return outputs
384
+
385
+
386
+ class RwkvPreTrainedModel(PreTrainedModel):
387
+ """
388
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
389
+ models.
390
+ """
391
+
392
+ config_class = RwkvConfig
393
+ base_model_prefix = "rwkv"
394
+ _no_split_modules = ["RwkvBlock"]
395
+ _keep_in_fp32_modules = ["time_decay", "time_first"]
396
+ supports_gradient_checkpointing = True
397
+ _is_stateful = True
398
+
399
+ def _init_weights(self, module):
400
+ """Initialize the weights."""
401
+ if isinstance(module, RwkvSelfAttention):
402
+ layer_id = module.layer_id
403
+ num_hidden_layers = module.config.num_hidden_layers
404
+ hidden_size = module.config.hidden_size
405
+ attention_hidden_size = module.attention_hidden_size
406
+
407
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
408
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
409
+
410
+ time_weight = torch.tensor(
411
+ [i / hidden_size for i in range(hidden_size)],
412
+ dtype=module.time_mix_key.dtype,
413
+ device=module.time_mix_key.device,
414
+ )
415
+ time_weight = time_weight[None, None, :]
416
+
417
+ decay_speed = [
418
+ -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
419
+ for h in range(attention_hidden_size)
420
+ ]
421
+ decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
422
+ zigzag = (
423
+ torch.tensor(
424
+ [(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
425
+ dtype=module.time_first.dtype,
426
+ device=module.time_first.device,
427
+ )
428
+ * 0.5
429
+ )
430
+
431
+ with torch.no_grad():
432
+ module.time_decay.data = decay_speed
433
+ module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag)
434
+
435
+ module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
436
+ module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
437
+ module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
438
+ elif isinstance(module, RwkvFeedForward):
439
+ layer_id = module.layer_id
440
+ num_hidden_layers = module.config.num_hidden_layers
441
+ hidden_size = module.config.hidden_size
442
+
443
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
444
+
445
+ time_weight = torch.tensor(
446
+ [i / hidden_size for i in range(hidden_size)],
447
+ dtype=module.time_mix_key.dtype,
448
+ device=module.time_mix_key.device,
449
+ )
450
+ time_weight = time_weight[None, None, :]
451
+
452
+ with torch.no_grad():
453
+ module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
454
+ module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
455
+
456
+
457
+ @dataclass
458
+ class RwkvOutput(ModelOutput):
459
+ """
460
+ Class for the RWKV model outputs.
461
+
462
+ Args:
463
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
464
+ Sequence of hidden-states at the output of the last layer of the model.
465
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
466
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
467
+ avoid providing the old `input_ids`.
468
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
469
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
470
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
471
+
472
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
473
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
474
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
475
+ sequence_length)`.
476
+
477
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
478
+ heads.
479
+ """
480
+
481
+ last_hidden_state: Optional[torch.FloatTensor] = None
482
+ state: Optional[List[torch.FloatTensor]] = None
483
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
484
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
485
+
486
+
487
+ @dataclass
488
+ class RwkvCausalLMOutput(ModelOutput):
489
+ """
490
+ Base class for causal language model (or autoregressive) outputs.
491
+
492
+ Args:
493
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
494
+ Language modeling loss (for next-token prediction).
495
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
496
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
497
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
498
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
499
+ avoid providing the old `input_ids`.
500
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
501
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
502
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
503
+
504
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
505
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
506
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
507
+ sequence_length)`.
508
+
509
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
510
+ heads.
511
+ """
512
+
513
+ loss: Optional[torch.FloatTensor] = None
514
+ logits: Optional[torch.FloatTensor] = None
515
+ state: Optional[List[torch.FloatTensor]] = None
516
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
517
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
518
+
519
+
520
+ RWKV_START_DOCSTRING = r"""
521
+
522
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
523
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
524
+ etc.)
525
+
526
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
527
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
528
+ and behavior.
529
+
530
+ Parameters:
531
+ config ([`RwkvConfig`]): Model configuration class with all the parameters of the model.
532
+ Initializing with a config file does not load the weights associated with the model, only the
533
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
534
+ """
535
+
536
+ RWKV_INPUTS_DOCSTRING = r"""
537
+ Args:
538
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
539
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
540
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
541
+ sequence tokens in the vocabulary.
542
+
543
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
544
+ `input_ids`.
545
+
546
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
547
+ [`PreTrainedTokenizer.__call__`] for details.
548
+
549
+ [What are input IDs?](../glossary#input-ids)
550
+ attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
551
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
552
+
553
+ - 1 for tokens that are **not masked**,
554
+ - 0 for tokens that are **masked**.
555
+
556
+ This is currently not used by `RwkvModel`, but will be supported in the future.
557
+
558
+ [What are attention masks?](../glossary#attention-mask)
559
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
560
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
561
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
562
+ model's internal embedding lookup matrix.
563
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
564
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
565
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
566
+ use_cache (`bool`, *optional*):
567
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
568
+ output_attentions (`bool`, *optional*):
569
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
570
+ tensors for more detail.
571
+ output_hidden_states (`bool`, *optional*):
572
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
573
+ more detail.
574
+ return_dict (`bool`, *optional*):
575
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
576
+ """
577
+
578
+
579
+ @add_start_docstrings(
580
+ "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
581
+ RWKV_START_DOCSTRING,
582
+ )
583
+ class RwkvModel(RwkvPreTrainedModel):
584
+ def __init__(self, config):
585
+ super().__init__(config)
586
+
587
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
588
+ self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
589
+ self.ln_out = nn.LayerNorm(config.hidden_size)
590
+
591
+ self.layers_are_rescaled = False
592
+
593
+ self.gradient_checkpointing = False
594
+
595
+ # Initialize weights and apply final processing
596
+ self.post_init()
597
+
598
+ def get_input_embeddings(self):
599
+ return self.embeddings
600
+
601
+ def set_input_embeddings(self, new_embeddings):
602
+ self.embeddings = new_embeddings
603
+
604
+ @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
605
+ @add_code_sample_docstrings(
606
+ checkpoint=_CHECKPOINT_FOR_DOC,
607
+ output_type=RwkvOutput,
608
+ config_class=_CONFIG_FOR_DOC,
609
+ )
610
+ def forward(
611
+ self,
612
+ input_ids: Optional[torch.LongTensor] = None,
613
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
614
+ inputs_embeds: Optional[torch.FloatTensor] = None,
615
+ state: Optional[List[torch.FloatTensor]] = None,
616
+ use_cache: Optional[bool] = None,
617
+ output_attentions: Optional[bool] = None,
618
+ output_hidden_states: Optional[bool] = None,
619
+ return_dict: Optional[bool] = None,
620
+ ) -> Union[Tuple, RwkvOutput]:
621
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
622
+ output_hidden_states = (
623
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
624
+ )
625
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
626
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
627
+
628
+ if attention_mask is not None:
629
+ logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
630
+
631
+ if self.training == self.layers_are_rescaled:
632
+ self._rescale_layers()
633
+
634
+ if input_ids is not None and inputs_embeds is not None:
635
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
636
+ elif input_ids is None and inputs_embeds is None:
637
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
638
+
639
+ if inputs_embeds is None:
640
+ inputs_embeds = self.embeddings(input_ids)
641
+
642
+ if use_cache and state is None:
643
+ shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
644
+ state = [
645
+ torch.zeros(
646
+ *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
647
+ )
648
+ for i in range(5)
649
+ ]
650
+ state[4] -= 1e30
651
+
652
+ if self.gradient_checkpointing and self.training:
653
+ if use_cache:
654
+ logger.warning_once(
655
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
656
+ )
657
+ use_cache = False
658
+
659
+ hidden_states = inputs_embeds
660
+
661
+ all_self_attentions = () if output_attentions else None
662
+ all_hidden_states = () if output_hidden_states else None
663
+ for idx, block in enumerate(self.blocks):
664
+ if self.gradient_checkpointing and self.training:
665
+ hidden_states, state, attentions = self._gradient_checkpointing_func(
666
+ block.__call__, hidden_states, state, use_cache, output_attentions
667
+ )
668
+ else:
669
+ hidden_states, state, attentions = block(
670
+ hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
671
+ )
672
+
673
+ if (
674
+ self.layers_are_rescaled
675
+ and self.config.rescale_every > 0
676
+ and (idx + 1) % self.config.rescale_every == 0
677
+ ):
678
+ hidden_states = hidden_states / 2
679
+
680
+ if output_hidden_states:
681
+ all_hidden_states = all_hidden_states + (hidden_states,)
682
+
683
+ if output_attentions:
684
+ all_self_attentions = all_self_attentions + (attentions,)
685
+
686
+ hidden_states = self.ln_out(hidden_states)
687
+
688
+ if output_hidden_states:
689
+ all_hidden_states = all_hidden_states + (hidden_states,)
690
+
691
+ if not return_dict:
692
+ return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None)
693
+
694
+ return RwkvOutput(
695
+ last_hidden_state=hidden_states,
696
+ state=state,
697
+ hidden_states=all_hidden_states,
698
+ attentions=all_self_attentions,
699
+ )
700
+
701
+ def _rescale_layers(self):
702
+ # Layers should be rescaled for inference only.
703
+ if self.layers_are_rescaled == (not self.training):
704
+ return
705
+ if self.config.rescale_every > 0:
706
+ with torch.no_grad():
707
+ for block_id, block in enumerate(self.blocks):
708
+ if self.training:
709
+ block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
710
+ block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
711
+ else:
712
+ # Deal with quantization statistics
713
+ if hasattr(block.attention.output.weight, "SCB"):
714
+ block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
715
+ block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
716
+ elif hasattr(block.attention.output.weight, "quant_state"):
717
+ self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
718
+ self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
719
+ else:
720
+ block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
721
+ block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
722
+
723
+ self.layers_are_rescaled = not self.training
724
+
725
+ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
726
+ r"""
727
+ Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
728
+ be quantized again.
729
+ """
730
+ if not is_bitsandbytes_available():
731
+ raise ImportError("Please install bitsandbytes to use this method.")
732
+ import bitsandbytes as bnb
733
+
734
+ dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
735
+
736
+ dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
737
+
738
+ # re-quantize the model:
739
+ # we need to put it first on CPU then back to the device
740
+ # this will create an overhead :/
741
+ # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
742
+ # bugs with bnb
743
+ quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
744
+ setattr(target_layer, "weight", quant_weight)
745
+
746
+
747
+ @add_start_docstrings(
748
+ """
749
+ The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
750
+ embeddings).
751
+ """,
752
+ RWKV_START_DOCSTRING,
753
+ )
754
+ class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin):
755
+ _tied_weights_keys = ["head.weight"]
756
+
757
+ def __init__(self, config):
758
+ super().__init__(config)
759
+ self.rwkv = RwkvModel(config)
760
+ self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
761
+
762
+ # Initialize weights and apply final processing
763
+ self.post_init()
764
+
765
+ def get_output_embeddings(self):
766
+ return self.head
767
+
768
+ def set_output_embeddings(self, new_embeddings):
769
+ self.head = new_embeddings
770
+
771
+ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
772
+ # Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`)
773
+
774
+ # only last token for inputs_ids if the state is passed along.
775
+ if state is not None:
776
+ input_ids = input_ids[:, -1].unsqueeze(-1)
777
+
778
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
779
+ if inputs_embeds is not None and state is None:
780
+ model_inputs = {"inputs_embeds": inputs_embeds}
781
+ else:
782
+ model_inputs = {"input_ids": input_ids}
783
+
784
+ model_inputs["state"] = state
785
+ model_inputs["use_cache"] = use_cache
786
+ return model_inputs
787
+
788
+ @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
789
+ @add_code_sample_docstrings(
790
+ checkpoint=_CHECKPOINT_FOR_DOC,
791
+ output_type=RwkvCausalLMOutput,
792
+ config_class=_CONFIG_FOR_DOC,
793
+ )
794
+ def forward(
795
+ self,
796
+ input_ids: Optional[torch.LongTensor] = None,
797
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
798
+ inputs_embeds: Optional[torch.FloatTensor] = None,
799
+ state: Optional[List[torch.FloatTensor]] = None,
800
+ labels: Optional[torch.LongTensor] = None,
801
+ use_cache: Optional[bool] = None,
802
+ output_attentions: Optional[bool] = None,
803
+ output_hidden_states: Optional[bool] = None,
804
+ return_dict: Optional[bool] = None,
805
+ **kwargs,
806
+ ) -> Union[Tuple, RwkvCausalLMOutput]:
807
+ r"""
808
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
809
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
810
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
811
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
812
+ """
813
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
814
+
815
+ rwkv_outputs = self.rwkv(
816
+ input_ids,
817
+ inputs_embeds=inputs_embeds,
818
+ state=state,
819
+ use_cache=use_cache,
820
+ output_attentions=output_attentions,
821
+ output_hidden_states=output_hidden_states,
822
+ return_dict=return_dict,
823
+ )
824
+ hidden_states = rwkv_outputs[0]
825
+
826
+ logits = self.head(hidden_states)
827
+
828
+ loss = None
829
+ if labels is not None:
830
+ loss = self.loss_function(
831
+ logits,
832
+ labels,
833
+ vocab_size=self.config.vocab_size,
834
+ **kwargs,
835
+ )
836
+
837
+ if not return_dict:
838
+ output = (logits,) + rwkv_outputs[1:]
839
+ return ((loss,) + output) if loss is not None else output
840
+
841
+ return RwkvCausalLMOutput(
842
+ loss=loss,
843
+ logits=logits,
844
+ state=rwkv_outputs.state,
845
+ hidden_states=rwkv_outputs.hidden_states,
846
+ attentions=rwkv_outputs.attentions,
847
+ )
848
+
849
+
850
+ __all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"]
docs/transformers/build/lib/transformers/models/sam/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_sam import *
22
+ from .image_processing_sam import *
23
+ from .modeling_sam import *
24
+ from .modeling_tf_sam import *
25
+ from .processing_sam import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/sam/configuration_sam.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SAM model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class SamPromptEncoderConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`]
27
+ module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield
28
+ a similar configuration to that of the SAM-vit-h
29
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ hidden_size (`int`, *optional*, defaults to 256):
36
+ Dimensionality of the hidden states.
37
+ image_size (`int`, *optional*, defaults to 1024):
38
+ The expected output resolution of the image.
39
+ patch_size (`int`, *optional*, defaults to 16):
40
+ The size (resolution) of each patch.
41
+ mask_input_channels (`int`, *optional*, defaults to 16):
42
+ The number of channels to be fed to the `MaskDecoder` module.
43
+ num_point_embeddings (`int`, *optional*, defaults to 4):
44
+ The number of point embeddings to be used.
45
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
46
+ The non-linear activation function in the encoder and pooler.
47
+ """
48
+
49
+ base_config_key = "prompt_encoder_config"
50
+
51
+ def __init__(
52
+ self,
53
+ hidden_size=256,
54
+ image_size=1024,
55
+ patch_size=16,
56
+ mask_input_channels=16,
57
+ num_point_embeddings=4,
58
+ hidden_act="gelu",
59
+ layer_norm_eps=1e-6,
60
+ **kwargs,
61
+ ):
62
+ super().__init__(**kwargs)
63
+ self.hidden_size = hidden_size
64
+ self.image_size = image_size
65
+ self.patch_size = patch_size
66
+ self.image_embedding_size = image_size // patch_size
67
+ self.mask_input_channels = mask_input_channels
68
+ self.num_point_embeddings = num_point_embeddings
69
+ self.hidden_act = hidden_act
70
+ self.layer_norm_eps = layer_norm_eps
71
+
72
+
73
+ class SamMaskDecoderConfig(PretrainedConfig):
74
+ r"""
75
+ This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM
76
+ mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
77
+ will yield a similar configuration to that of the SAM-vit-h
78
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
79
+
80
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
81
+ documentation from [`PretrainedConfig`] for more information.
82
+
83
+ Args:
84
+ hidden_size (`int`, *optional*, defaults to 256):
85
+ Dimensionality of the hidden states.
86
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
87
+ The non-linear activation function used inside the `SamMaskDecoder` module.
88
+ mlp_dim (`int`, *optional*, defaults to 2048):
89
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
90
+ num_hidden_layers (`int`, *optional*, defaults to 2):
91
+ Number of hidden layers in the Transformer encoder.
92
+ num_attention_heads (`int`, *optional*, defaults to 8):
93
+ Number of attention heads for each attention layer in the Transformer encoder.
94
+ attention_downsample_rate (`int`, *optional*, defaults to 2):
95
+ The downsampling rate of the attention layer.
96
+ num_multimask_outputs (`int`, *optional*, defaults to 3):
97
+ The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3.
98
+ iou_head_depth (`int`, *optional*, defaults to 3):
99
+ The number of layers in the IoU head module.
100
+ iou_head_hidden_dim (`int`, *optional*, defaults to 256):
101
+ The dimensionality of the hidden states in the IoU head module.
102
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
103
+ The epsilon used by the layer normalization layers.
104
+
105
+ """
106
+
107
+ base_config_key = "mask_decoder_config"
108
+
109
+ def __init__(
110
+ self,
111
+ hidden_size=256,
112
+ hidden_act="relu",
113
+ mlp_dim=2048,
114
+ num_hidden_layers=2,
115
+ num_attention_heads=8,
116
+ attention_downsample_rate=2,
117
+ num_multimask_outputs=3,
118
+ iou_head_depth=3,
119
+ iou_head_hidden_dim=256,
120
+ layer_norm_eps=1e-6,
121
+ **kwargs,
122
+ ):
123
+ super().__init__(**kwargs)
124
+ self.hidden_size = hidden_size
125
+ self.hidden_act = hidden_act
126
+ self.mlp_dim = mlp_dim
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+ self.attention_downsample_rate = attention_downsample_rate
130
+ self.num_multimask_outputs = num_multimask_outputs
131
+ self.iou_head_depth = iou_head_depth
132
+ self.iou_head_hidden_dim = iou_head_hidden_dim
133
+ self.layer_norm_eps = layer_norm_eps
134
+
135
+
136
+ class SamVisionConfig(PretrainedConfig):
137
+ r"""
138
+ This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM
139
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
140
+ defaults will yield a similar configuration to that of the SAM ViT-h
141
+ [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
142
+
143
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
144
+ documentation from [`PretrainedConfig`] for more information.
145
+
146
+ Args:
147
+ hidden_size (`int`, *optional*, defaults to 768):
148
+ Dimensionality of the encoder layers and the pooler layer.
149
+ output_channels (`int`, *optional*, defaults to 256):
150
+ Dimensionality of the output channels in the Patch Encoder.
151
+ num_hidden_layers (`int`, *optional*, defaults to 12):
152
+ Number of hidden layers in the Transformer encoder.
153
+ num_attention_heads (`int`, *optional*, defaults to 12):
154
+ Number of attention heads for each attention layer in the Transformer encoder.
155
+ num_channels (`int`, *optional*, defaults to 3):
156
+ Number of channels in the input image.
157
+ image_size (`int`, *optional*, defaults to 1024):
158
+ Expected resolution. Target size of the resized input image.
159
+ patch_size (`int`, *optional*, defaults to 16):
160
+ Size of the patches to be extracted from the input image.
161
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
162
+ The non-linear activation function (function or string)
163
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
164
+ The epsilon used by the layer normalization layers.
165
+ attention_dropout (`float`, *optional*, defaults to 0.0):
166
+ The dropout ratio for the attention probabilities.
167
+ initializer_range (`float`, *optional*, defaults to 1e-10):
168
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
169
+ qkv_bias (`bool`, *optional*, defaults to `True`):
170
+ Whether to add a bias to query, key, value projections.
171
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
172
+ Ratio of mlp hidden dim to embedding dim.
173
+ use_abs_pos (`bool`, *optional*, defaults to `True`):
174
+ Whether to use absolute position embedding.
175
+ use_rel_pos (`bool`, *optional*, defaults to `True`):
176
+ Whether to use relative position embedding.
177
+ window_size (`int`, *optional*, defaults to 14):
178
+ Window size for relative position.
179
+ global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
180
+ The indexes of the global attention layers.
181
+ num_pos_feats (`int`, *optional*, defaults to 128):
182
+ The dimensionality of the position embedding.
183
+ mlp_dim (`int`, *optional*):
184
+ The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
185
+ hidden_size`.
186
+
187
+ Example:
188
+
189
+ ```python
190
+ >>> from transformers import (
191
+ ... SamVisionConfig,
192
+ ... SamVisionModel,
193
+ ... )
194
+
195
+ >>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
196
+ >>> configuration = SamVisionConfig()
197
+
198
+ >>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
199
+ >>> model = SamVisionModel(configuration)
200
+
201
+ >>> # Accessing the model configuration
202
+ >>> configuration = model.config
203
+ ```"""
204
+
205
+ base_config_key = "vision_config"
206
+ model_type = "sam_vision_model"
207
+
208
+ def __init__(
209
+ self,
210
+ hidden_size=768,
211
+ output_channels=256,
212
+ num_hidden_layers=12,
213
+ num_attention_heads=12,
214
+ num_channels=3,
215
+ image_size=1024,
216
+ patch_size=16,
217
+ hidden_act="gelu",
218
+ layer_norm_eps=1e-06,
219
+ attention_dropout=0.0,
220
+ initializer_range=1e-10,
221
+ qkv_bias=True,
222
+ mlp_ratio=4.0,
223
+ use_abs_pos=True,
224
+ use_rel_pos=True,
225
+ window_size=14,
226
+ global_attn_indexes=[2, 5, 8, 11],
227
+ num_pos_feats=128,
228
+ mlp_dim=None,
229
+ **kwargs,
230
+ ):
231
+ super().__init__(**kwargs)
232
+
233
+ self.hidden_size = hidden_size
234
+ self.output_channels = output_channels
235
+ self.num_hidden_layers = num_hidden_layers
236
+ self.num_attention_heads = num_attention_heads
237
+ self.num_channels = num_channels
238
+ self.image_size = image_size
239
+ self.patch_size = patch_size
240
+ self.hidden_act = hidden_act
241
+ self.layer_norm_eps = layer_norm_eps
242
+ self.attention_dropout = attention_dropout
243
+ self.initializer_range = initializer_range
244
+ self.qkv_bias = qkv_bias
245
+ self.mlp_ratio = mlp_ratio
246
+ self.use_abs_pos = use_abs_pos
247
+ self.use_rel_pos = use_rel_pos
248
+ self.window_size = window_size
249
+ self.global_attn_indexes = global_attn_indexes
250
+ self.num_pos_feats = num_pos_feats
251
+ self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim
252
+
253
+
254
+ class SamConfig(PretrainedConfig):
255
+ r"""
256
+ [`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a
257
+ SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder
258
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
259
+ SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture.
260
+
261
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
262
+ documentation from [`PretrainedConfig`] for more information.
263
+
264
+ Args:
265
+ vision_config (Union[`dict`, `SamVisionConfig`], *optional*):
266
+ Dictionary of configuration options used to initialize [`SamVisionConfig`].
267
+ prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*):
268
+ Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`].
269
+ mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*):
270
+ Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`].
271
+
272
+ kwargs (*optional*):
273
+ Dictionary of keyword arguments.
274
+
275
+ Example:
276
+
277
+ ```python
278
+ >>> from transformers import (
279
+ ... SamVisionConfig,
280
+ ... SamPromptEncoderConfig,
281
+ ... SamMaskDecoderConfig,
282
+ ... SamModel,
283
+ ... )
284
+
285
+ >>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
286
+ >>> configuration = SamConfig()
287
+
288
+ >>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
289
+ >>> model = SamModel(configuration)
290
+
291
+ >>> # Accessing the model configuration
292
+ >>> configuration = model.config
293
+
294
+ >>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig
295
+
296
+ >>> # Initializing SAM vision, SAM Q-Former and language model configurations
297
+ >>> vision_config = SamVisionConfig()
298
+ >>> prompt_encoder_config = SamPromptEncoderConfig()
299
+ >>> mask_decoder_config = SamMaskDecoderConfig()
300
+
301
+ >>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
302
+ ```"""
303
+
304
+ model_type = "sam"
305
+ sub_configs = {
306
+ "prompt_encoder_config": SamPromptEncoderConfig,
307
+ "mask_decoder_config": SamMaskDecoderConfig,
308
+ "vision_config": SamVisionConfig,
309
+ }
310
+
311
+ def __init__(
312
+ self,
313
+ vision_config=None,
314
+ prompt_encoder_config=None,
315
+ mask_decoder_config=None,
316
+ initializer_range=0.02,
317
+ **kwargs,
318
+ ):
319
+ super().__init__(**kwargs)
320
+ vision_config = vision_config if vision_config is not None else {}
321
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
322
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
323
+
324
+ if isinstance(vision_config, SamVisionConfig):
325
+ vision_config = vision_config.to_dict()
326
+ if isinstance(prompt_encoder_config, SamPromptEncoderConfig):
327
+ prompt_encoder_config = prompt_encoder_config.to_dict()
328
+ if isinstance(mask_decoder_config, SamMaskDecoderConfig):
329
+ mask_decoder_config = mask_decoder_config.to_dict()
330
+
331
+ self.vision_config = SamVisionConfig(**vision_config)
332
+ self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
333
+ self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
334
+ self.initializer_range = initializer_range
335
+
336
+
337
+ __all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"]
docs/transformers/build/lib/transformers/models/sam/image_processing_sam.py ADDED
@@ -0,0 +1,1494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SAM."""
16
+
17
+ import math
18
+ from copy import deepcopy
19
+ from itertools import product
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+
24
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
25
+ from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
26
+ from ...image_utils import (
27
+ IMAGENET_DEFAULT_MEAN,
28
+ IMAGENET_DEFAULT_STD,
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ get_image_size,
33
+ infer_channel_dimension_format,
34
+ is_scaled_image,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ validate_preprocess_arguments,
39
+ )
40
+ from ...utils import (
41
+ TensorType,
42
+ filter_out_non_signature_kwargs,
43
+ is_tf_available,
44
+ is_torch_available,
45
+ is_torchvision_available,
46
+ logging,
47
+ requires_backends,
48
+ )
49
+
50
+
51
+ if is_torch_available():
52
+ import torch
53
+ import torch.nn.functional as F
54
+
55
+ if is_torchvision_available():
56
+ from torchvision.ops.boxes import batched_nms
57
+
58
+ if is_tf_available():
59
+ import tensorflow as tf
60
+ from tensorflow.experimental import numpy as tnp
61
+
62
+ from ...tf_utils import flatten, shape_list
63
+
64
+ logger = logging.get_logger(__name__)
65
+
66
+
67
+ class SamImageProcessor(BaseImageProcessor):
68
+ r"""
69
+ Constructs a SAM image processor.
70
+
71
+ Args:
72
+ do_resize (`bool`, *optional*, defaults to `True`):
73
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
74
+ `do_resize` parameter in the `preprocess` method.
75
+ size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
76
+ Size of the output image after resizing. Resizes the longest edge of the image to match
77
+ `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the
78
+ `preprocess` method.
79
+ mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`):
80
+ Size of the output segmentation map after resizing. Resizes the longest edge of the image to match
81
+ `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter
82
+ in the `preprocess` method.
83
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
84
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
85
+ `preprocess` method.
86
+ do_rescale (`bool`, *optional*, defaults to `True`):
87
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
88
+ `do_rescale` parameter in the `preprocess` method.
89
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
90
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
91
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
92
+ do_normalize (`bool`, *optional*, defaults to `True`):
93
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
94
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
95
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
96
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
97
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
98
+ overridden by the `image_mean` parameter in the `preprocess` method.
99
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
100
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
101
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
102
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
103
+ do_pad (`bool`, *optional*, defaults to `True`):
104
+ Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
105
+ `preprocess` method.
106
+ pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
107
+ Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
108
+ method.
109
+ mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`):
110
+ Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in
111
+ the `preprocess` method.
112
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
113
+ Whether to convert the image to RGB.
114
+ """
115
+
116
+ model_input_names = ["pixel_values"]
117
+
118
+ def __init__(
119
+ self,
120
+ do_resize: bool = True,
121
+ size: Dict[str, int] = None,
122
+ mask_size: Dict[str, int] = None,
123
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
124
+ do_rescale: bool = True,
125
+ rescale_factor: Union[int, float] = 1 / 255,
126
+ do_normalize: bool = True,
127
+ image_mean: Optional[Union[float, List[float]]] = None,
128
+ image_std: Optional[Union[float, List[float]]] = None,
129
+ do_pad: bool = True,
130
+ pad_size: Optional[int] = None,
131
+ mask_pad_size: Optional[int] = None,
132
+ do_convert_rgb: bool = True,
133
+ **kwargs,
134
+ ) -> None:
135
+ super().__init__(**kwargs)
136
+ size = size if size is not None else {"longest_edge": 1024}
137
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
138
+
139
+ pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
140
+ pad_size = get_size_dict(pad_size, default_to_square=True)
141
+
142
+ mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
143
+ mask_size = (
144
+ get_size_dict(max_size=mask_size, default_to_square=False)
145
+ if not isinstance(mask_size, dict)
146
+ else mask_size
147
+ )
148
+
149
+ mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
150
+ mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
151
+
152
+ self.do_resize = do_resize
153
+ self.size = size
154
+ self.mask_size = mask_size
155
+ self.resample = resample
156
+ self.do_rescale = do_rescale
157
+ self.rescale_factor = rescale_factor
158
+ self.do_normalize = do_normalize
159
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
160
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
161
+ self.do_pad = do_pad
162
+ self.pad_size = pad_size
163
+ self.mask_pad_size = mask_pad_size
164
+ self.do_convert_rgb = do_convert_rgb
165
+
166
+ def pad_image(
167
+ self,
168
+ image: np.ndarray,
169
+ pad_size: Dict[str, int],
170
+ data_format: Optional[Union[str, ChannelDimension]] = None,
171
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
172
+ **kwargs,
173
+ ) -> np.ndarray:
174
+ """
175
+ Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.
176
+
177
+ Args:
178
+ image (`np.ndarray`):
179
+ Image to pad.
180
+ pad_size (`Dict[str, int]`):
181
+ Size of the output image after padding.
182
+ data_format (`str` or `ChannelDimension`, *optional*):
183
+ The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
184
+ `data_format` of the `image` will be used.
185
+ input_data_format (`str` or `ChannelDimension`, *optional*):
186
+ The channel dimension format of the input image. If not provided, it will be inferred.
187
+ """
188
+ output_height, output_width = pad_size["height"], pad_size["width"]
189
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
190
+
191
+ pad_width = output_width - input_width
192
+ pad_height = output_height - input_height
193
+
194
+ padded_image = pad(
195
+ image,
196
+ ((0, pad_height), (0, pad_width)),
197
+ data_format=data_format,
198
+ input_data_format=input_data_format,
199
+ **kwargs,
200
+ )
201
+ return padded_image
202
+
203
+ def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
204
+ """
205
+ Compute the output size given input size and target long side length.
206
+ """
207
+ oldh, oldw = old_shape
208
+ scale = longest_edge * 1.0 / max(oldh, oldw)
209
+ newh, neww = oldh * scale, oldw * scale
210
+ newh = int(newh + 0.5)
211
+ neww = int(neww + 0.5)
212
+ return (newh, neww)
213
+
214
+ def resize(
215
+ self,
216
+ image: np.ndarray,
217
+ size: Dict[str, int],
218
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
219
+ data_format: Optional[Union[str, ChannelDimension]] = None,
220
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
221
+ **kwargs,
222
+ ) -> np.ndarray:
223
+ """
224
+ Resize an image to `(size["height"], size["width"])`.
225
+
226
+ Args:
227
+ image (`np.ndarray`):
228
+ Image to resize.
229
+ size (`Dict[str, int]`):
230
+ Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
231
+ edge of the image will be resized to the specified size, while the other edge will be resized to
232
+ maintain the aspect ratio.
233
+ resample:
234
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
235
+ data_format (`ChannelDimension` or `str`, *optional*):
236
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
237
+ image is used. Can be one of:
238
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
239
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
240
+ input_data_format (`ChannelDimension` or `str`, *optional*):
241
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
242
+ from the input image. Can be one of:
243
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
244
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
245
+
246
+ Returns:
247
+ `np.ndarray`: The resized image.
248
+ """
249
+ size = get_size_dict(size)
250
+ if "longest_edge" not in size:
251
+ raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
252
+ input_size = get_image_size(image, channel_dim=input_data_format)
253
+ output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
254
+ return resize(
255
+ image,
256
+ size=(output_height, output_width),
257
+ resample=resample,
258
+ data_format=data_format,
259
+ input_data_format=input_data_format,
260
+ **kwargs,
261
+ )
262
+
263
+ def _preprocess(
264
+ self,
265
+ image: ImageInput,
266
+ do_resize: bool,
267
+ do_rescale: bool,
268
+ do_normalize: bool,
269
+ size: Optional[Dict[str, int]] = None,
270
+ resample: PILImageResampling = None,
271
+ rescale_factor: Optional[float] = None,
272
+ image_mean: Optional[Union[float, List[float]]] = None,
273
+ image_std: Optional[Union[float, List[float]]] = None,
274
+ do_pad: Optional[bool] = None,
275
+ pad_size: Optional[Dict[str, int]] = None,
276
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
277
+ ):
278
+ if do_resize:
279
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
280
+ reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
281
+
282
+ if do_rescale:
283
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
284
+
285
+ if do_normalize:
286
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
287
+
288
+ if do_pad:
289
+ image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
290
+
291
+ return image, reshaped_input_size
292
+
293
+ def _preprocess_image(
294
+ self,
295
+ image: ImageInput,
296
+ do_resize: Optional[bool] = None,
297
+ size: Dict[str, int] = None,
298
+ resample: PILImageResampling = None,
299
+ do_rescale: Optional[bool] = None,
300
+ rescale_factor: Optional[float] = None,
301
+ do_normalize: Optional[bool] = None,
302
+ image_mean: Optional[Union[float, List[float]]] = None,
303
+ image_std: Optional[Union[float, List[float]]] = None,
304
+ do_pad: Optional[bool] = None,
305
+ pad_size: Optional[Dict[str, int]] = None,
306
+ do_convert_rgb: Optional[bool] = None,
307
+ data_format: Optional[Union[str, ChannelDimension]] = None,
308
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
309
+ ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
310
+ # PIL RGBA images are converted to RGB
311
+ if do_convert_rgb:
312
+ image = convert_to_rgb(image)
313
+
314
+ # All transformations expect numpy arrays.
315
+ image = to_numpy_array(image)
316
+
317
+ if do_rescale and is_scaled_image(image):
318
+ logger.warning_once(
319
+ "It looks like you are trying to rescale already rescaled images. If the input"
320
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
321
+ )
322
+
323
+ if input_data_format is None:
324
+ input_data_format = infer_channel_dimension_format(image)
325
+
326
+ original_size = get_image_size(image, channel_dim=input_data_format)
327
+
328
+ image, reshaped_input_size = self._preprocess(
329
+ image=image,
330
+ do_resize=do_resize,
331
+ size=size,
332
+ resample=resample,
333
+ do_rescale=do_rescale,
334
+ rescale_factor=rescale_factor,
335
+ do_normalize=do_normalize,
336
+ image_mean=image_mean,
337
+ image_std=image_std,
338
+ do_pad=do_pad,
339
+ pad_size=pad_size,
340
+ input_data_format=input_data_format,
341
+ )
342
+
343
+ if data_format is not None:
344
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
345
+
346
+ return image, original_size, reshaped_input_size
347
+
348
+ def _preprocess_mask(
349
+ self,
350
+ segmentation_map: ImageInput,
351
+ do_resize: Optional[bool] = None,
352
+ mask_size: Dict[str, int] = None,
353
+ do_pad: Optional[bool] = None,
354
+ mask_pad_size: Optional[Dict[str, int]] = None,
355
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
356
+ ) -> np.ndarray:
357
+ segmentation_map = to_numpy_array(segmentation_map)
358
+
359
+ # Add channel dimension if missing - needed for certain transformations
360
+ if segmentation_map.ndim == 2:
361
+ added_channel_dim = True
362
+ segmentation_map = segmentation_map[None, ...]
363
+ input_data_format = ChannelDimension.FIRST
364
+ else:
365
+ added_channel_dim = False
366
+ if input_data_format is None:
367
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
368
+
369
+ original_size = get_image_size(segmentation_map, channel_dim=input_data_format)
370
+
371
+ segmentation_map, _ = self._preprocess(
372
+ image=segmentation_map,
373
+ do_resize=do_resize,
374
+ size=mask_size,
375
+ resample=PILImageResampling.NEAREST,
376
+ do_rescale=False,
377
+ do_normalize=False,
378
+ do_pad=do_pad,
379
+ pad_size=mask_pad_size,
380
+ input_data_format=input_data_format,
381
+ )
382
+
383
+ # Remove extra channel dimension if added for processing
384
+ if added_channel_dim:
385
+ segmentation_map = segmentation_map.squeeze(0)
386
+ segmentation_map = segmentation_map.astype(np.int64)
387
+
388
+ return segmentation_map, original_size
389
+
390
+ @filter_out_non_signature_kwargs()
391
+ def preprocess(
392
+ self,
393
+ images: ImageInput,
394
+ segmentation_maps: Optional[ImageInput] = None,
395
+ do_resize: Optional[bool] = None,
396
+ size: Optional[Dict[str, int]] = None,
397
+ mask_size: Optional[Dict[str, int]] = None,
398
+ resample: Optional["PILImageResampling"] = None,
399
+ do_rescale: Optional[bool] = None,
400
+ rescale_factor: Optional[Union[int, float]] = None,
401
+ do_normalize: Optional[bool] = None,
402
+ image_mean: Optional[Union[float, List[float]]] = None,
403
+ image_std: Optional[Union[float, List[float]]] = None,
404
+ do_pad: Optional[bool] = None,
405
+ pad_size: Optional[Dict[str, int]] = None,
406
+ mask_pad_size: Optional[Dict[str, int]] = None,
407
+ do_convert_rgb: Optional[bool] = None,
408
+ return_tensors: Optional[Union[str, TensorType]] = None,
409
+ data_format: ChannelDimension = ChannelDimension.FIRST,
410
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
411
+ ):
412
+ """
413
+ Preprocess an image or batch of images.
414
+
415
+ Args:
416
+ images (`ImageInput`):
417
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
418
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
419
+ segmentation_maps (`ImageInput`, *optional*):
420
+ Segmentation map to preprocess.
421
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
422
+ Whether to resize the image.
423
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
424
+ Controls the size of the image after `resize`. The longest edge of the image is resized to
425
+ `size["longest_edge"]` whilst preserving the aspect ratio.
426
+ mask_size (`Dict[str, int]`, *optional*, defaults to `self.mask_size`):
427
+ Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to
428
+ `size["longest_edge"]` whilst preserving the aspect ratio.
429
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
430
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
431
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
432
+ Whether to rescale the image pixel values by rescaling factor.
433
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
434
+ Rescale factor to apply to the image pixel values.
435
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
436
+ Whether to normalize the image.
437
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
438
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
439
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
440
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
441
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
442
+ Whether to pad the image.
443
+ pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
444
+ Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
445
+ `pad_size["width"]` if `do_pad` is set to `True`.
446
+ mask_pad_size (`Dict[str, int]`, *optional*, defaults to `self.mask_pad_size`):
447
+ Controls the size of the padding applied to the segmentation map. The image is padded to
448
+ `mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`.
449
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
450
+ Whether to convert the image to RGB.
451
+ return_tensors (`str` or `TensorType`, *optional*):
452
+ The type of tensors to return. Can be one of:
453
+ - Unset: Return a list of `np.ndarray`.
454
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
455
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
456
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
457
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
458
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
459
+ The channel dimension format for the output image. Can be one of:
460
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
461
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
462
+ - Unset: Use the channel dimension format of the input image.
463
+ input_data_format (`ChannelDimension` or `str`, *optional*):
464
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
465
+ from the input image. Can be one of:
466
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
467
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
468
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
469
+ """
470
+ do_resize = do_resize if do_resize is not None else self.do_resize
471
+ size = size if size is not None else self.size
472
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
473
+ mask_size = mask_size if mask_size is not None else self.mask_size
474
+ mask_size = (
475
+ get_size_dict(max_size=mask_size, default_to_square=False)
476
+ if not isinstance(mask_size, dict)
477
+ else mask_size
478
+ )
479
+ resample = resample if resample is not None else self.resample
480
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
481
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
482
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
483
+ image_mean = image_mean if image_mean is not None else self.image_mean
484
+ image_std = image_std if image_std is not None else self.image_std
485
+ do_pad = do_pad if do_pad is not None else self.do_pad
486
+ pad_size = pad_size if pad_size is not None else self.pad_size
487
+ pad_size = get_size_dict(pad_size, default_to_square=True)
488
+ mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size
489
+ mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
490
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
491
+
492
+ images = make_list_of_images(images)
493
+
494
+ if not valid_images(images):
495
+ raise ValueError(
496
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
497
+ "torch.Tensor, tf.Tensor or jax.ndarray."
498
+ )
499
+
500
+ if segmentation_maps is not None:
501
+ segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
502
+
503
+ if not valid_images(segmentation_maps):
504
+ raise ValueError(
505
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
506
+ "torch.Tensor, tf.Tensor or jax.ndarray."
507
+ )
508
+ validate_preprocess_arguments(
509
+ do_rescale=do_rescale,
510
+ rescale_factor=rescale_factor,
511
+ do_normalize=do_normalize,
512
+ image_mean=image_mean,
513
+ image_std=image_std,
514
+ do_pad=do_pad,
515
+ size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size.
516
+ do_resize=do_resize,
517
+ size=size,
518
+ resample=resample,
519
+ )
520
+
521
+ images, original_sizes, reshaped_input_sizes = zip(
522
+ *(
523
+ self._preprocess_image(
524
+ image=img,
525
+ do_resize=do_resize,
526
+ size=size,
527
+ resample=resample,
528
+ do_rescale=do_rescale,
529
+ rescale_factor=rescale_factor,
530
+ do_normalize=do_normalize,
531
+ image_mean=image_mean,
532
+ image_std=image_std,
533
+ do_pad=do_pad,
534
+ pad_size=pad_size,
535
+ do_convert_rgb=do_convert_rgb,
536
+ data_format=data_format,
537
+ input_data_format=input_data_format,
538
+ )
539
+ for img in images
540
+ )
541
+ )
542
+
543
+ data = {
544
+ "pixel_values": images,
545
+ "original_sizes": original_sizes,
546
+ "reshaped_input_sizes": reshaped_input_sizes,
547
+ }
548
+
549
+ if segmentation_maps is not None:
550
+ segmentation_maps, original_mask_sizes = zip(
551
+ *(
552
+ self._preprocess_mask(
553
+ segmentation_map=mask,
554
+ do_resize=do_resize,
555
+ mask_size=mask_size,
556
+ do_pad=do_pad,
557
+ mask_pad_size=mask_pad_size,
558
+ input_data_format=input_data_format,
559
+ )
560
+ for mask in segmentation_maps
561
+ )
562
+ )
563
+
564
+ # masks should start out the same size as input images
565
+ assert all(
566
+ original_im_size == original_mask_size
567
+ for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes)
568
+ ), "Segmentation maps should be the same size as input images."
569
+
570
+ data["labels"] = segmentation_maps
571
+
572
+ return BatchFeature(data=data, tensor_type=return_tensors)
573
+
574
+ def post_process_masks(
575
+ self,
576
+ masks,
577
+ original_sizes,
578
+ reshaped_input_sizes,
579
+ mask_threshold=0.0,
580
+ binarize=True,
581
+ pad_size=None,
582
+ return_tensors="pt",
583
+ ):
584
+ """
585
+ Remove padding and upscale masks to the original image size.
586
+
587
+ Args:
588
+ masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
589
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
590
+ original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
591
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
592
+ width) format.
593
+ reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
594
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
595
+ mask_threshold (`float`, *optional*, defaults to 0.0):
596
+ The threshold to use for binarizing the masks.
597
+ binarize (`bool`, *optional*, defaults to `True`):
598
+ Whether to binarize the masks.
599
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
600
+ The target size the images were padded to before being passed to the model. If None, the target size is
601
+ assumed to be the processor's `pad_size`.
602
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
603
+ If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
604
+ Returns:
605
+ (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
606
+ (height, width) is given by original_size.
607
+ """
608
+ if return_tensors == "pt":
609
+ return self._post_process_masks_pt(
610
+ masks=masks,
611
+ original_sizes=original_sizes,
612
+ reshaped_input_sizes=reshaped_input_sizes,
613
+ mask_threshold=mask_threshold,
614
+ binarize=binarize,
615
+ pad_size=pad_size,
616
+ )
617
+ elif return_tensors == "tf":
618
+ return self._post_process_masks_tf(
619
+ masks=masks,
620
+ original_sizes=original_sizes,
621
+ reshaped_input_sizes=reshaped_input_sizes,
622
+ mask_threshold=mask_threshold,
623
+ binarize=binarize,
624
+ pad_size=pad_size,
625
+ )
626
+ else:
627
+ raise ValueError("return_tensors must be either 'pt' or 'tf'")
628
+
629
+ def _post_process_masks_pt(
630
+ self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
631
+ ):
632
+ """
633
+ Remove padding and upscale masks to the original image size.
634
+
635
+ Args:
636
+ masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
637
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
638
+ original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
639
+ The original sizes of each image before it was resized to the model's expected input shape, in (height,
640
+ width) format.
641
+ reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
642
+ The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
643
+ mask_threshold (`float`, *optional*, defaults to 0.0):
644
+ The threshold to use for binarizing the masks.
645
+ binarize (`bool`, *optional*, defaults to `True`):
646
+ Whether to binarize the masks.
647
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
648
+ The target size the images were padded to before being passed to the model. If None, the target size is
649
+ assumed to be the processor's `pad_size`.
650
+ Returns:
651
+ (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
652
+ is given by original_size.
653
+ """
654
+ requires_backends(self, ["torch"])
655
+ pad_size = self.pad_size if pad_size is None else pad_size
656
+ target_image_size = (pad_size["height"], pad_size["width"])
657
+ if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
658
+ original_sizes = original_sizes.tolist()
659
+ if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
660
+ reshaped_input_sizes = reshaped_input_sizes.tolist()
661
+ output_masks = []
662
+ for i, original_size in enumerate(original_sizes):
663
+ if isinstance(masks[i], np.ndarray):
664
+ masks[i] = torch.from_numpy(masks[i])
665
+ elif not isinstance(masks[i], torch.Tensor):
666
+ raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
667
+ interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
668
+ interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
669
+ interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
670
+ if binarize:
671
+ interpolated_mask = interpolated_mask > mask_threshold
672
+ output_masks.append(interpolated_mask)
673
+
674
+ return output_masks
675
+
676
+ def _post_process_masks_tf(
677
+ self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
678
+ ):
679
+ """
680
+ Remove padding and upscale masks to the original image size.
681
+
682
+ Args:
683
+ masks (`tf.Tensor`):
684
+ Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
685
+ original_sizes (`tf.Tensor`):
686
+ The original size of the images before resizing for input to the model, in (height, width) format.
687
+ reshaped_input_sizes (`tf.Tensor`):
688
+ The size of the image input to the model, in (height, width) format. Used to remove padding.
689
+ mask_threshold (`float`, *optional*, defaults to 0.0):
690
+ The threshold to use for binarizing the masks.
691
+ binarize (`bool`, *optional*, defaults to `True`):
692
+ Whether to binarize the masks.
693
+ pad_size (`int`, *optional*, defaults to `self.pad_size`):
694
+ The target size the images were padded to before being passed to the model. If None, the target size is
695
+ assumed to be the processor's `pad_size`.
696
+ Returns:
697
+ (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
698
+ given by original_size.
699
+ """
700
+ requires_backends(self, ["tf"])
701
+ pad_size = self.pad_size if pad_size is None else pad_size
702
+ target_image_size = (pad_size["height"], pad_size["width"])
703
+
704
+ output_masks = []
705
+ for i, original_size in enumerate(original_sizes):
706
+ # tf.image expects NHWC, we transpose the NCHW inputs for it
707
+ mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
708
+ interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
709
+ interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
710
+ interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
711
+ if binarize:
712
+ interpolated_mask = interpolated_mask > mask_threshold
713
+ # And then we transpose them back at the end
714
+ output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
715
+
716
+ return output_masks
717
+
718
+ def post_process_for_mask_generation(
719
+ self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
720
+ ):
721
+ """
722
+ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
723
+
724
+ Args:
725
+ all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
726
+ List of all predicted segmentation masks
727
+ all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
728
+ List of all predicted iou scores
729
+ all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
730
+ List of all bounding boxes of the predicted masks
731
+ crops_nms_thresh (`float`):
732
+ Threshold for NMS (Non Maximum Suppression) algorithm.
733
+ return_tensors (`str`, *optional*, defaults to `pt`):
734
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
735
+ """
736
+ if return_tensors == "pt":
737
+ return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
738
+ elif return_tensors == "tf":
739
+ return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
740
+
741
+ def generate_crop_boxes(
742
+ self,
743
+ image,
744
+ target_size,
745
+ crop_n_layers: int = 0,
746
+ overlap_ratio: float = 512 / 1500,
747
+ points_per_crop: Optional[int] = 32,
748
+ crop_n_points_downscale_factor: Optional[List[int]] = 1,
749
+ device: Optional["torch.device"] = None,
750
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
751
+ return_tensors: str = "pt",
752
+ ):
753
+ """
754
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
755
+
756
+ Args:
757
+ image (`np.array`):
758
+ Input original image
759
+ target_size (`int`):
760
+ Target size of the resized image
761
+ crop_n_layers (`int`, *optional*, defaults to 0):
762
+ If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
763
+ each layer has 2**i_layer number of image crops.
764
+ overlap_ratio (`float`, *optional*, defaults to 512/1500):
765
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
766
+ the image length. Later layers with more crops scale down this overlap.
767
+ points_per_crop (`int`, *optional*, defaults to 32):
768
+ Number of points to sample from each crop.
769
+ crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
770
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
771
+ device (`torch.device`, *optional*, defaults to None):
772
+ Device to use for the computation. If None, cpu will be used.
773
+ input_data_format (`str` or `ChannelDimension`, *optional*):
774
+ The channel dimension format of the input image. If not provided, it will be inferred.
775
+ return_tensors (`str`, *optional*, defaults to `pt`):
776
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
777
+ """
778
+ crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
779
+ image,
780
+ target_size,
781
+ crop_n_layers,
782
+ overlap_ratio,
783
+ points_per_crop,
784
+ crop_n_points_downscale_factor,
785
+ input_data_format,
786
+ )
787
+ if return_tensors == "pt":
788
+ if device is None:
789
+ device = torch.device("cpu")
790
+ crop_boxes = torch.tensor(crop_boxes, device=device)
791
+ points_per_crop = torch.tensor(points_per_crop, device=device)
792
+ # cropped_images stays as np
793
+ input_labels = torch.tensor(input_labels, device=device)
794
+
795
+ elif return_tensors == "tf":
796
+ if device is not None:
797
+ raise ValueError("device is not a supported argument when return_tensors is tf!")
798
+ crop_boxes = tf.convert_to_tensor(crop_boxes)
799
+ points_per_crop = tf.convert_to_tensor(points_per_crop)
800
+ # cropped_images stays as np
801
+ input_labels = tf.convert_to_tensor(input_labels)
802
+ else:
803
+ raise ValueError("return_tensors must be either 'pt' or 'tf'.")
804
+ return crop_boxes, points_per_crop, cropped_images, input_labels
805
+
806
+ def filter_masks(
807
+ self,
808
+ masks,
809
+ iou_scores,
810
+ original_size,
811
+ cropped_box_image,
812
+ pred_iou_thresh=0.88,
813
+ stability_score_thresh=0.95,
814
+ mask_threshold=0,
815
+ stability_score_offset=1,
816
+ return_tensors="pt",
817
+ ):
818
+ """
819
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
820
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
821
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
822
+ bounding boxes and pad the predicted masks if necessary.
823
+
824
+ Args:
825
+ masks (`Union[torch.Tensor, tf.Tensor]`):
826
+ Input masks.
827
+ iou_scores (`Union[torch.Tensor, tf.Tensor]`):
828
+ List of IoU scores.
829
+ original_size (`Tuple[int,int]`):
830
+ Size of the orginal image.
831
+ cropped_box_image (`np.array`):
832
+ The cropped image.
833
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
834
+ The threshold for the iou scores.
835
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
836
+ The threshold for the stability score.
837
+ mask_threshold (`float`, *optional*, defaults to 0):
838
+ The threshold for the predicted masks.
839
+ stability_score_offset (`float`, *optional*, defaults to 1):
840
+ The offset for the stability score used in the `_compute_stability_score` method.
841
+ return_tensors (`str`, *optional*, defaults to `pt`):
842
+ If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
843
+ """
844
+ if return_tensors == "pt":
845
+ return self._filter_masks_pt(
846
+ masks=masks,
847
+ iou_scores=iou_scores,
848
+ original_size=original_size,
849
+ cropped_box_image=cropped_box_image,
850
+ pred_iou_thresh=pred_iou_thresh,
851
+ stability_score_thresh=stability_score_thresh,
852
+ mask_threshold=mask_threshold,
853
+ stability_score_offset=stability_score_offset,
854
+ )
855
+ elif return_tensors == "tf":
856
+ return self._filter_masks_tf(
857
+ masks=masks,
858
+ iou_scores=iou_scores,
859
+ original_size=original_size,
860
+ cropped_box_image=cropped_box_image,
861
+ pred_iou_thresh=pred_iou_thresh,
862
+ stability_score_thresh=stability_score_thresh,
863
+ mask_threshold=mask_threshold,
864
+ stability_score_offset=stability_score_offset,
865
+ )
866
+
867
+ def _filter_masks_pt(
868
+ self,
869
+ masks,
870
+ iou_scores,
871
+ original_size,
872
+ cropped_box_image,
873
+ pred_iou_thresh=0.88,
874
+ stability_score_thresh=0.95,
875
+ mask_threshold=0,
876
+ stability_score_offset=1,
877
+ ):
878
+ """
879
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
880
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
881
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
882
+ bounding boxes and pad the predicted masks if necessary.
883
+
884
+ Args:
885
+ masks (`torch.Tensor`):
886
+ Input masks.
887
+ iou_scores (`torch.Tensor`):
888
+ List of IoU scores.
889
+ original_size (`Tuple[int,int]`):
890
+ Size of the orginal image.
891
+ cropped_box_image (`np.array`):
892
+ The cropped image.
893
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
894
+ The threshold for the iou scores.
895
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
896
+ The threshold for the stability score.
897
+ mask_threshold (`float`, *optional*, defaults to 0):
898
+ The threshold for the predicted masks.
899
+ stability_score_offset (`float`, *optional*, defaults to 1):
900
+ The offset for the stability score used in the `_compute_stability_score` method.
901
+
902
+ """
903
+ requires_backends(self, ["torch"])
904
+ original_height, original_width = original_size
905
+ iou_scores = iou_scores.flatten(0, 1)
906
+ masks = masks.flatten(0, 1)
907
+
908
+ if masks.shape[0] != iou_scores.shape[0]:
909
+ raise ValueError("masks and iou_scores must have the same batch size.")
910
+
911
+ if masks.device != iou_scores.device:
912
+ iou_scores = iou_scores.to(masks.device)
913
+
914
+ batch_size = masks.shape[0]
915
+
916
+ keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
917
+
918
+ if pred_iou_thresh > 0.0:
919
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
920
+
921
+ # compute stability score
922
+ if stability_score_thresh > 0.0:
923
+ stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
924
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
925
+
926
+ scores = iou_scores[keep_mask]
927
+ masks = masks[keep_mask]
928
+
929
+ # binarize masks
930
+ masks = masks > mask_threshold
931
+ converted_boxes = _batched_mask_to_box(masks)
932
+
933
+ keep_mask = ~_is_box_near_crop_edge(
934
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
935
+ )
936
+
937
+ scores = scores[keep_mask]
938
+ masks = masks[keep_mask]
939
+ converted_boxes = converted_boxes[keep_mask]
940
+
941
+ masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
942
+ # conversion to rle is necessary to run non-maximum suppresion
943
+ masks = _mask_to_rle_pytorch(masks)
944
+
945
+ return masks, scores, converted_boxes
946
+
947
+ def _filter_masks_tf(
948
+ self,
949
+ masks,
950
+ iou_scores,
951
+ original_size,
952
+ cropped_box_image,
953
+ pred_iou_thresh=0.88,
954
+ stability_score_thresh=0.95,
955
+ mask_threshold=0,
956
+ stability_score_offset=1,
957
+ ):
958
+ """
959
+ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
960
+ that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
961
+ score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
962
+ bounding boxes and pad the predicted masks if necessary.
963
+
964
+ Args:
965
+ masks (`tf.Tensor`):
966
+ Input masks.
967
+ iou_scores (`tf.Tensor`):
968
+ List of IoU scores.
969
+ original_size (`Tuple[int,int]`):
970
+ Size of the orginal image.
971
+ cropped_box_image (`np.array`):
972
+ The cropped image.
973
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
974
+ The threshold for the iou scores.
975
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
976
+ The threshold for the stability score.
977
+ mask_threshold (`float`, *optional*, defaults to 0):
978
+ The threshold for the predicted masks.
979
+ stability_score_offset (`float`, *optional*, defaults to 1):
980
+ The offset for the stability score used in the `_compute_stability_score` method.
981
+
982
+ """
983
+ requires_backends(self, ["tf"])
984
+ original_height, original_width = original_size
985
+ iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
986
+ masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
987
+
988
+ if masks.shape[0] != iou_scores.shape[0]:
989
+ raise ValueError("masks and iou_scores must have the same batch size.")
990
+
991
+ batch_size = masks.shape[0]
992
+
993
+ keep_mask = tf.ones(batch_size, dtype=tf.bool)
994
+
995
+ if pred_iou_thresh > 0.0:
996
+ keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
997
+
998
+ # compute stability score
999
+ if stability_score_thresh > 0.0:
1000
+ stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
1001
+ keep_mask = keep_mask & (stability_scores > stability_score_thresh)
1002
+
1003
+ scores = iou_scores[keep_mask]
1004
+ masks = masks[keep_mask]
1005
+
1006
+ # binarize masks
1007
+ masks = masks > mask_threshold
1008
+ converted_boxes = _batched_mask_to_box_tf(masks)
1009
+
1010
+ keep_mask = ~_is_box_near_crop_edge_tf(
1011
+ converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
1012
+ )
1013
+
1014
+ scores = scores[keep_mask]
1015
+ masks = masks[keep_mask]
1016
+ converted_boxes = converted_boxes[keep_mask]
1017
+
1018
+ masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
1019
+ # conversion to rle is necessary to run non-maximum suppresion
1020
+ masks = _mask_to_rle_tf(masks)
1021
+
1022
+ return masks, scores, converted_boxes
1023
+
1024
+
1025
+ def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
1026
+ # One mask is always contained inside the other.
1027
+ # Save memory by preventing unnecesary cast to torch.int64
1028
+ intersections = (
1029
+ (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
1030
+ )
1031
+ unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
1032
+ stability_scores = intersections / unions
1033
+ return stability_scores
1034
+
1035
+
1036
+ def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
1037
+ # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
1038
+ # we get the right division results.
1039
+ intersections = tf.count_nonzero(
1040
+ masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
1041
+ )
1042
+ unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
1043
+ stability_scores = intersections / unions
1044
+ return stability_scores
1045
+
1046
+
1047
+ def _build_point_grid(n_per_side: int) -> np.ndarray:
1048
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
1049
+ offset = 1 / (2 * n_per_side)
1050
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
1051
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
1052
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
1053
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
1054
+ return points
1055
+
1056
+
1057
+ def _normalize_coordinates(
1058
+ target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
1059
+ ) -> np.ndarray:
1060
+ """
1061
+ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
1062
+ format.
1063
+ """
1064
+ old_height, old_width = original_size
1065
+
1066
+ scale = target_size * 1.0 / max(old_height, old_width)
1067
+ new_height, new_width = old_height * scale, old_width * scale
1068
+ new_width = int(new_width + 0.5)
1069
+ new_height = int(new_height + 0.5)
1070
+
1071
+ coords = deepcopy(coords).astype(float)
1072
+
1073
+ if is_bounding_box:
1074
+ coords = coords.reshape(-1, 2, 2)
1075
+
1076
+ coords[..., 0] = coords[..., 0] * (new_width / old_width)
1077
+ coords[..., 1] = coords[..., 1] * (new_height / old_height)
1078
+
1079
+ if is_bounding_box:
1080
+ coords = coords.reshape(-1, 4)
1081
+
1082
+ return coords
1083
+
1084
+
1085
+ def _generate_crop_boxes(
1086
+ image,
1087
+ target_size: int, # Is it tuple here?
1088
+ crop_n_layers: int = 0,
1089
+ overlap_ratio: float = 512 / 1500,
1090
+ points_per_crop: Optional[int] = 32,
1091
+ crop_n_points_downscale_factor: Optional[List[int]] = 1,
1092
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1093
+ ) -> Tuple[List[List[int]], List[int]]:
1094
+ """
1095
+ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
1096
+
1097
+ Args:
1098
+ image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
1099
+ Image to generate crops for.
1100
+ target_size (`int`):
1101
+ Size of the smallest crop.
1102
+ crop_n_layers (`int`, *optional*):
1103
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
1104
+ to run, where each layer has 2**i_layer number of image crops.
1105
+ overlap_ratio (`int`, *optional*):
1106
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
1107
+ image length. Later layers with more crops scale down this overlap.
1108
+ points_per_crop (`int`, *optional*):
1109
+ Number of points to sample per crop.
1110
+ crop_n_points_downscale_factor (`int`, *optional*):
1111
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
1112
+ input_data_format (`str` or `ChannelDimension`, *optional*):
1113
+ The channel dimension format of the input image. If not provided, it will be inferred.
1114
+ """
1115
+
1116
+ if isinstance(image, list):
1117
+ raise ValueError("Only one image is allowed for crop generation.")
1118
+ image = to_numpy_array(image)
1119
+ original_size = get_image_size(image, input_data_format)
1120
+
1121
+ points_grid = []
1122
+ for i in range(crop_n_layers + 1):
1123
+ n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
1124
+ points_grid.append(_build_point_grid(n_points))
1125
+
1126
+ crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
1127
+
1128
+ cropped_images, point_grid_per_crop = _generate_crop_images(
1129
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
1130
+ )
1131
+ crop_boxes = np.array(crop_boxes)
1132
+ crop_boxes = crop_boxes.astype(np.float32)
1133
+ points_per_crop = np.array([point_grid_per_crop])
1134
+ points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
1135
+
1136
+ input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
1137
+
1138
+ return crop_boxes, points_per_crop, cropped_images, input_labels
1139
+
1140
+
1141
+ def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
1142
+ """
1143
+ Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
1144
+ consists of the following required indices:
1145
+ - X: X coordinate of the top left of the bounding box
1146
+ - Y: Y coordinate of the top left of the bounding box
1147
+ - W: width of the bounding box
1148
+ - H: height of the bounding box
1149
+ """
1150
+ crop_boxes, layer_idxs = [], []
1151
+ im_height, im_width = original_size
1152
+ short_side = min(im_height, im_width)
1153
+
1154
+ # Original image
1155
+ crop_boxes.append([0, 0, im_width, im_height])
1156
+ layer_idxs.append(0)
1157
+ for i_layer in range(crop_n_layers):
1158
+ n_crops_per_side = 2 ** (i_layer + 1)
1159
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
1160
+
1161
+ crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
1162
+ crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
1163
+
1164
+ crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
1165
+ crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
1166
+
1167
+ for left, top in product(crop_box_x0, crop_box_y0):
1168
+ box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
1169
+ crop_boxes.append(box)
1170
+ layer_idxs.append(i_layer + 1)
1171
+
1172
+ return crop_boxes, layer_idxs
1173
+
1174
+
1175
+ def _generate_crop_images(
1176
+ crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
1177
+ ):
1178
+ """
1179
+ Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
1180
+ also passed.
1181
+ """
1182
+ cropped_images = []
1183
+ total_points_per_crop = []
1184
+ for i, crop_box in enumerate(crop_boxes):
1185
+ left, top, right, bottom = crop_box
1186
+
1187
+ channel_dim = infer_channel_dimension_format(image, input_data_format)
1188
+ if channel_dim == ChannelDimension.LAST:
1189
+ cropped_im = image[top:bottom, left:right, :]
1190
+ else:
1191
+ cropped_im = image[:, top:bottom, left:right]
1192
+
1193
+ cropped_images.append(cropped_im)
1194
+
1195
+ cropped_im_size = get_image_size(cropped_im, channel_dim)
1196
+ points_scale = np.array(cropped_im_size)[None, ::-1]
1197
+
1198
+ points = points_grid[layer_idxs[i]] * points_scale
1199
+ normalized_points = _normalize_coordinates(target_size, points, original_size)
1200
+ total_points_per_crop.append(normalized_points)
1201
+
1202
+ return cropped_images, total_points_per_crop
1203
+
1204
+
1205
+ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
1206
+ left, top, right, bottom = crop_box
1207
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
1208
+ return masks
1209
+ # Coordinate transform masks
1210
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
1211
+ pad = (left, pad_x - left, top, pad_y - top)
1212
+ return torch.nn.functional.pad(masks, pad, value=0)
1213
+
1214
+
1215
+ def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
1216
+ left, top, right, bottom = crop_box
1217
+ if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
1218
+ return masks
1219
+ # Coordinate transform masks
1220
+ pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
1221
+ pad = (left, pad_x - left, top, pad_y - top)
1222
+ return tf.pad(masks, pad, constant_values=0)
1223
+
1224
+
1225
+ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
1226
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
1227
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
1228
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
1229
+
1230
+ left, top, _, _ = crop_box
1231
+ offset = torch.tensor([[left, top, left, top]], device=boxes.device)
1232
+ # Check if boxes has a channel dimension
1233
+ if len(boxes.shape) == 3:
1234
+ offset = offset.unsqueeze(1)
1235
+ boxes = (boxes + offset).float()
1236
+
1237
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
1238
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
1239
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
1240
+ return torch.any(near_crop_edge, dim=1)
1241
+
1242
+
1243
+ def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
1244
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
1245
+ crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
1246
+ orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
1247
+
1248
+ left, top, _, _ = crop_box
1249
+ offset = tf.convert_to_tensor([[left, top, left, top]])
1250
+ # Check if boxes has a channel dimension
1251
+ if len(boxes.shape) == 3:
1252
+ offset = tf.expand_dims(offset, 1)
1253
+ boxes = tf.cast(boxes + offset, tf.float32)
1254
+
1255
+ near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
1256
+ near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
1257
+ near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
1258
+ return tf.reduce_any(near_crop_edge, axis=1)
1259
+
1260
+
1261
+ def _batched_mask_to_box(masks: "torch.Tensor"):
1262
+ """
1263
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
1264
+ corresponds the following required indices:
1265
+ - LEFT: left hand side of the bounding box
1266
+ - TOP: top of the bounding box
1267
+ - RIGHT: right of the bounding box
1268
+ - BOTTOM: bottom of the bounding box
1269
+
1270
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
1271
+ is channel_1 x channel_2 x ... x 4.
1272
+
1273
+ Args:
1274
+ - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
1275
+ """
1276
+ # torch.max below raises an error on empty inputs, just skip in this case
1277
+
1278
+ if torch.numel(masks) == 0:
1279
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
1280
+
1281
+ # Normalize shape to Cxheightxwidth
1282
+ shape = masks.shape
1283
+ height, width = shape[-2:]
1284
+
1285
+ # Get top and bottom edges
1286
+ in_height, _ = torch.max(masks, dim=-1)
1287
+ in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
1288
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
1289
+ in_height_coords = in_height_coords + height * (~in_height)
1290
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
1291
+
1292
+ # Get left and right edges
1293
+ in_width, _ = torch.max(masks, dim=-2)
1294
+ in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
1295
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
1296
+ in_width_coords = in_width_coords + width * (~in_width)
1297
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
1298
+
1299
+ # If the mask is empty the right edge will be to the left of the left edge.
1300
+ # Replace these boxes with [0, 0, 0, 0]
1301
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
1302
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
1303
+ out = out * (~empty_filter).unsqueeze(-1)
1304
+
1305
+ # Return to original shape
1306
+ out = out.reshape(*shape[:-2], 4)
1307
+ return out
1308
+
1309
+
1310
+ def _batched_mask_to_box_tf(masks: "tf.Tensor"):
1311
+ """
1312
+ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
1313
+ corresponds the following required indices:
1314
+ - LEFT: left hand side of the bounding box
1315
+ - TOP: top of the bounding box
1316
+ - RIGHT: right of the bounding box
1317
+ - BOTTOM: bottom of the bounding box
1318
+
1319
+ Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
1320
+ is channel_1 x channel_2 x ... x 4.
1321
+
1322
+ Args:
1323
+ - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
1324
+ """
1325
+
1326
+ if tf.size(masks) == 0:
1327
+ return tf.zeros([*masks.shape[:-2], 4])
1328
+
1329
+ # Normalize shape to Cxheightxwidth
1330
+ shape = shape_list(masks)
1331
+ height, width = shape[-2:]
1332
+
1333
+ # Get top and bottom edges
1334
+ in_height = tf.reduce_max(masks, axis=-1)
1335
+ in_height_coords = in_height * tf.range(height)[None, :]
1336
+ bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
1337
+ in_height_coords = in_height_coords + height * (~in_height)
1338
+ top_edges = tf.reduce_min(in_height_coords, axis=-1)
1339
+
1340
+ # Get left and right edges
1341
+ in_width, _ = tf.reduce_max(masks, axis=-2)
1342
+ in_width_coords = in_width * tf.range(width)[None, :]
1343
+ right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
1344
+ in_width_coords = in_width_coords + width * (~in_width)
1345
+ left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
1346
+
1347
+ # If the mask is empty the right edge will be to the left of the left edge.
1348
+ # Replace these boxes with [0, 0, 0, 0]
1349
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
1350
+ out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
1351
+ out = out * tf.expand_dims(~empty_filter, -1)
1352
+
1353
+ # Return to original shape
1354
+ out = tf.reshape(out, *shape[:-2], 4)
1355
+ return out
1356
+
1357
+
1358
+ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
1359
+ """
1360
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
1361
+ """
1362
+ # Put in fortran order and flatten height and width
1363
+ batch_size, height, width = input_mask.shape
1364
+ input_mask = input_mask.permute(0, 2, 1).flatten(1)
1365
+
1366
+ # Compute change indices
1367
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
1368
+ change_indices = diff.nonzero()
1369
+
1370
+ # Encode run length
1371
+ out = []
1372
+ for i in range(batch_size):
1373
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
1374
+ if len(cur_idxs) == 0:
1375
+ # No changes => either all 0 or all 1
1376
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
1377
+ if input_mask[i, 0] == 0:
1378
+ out.append({"size": [height, width], "counts": [height * width]})
1379
+ else:
1380
+ out.append({"size": [height, width], "counts": [0, height * width]})
1381
+ continue
1382
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
1383
+ counts = [] if input_mask[i, 0] == 0 else [0]
1384
+ counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
1385
+ out.append({"size": [height, width], "counts": counts})
1386
+ return out
1387
+
1388
+
1389
+ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
1390
+ """
1391
+ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
1392
+ """
1393
+ # Put in fortran order and flatten height and width
1394
+ batch_size, height, width = input_mask.shape
1395
+ input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
1396
+
1397
+ # Compute change indices
1398
+ diff = input_mask[:, 1:] ^ input_mask[:, :-1]
1399
+ change_indices = tf.where(diff)
1400
+
1401
+ # Encode run length
1402
+ out = []
1403
+ for i in range(batch_size):
1404
+ cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
1405
+ if len(cur_idxs) == 0:
1406
+ # No changes => either all 0 or all 1
1407
+ # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
1408
+ if input_mask[i, 0] == 0:
1409
+ out.append({"size": [height, width], "counts": [height * width]})
1410
+ else:
1411
+ out.append({"size": [height, width], "counts": [0, height * width]})
1412
+ continue
1413
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
1414
+ counts = [] if input_mask[i, 0] == 0 else [0]
1415
+ counts += (
1416
+ [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
1417
+ )
1418
+ out.append({"size": [height, width], "counts": counts})
1419
+ return out
1420
+
1421
+
1422
+ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
1423
+ """Compute a binary mask from an uncompressed RLE."""
1424
+ height, width = rle["size"]
1425
+ mask = np.empty(height * width, dtype=bool)
1426
+ idx = 0
1427
+ parity = False
1428
+ for count in rle["counts"]:
1429
+ mask[idx : idx + count] = parity
1430
+ idx += count
1431
+ parity = not parity
1432
+ mask = mask.reshape(width, height)
1433
+ return mask.transpose() # Reshape to original shape
1434
+
1435
+
1436
+ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
1437
+ """
1438
+ Perform NMS (Non Maximum Suppression) on the outputs.
1439
+
1440
+ Args:
1441
+ rle_masks (`torch.Tensor`):
1442
+ binary masks in the RLE format
1443
+ iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
1444
+ iou_scores predicted by the model
1445
+ mask_boxes (`torch.Tensor`):
1446
+ The bounding boxes corresponding to segmentation masks
1447
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
1448
+ NMS threshold.
1449
+ """
1450
+ keep_by_nms = batched_nms(
1451
+ boxes=mask_boxes.float(),
1452
+ scores=iou_scores,
1453
+ idxs=torch.zeros(mask_boxes.shape[0]),
1454
+ iou_threshold=amg_crops_nms_thresh,
1455
+ )
1456
+
1457
+ iou_scores = iou_scores[keep_by_nms]
1458
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
1459
+ mask_boxes = mask_boxes[keep_by_nms]
1460
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
1461
+
1462
+ return masks, iou_scores, rle_masks, mask_boxes
1463
+
1464
+
1465
+ def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
1466
+ """
1467
+ Perform NMS (Non Maximum Suppression) on the outputs.
1468
+
1469
+ Args:
1470
+ rle_masks (`tf.Tensor`):
1471
+ binary masks in the RLE format
1472
+ iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
1473
+ iou_scores predicted by the model
1474
+ mask_boxes (`tf.Tensor`):
1475
+ The bounding boxes corresponding to segmentation masks
1476
+ amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
1477
+ NMS threshold.
1478
+ """
1479
+ keep_by_nms = tf.image.combined_non_max_suppression(
1480
+ boxes=mask_boxes.float(),
1481
+ scores=iou_scores,
1482
+ idxs=torch.zeros(mask_boxes.shape[0]),
1483
+ iou_threshold=amg_crops_nms_thresh,
1484
+ )
1485
+
1486
+ iou_scores = iou_scores[keep_by_nms]
1487
+ rle_masks = [rle_masks[i] for i in keep_by_nms]
1488
+ mask_boxes = mask_boxes[keep_by_nms]
1489
+ masks = [_rle_to_mask(rle) for rle in rle_masks]
1490
+
1491
+ return masks, iou_scores, rle_masks, mask_boxes
1492
+
1493
+
1494
+ __all__ = ["SamImageProcessor"]
docs/transformers/build/lib/transformers/models/sam/modeling_sam.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch SAM model."""
16
+
17
+ import collections
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import Tensor, nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...modeling_outputs import BaseModelOutput
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import (
31
+ ModelOutput,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ can_return_tuple,
35
+ logging,
36
+ replace_return_docstrings,
37
+ )
38
+ from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CONFIG_FOR_DOC = "SamConfig"
44
+ _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
45
+
46
+
47
+ @dataclass
48
+ class SamVisionEncoderOutput(ModelOutput):
49
+ """
50
+ Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
51
+ layer to the pooler_output.
52
+
53
+ Args:
54
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
55
+ The image embeddings obtained by applying the projection layer to the pooler_output.
56
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
57
+ Sequence of hidden-states at the output of the last layer of the model.
58
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
59
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
60
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
61
+
62
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
63
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
64
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
65
+ sequence_length)`.
66
+
67
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
68
+ heads.
69
+ """
70
+
71
+ image_embeds: Optional[torch.FloatTensor] = None
72
+ last_hidden_state: Optional[torch.FloatTensor] = None
73
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
74
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
75
+
76
+
77
+ @dataclass
78
+ class SamImageSegmentationOutput(ModelOutput):
79
+ """
80
+ Base class for Segment-Anything model's output
81
+
82
+ Args:
83
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
84
+ The iou scores of the predicted masks.
85
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
86
+ The predicted low resolutions masks. Needs to be post-processed by the processor
87
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
92
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
99
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
100
+ sequence_length)`.
101
+
102
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
103
+ heads.
104
+ """
105
+
106
+ iou_scores: Optional[torch.FloatTensor] = None
107
+ pred_masks: Optional[torch.FloatTensor] = None
108
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
109
+ vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
110
+ mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
111
+
112
+
113
+ class SamPatchEmbeddings(nn.Module):
114
+ """
115
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
116
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
117
+ Transformer.
118
+ """
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ image_size, patch_size = config.image_size, config.patch_size
123
+ num_channels, hidden_size = config.num_channels, config.hidden_size
124
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
125
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
126
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
127
+ self.image_size = image_size
128
+ self.patch_size = patch_size
129
+ self.num_channels = num_channels
130
+ self.num_patches = num_patches
131
+
132
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
133
+
134
+ def forward(self, pixel_values):
135
+ batch_size, num_channels, height, width = pixel_values.shape
136
+ if num_channels != self.num_channels:
137
+ raise ValueError(
138
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
139
+ )
140
+ if height != self.image_size[0] or width != self.image_size[1]:
141
+ raise ValueError(
142
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
143
+ )
144
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
145
+ return embeddings
146
+
147
+
148
+ class SamMLPBlock(nn.Module):
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
152
+ self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
153
+ self.act = ACT2FN[config.hidden_act]
154
+
155
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
156
+ hidden_states = self.lin1(hidden_states)
157
+ hidden_states = self.act(hidden_states)
158
+ hidden_states = self.lin2(hidden_states)
159
+ return hidden_states
160
+
161
+
162
+ # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
163
+ class SamLayerNorm(nn.Module):
164
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
165
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
166
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
167
+ """
168
+
169
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
170
+ super().__init__()
171
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
172
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
173
+ self.eps = eps
174
+ self.data_format = data_format
175
+ if self.data_format not in ["channels_last", "channels_first"]:
176
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
177
+ self.normalized_shape = (normalized_shape,)
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ if self.data_format == "channels_last":
181
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
182
+ elif self.data_format == "channels_first":
183
+ input_dtype = x.dtype
184
+ x = x.float()
185
+ u = x.mean(1, keepdim=True)
186
+ s = (x - u).pow(2).mean(1, keepdim=True)
187
+ x = (x - u) / torch.sqrt(s + self.eps)
188
+ x = x.to(dtype=input_dtype)
189
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
190
+ return x
191
+
192
+
193
+ class SamAttention(nn.Module):
194
+ """
195
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
196
+ values.
197
+ """
198
+
199
+ def __init__(self, config, downsample_rate=None):
200
+ super().__init__()
201
+ self.hidden_size = config.hidden_size
202
+
203
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
204
+
205
+ self.internal_dim = config.hidden_size // downsample_rate
206
+ self.num_attention_heads = config.num_attention_heads
207
+ if self.internal_dim % config.num_attention_heads != 0:
208
+ raise ValueError("num_attention_heads must divide hidden_size.")
209
+
210
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
211
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
212
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
213
+ self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
214
+
215
+ def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
216
+ batch, point_batch_size, n_tokens, channel = hidden_states.shape
217
+ c_per_head = channel // num_attention_heads
218
+ hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
219
+ return hidden_states.transpose(1, 2)
220
+
221
+ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
222
+ batch, n_heads, n_tokens, c_per_head = hidden_states.shape
223
+ hidden_states = hidden_states.transpose(1, 2)
224
+ return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
225
+
226
+ def forward(
227
+ self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
228
+ ) -> Tensor:
229
+ # Input projections
230
+ query = self.q_proj(query)
231
+ key = self.k_proj(key)
232
+ value = self.v_proj(value)
233
+
234
+ point_batch_size = query.shape[1]
235
+ # Separate into heads
236
+ query = self._separate_heads(query, self.num_attention_heads)
237
+ key = self._separate_heads(key, self.num_attention_heads)
238
+ value = self._separate_heads(value, self.num_attention_heads)
239
+
240
+ # SamAttention
241
+ _, _, _, c_per_head = query.shape
242
+ attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
243
+ attn = attn / (c_per_head**0.5)
244
+ attn = torch.softmax(attn, dim=-1)
245
+
246
+ if attention_similarity is not None:
247
+ attn = attn + attention_similarity
248
+ attn = torch.softmax(attn, dim=-1)
249
+
250
+ # Get output
251
+ out = attn @ value
252
+ out = self._recombine_heads(out, point_batch_size)
253
+ out = self.out_proj(out)
254
+
255
+ return out
256
+
257
+
258
+ class SamSdpaAttention(SamAttention):
259
+ """
260
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
261
+ values. Using SDPA instead of the default attention.
262
+ """
263
+
264
+ def __init__(self, config, downsample_rate=None):
265
+ super().__init__(config, downsample_rate)
266
+
267
+ def forward(
268
+ self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
269
+ ) -> Tensor:
270
+ # Input projections
271
+ query = self.q_proj(query)
272
+ key = self.k_proj(key)
273
+ value = self.v_proj(value)
274
+
275
+ point_batch_size = query.shape[1]
276
+ # Separate into heads
277
+ query = self._separate_heads(query, self.num_attention_heads)
278
+ key = self._separate_heads(key, self.num_attention_heads)
279
+ value = self._separate_heads(value, self.num_attention_heads)
280
+
281
+ # Scaled dot product attention
282
+ attn_mask = None
283
+ if attention_similarity is not None:
284
+ attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
285
+
286
+ out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
287
+
288
+ # Get output
289
+ out = self._recombine_heads(out, point_batch_size)
290
+ out = self.out_proj(out)
291
+
292
+ return out
293
+
294
+
295
+ SAM_ATTENTION_CLASSES = {
296
+ "eager": SamAttention,
297
+ "sdpa": SamSdpaAttention,
298
+ }
299
+
300
+
301
+ class SamTwoWayAttentionBlock(nn.Module):
302
+ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
303
+ """
304
+ A transformer block with four layers:
305
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
306
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
307
+
308
+ Arguments:
309
+ config (`SamMaskDecoderConfig`):
310
+ The configuration file used to instantiate the block
311
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
312
+ The downsample ratio of the block used to reduce the inner dim of the attention.
313
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
314
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
315
+ """
316
+ super().__init__()
317
+
318
+ self.hidden_size = config.hidden_size
319
+ self.layer_norm_eps = config.layer_norm_eps
320
+
321
+ self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
322
+ self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
323
+
324
+ self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
325
+ config, downsample_rate=attention_downsample_rate
326
+ )
327
+ self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
328
+
329
+ self.mlp = SamMLPBlock(config)
330
+ self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
331
+
332
+ self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
333
+ self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
334
+ config, downsample_rate=attention_downsample_rate
335
+ )
336
+ self.skip_first_layer_pe = skip_first_layer_pe
337
+
338
+ def forward(
339
+ self,
340
+ queries: Tensor,
341
+ keys: Tensor,
342
+ query_point_embedding: Tensor,
343
+ key_point_embedding: Tensor,
344
+ attention_similarity: Tensor,
345
+ output_attentions: bool = False,
346
+ ):
347
+ # Self attention block
348
+ if self.skip_first_layer_pe:
349
+ queries = self.self_attn(query=queries, key=queries, value=queries)
350
+ else:
351
+ query = queries + query_point_embedding
352
+ attn_out = self.self_attn(query=query, key=query, value=queries)
353
+ queries = queries + attn_out
354
+ queries = self.layer_norm1(queries)
355
+
356
+ # Cross attention block, tokens attending to image embedding
357
+ query = queries + query_point_embedding
358
+ key = keys + key_point_embedding
359
+
360
+ attn_out = self.cross_attn_token_to_image(
361
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
362
+ )
363
+ queries = queries + attn_out
364
+
365
+ queries = self.layer_norm2(queries)
366
+
367
+ # MLP block
368
+ mlp_out = self.mlp(queries)
369
+ queries = queries + mlp_out
370
+ queries = self.layer_norm3(queries)
371
+
372
+ # Cross attention block, image embedding attending to tokens
373
+ query = queries + query_point_embedding
374
+ key = keys + key_point_embedding
375
+
376
+ attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
377
+ keys = keys + attn_out
378
+
379
+ keys = self.layer_norm4(keys)
380
+
381
+ outputs = (queries, keys)
382
+
383
+ if output_attentions:
384
+ outputs = outputs + (attn_out,)
385
+ else:
386
+ outputs = outputs + (None,)
387
+
388
+ return outputs
389
+
390
+
391
+ class SamTwoWayTransformer(nn.Module):
392
+ def __init__(self, config: SamMaskDecoderConfig):
393
+ super().__init__()
394
+ self.config = config
395
+
396
+ self.num_hidden_layers = config.num_hidden_layers
397
+ self.layers = nn.ModuleList()
398
+
399
+ for i in range(self.num_hidden_layers):
400
+ self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
401
+
402
+ self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
403
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
404
+
405
+ def forward(
406
+ self,
407
+ point_embeddings: Tensor,
408
+ image_embeddings: Tensor,
409
+ image_positional_embeddings: Tensor,
410
+ attention_similarity: Tensor,
411
+ target_embedding=None,
412
+ output_attentions: Optional[bool] = None,
413
+ output_hidden_states: Optional[bool] = None,
414
+ ) -> Union[Tuple, BaseModelOutput]:
415
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
416
+ output_hidden_states = (
417
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
418
+ )
419
+
420
+ all_attentions = ()
421
+
422
+ if image_embeddings is None:
423
+ raise ValueError("You have to specify an image_embedding")
424
+
425
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
426
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
427
+
428
+ # Prepare queries
429
+ queries = point_embeddings
430
+ keys = image_embeddings
431
+
432
+ # Apply transformer blocks and final layernorm
433
+ for layer in self.layers:
434
+ if target_embedding is not None:
435
+ queries += target_embedding
436
+
437
+ queries, keys, attention_outputs = layer(
438
+ queries=queries,
439
+ keys=keys,
440
+ query_point_embedding=point_embeddings,
441
+ key_point_embedding=image_positional_embeddings,
442
+ attention_similarity=attention_similarity,
443
+ output_attentions=output_attentions,
444
+ )
445
+
446
+ if output_attentions:
447
+ all_attentions = all_attentions + (attention_outputs,)
448
+
449
+ # Apply the final attenion layer from the points to the image
450
+ query = queries + point_embeddings
451
+ key = keys + image_positional_embeddings
452
+
453
+ attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
454
+
455
+ queries = queries + attn_out
456
+ queries = self.layer_norm_final_attn(queries)
457
+ return queries, keys, all_attentions
458
+
459
+
460
+ class SamFeedForward(nn.Module):
461
+ def __init__(
462
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
463
+ ):
464
+ super().__init__()
465
+ self.num_layers = num_layers
466
+ self.activation = nn.ReLU()
467
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
468
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
469
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
470
+ self.sigmoid_output = sigmoid_output
471
+
472
+ def forward(self, hidden_states):
473
+ hidden_states = self.proj_in(hidden_states)
474
+ hidden_states = self.activation(hidden_states)
475
+ for layer in self.layers:
476
+ hidden_states = self.activation(layer(hidden_states))
477
+
478
+ hidden_states = self.proj_out(hidden_states)
479
+ if self.sigmoid_output:
480
+ hidden_states = F.sigmoid(hidden_states)
481
+ return hidden_states
482
+
483
+
484
+ class SamMaskDecoder(nn.Module):
485
+ def __init__(self, config: SamMaskDecoderConfig):
486
+ super().__init__()
487
+ self.config = config
488
+ self.hidden_size = config.hidden_size
489
+
490
+ self.num_multimask_outputs = config.num_multimask_outputs
491
+ self.num_mask_tokens = config.num_multimask_outputs + 1
492
+
493
+ self.iou_token = nn.Embedding(1, self.hidden_size)
494
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
495
+
496
+ self.transformer = SamTwoWayTransformer(config)
497
+
498
+ # should we create a new class for this?
499
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
500
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
501
+ self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
502
+ self.activation = nn.GELU()
503
+
504
+ mlps_list = []
505
+ for _ in range(self.num_mask_tokens):
506
+ mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
507
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
508
+
509
+ self.iou_prediction_head = SamFeedForward(
510
+ self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
511
+ )
512
+
513
+ def forward(
514
+ self,
515
+ image_embeddings: torch.Tensor,
516
+ image_positional_embeddings: torch.Tensor,
517
+ sparse_prompt_embeddings: torch.Tensor,
518
+ dense_prompt_embeddings: torch.Tensor,
519
+ multimask_output: bool,
520
+ output_attentions: Optional[bool] = None,
521
+ attention_similarity: Optional[torch.Tensor] = None,
522
+ target_embedding: Optional[torch.Tensor] = None,
523
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
524
+ """
525
+ Predict masks given image and prompt embeddings.
526
+
527
+ Args:
528
+ image_embeddings (`torch.Tensor`):
529
+ the embeddings from the image encoder
530
+ image_positional_embedding (`torch.Tensor`):
531
+ positional encoding with the shape of image_embeddings
532
+ sparse_prompt_embeddings (`torch.Tensor`):
533
+ The embeddings of the points and boxes
534
+ dense_prompt_embeddings (`torch.Tensor`):
535
+ the embeddings of the mask inputs
536
+ multimask_output (bool):
537
+ Whether to return multiple masks or a single mask.
538
+ output_attentions (bool, *optional*):
539
+ Whether or not to return the attentions tensors of all attention layers.
540
+ """
541
+ batch_size, num_channels, height, width = image_embeddings.shape
542
+ point_batch_size = sparse_prompt_embeddings.shape[1]
543
+ # Concatenate output tokens
544
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
545
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
546
+
547
+ if sparse_prompt_embeddings.sum().item() != 0:
548
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
549
+ else:
550
+ tokens = output_tokens
551
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
552
+
553
+ # Expand per-image data in batch direction to be per-point
554
+ image_embeddings = image_embeddings + dense_prompt_embeddings
555
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
556
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
557
+
558
+ # Run the transformer, image_positional_embedding are consumed
559
+ point_embedding, image_embeddings, attentions = self.transformer(
560
+ point_embeddings=point_embeddings,
561
+ image_embeddings=image_embeddings,
562
+ image_positional_embeddings=image_positional_embeddings,
563
+ attention_similarity=attention_similarity,
564
+ target_embedding=target_embedding,
565
+ output_attentions=output_attentions,
566
+ )
567
+ iou_token_out = point_embedding[:, :, 0, :]
568
+ mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
569
+
570
+ # Upscale mask embeddings and predict masks using the mask tokens
571
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
572
+ batch_size * point_batch_size, num_channels, height, width
573
+ )
574
+
575
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
576
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
577
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
578
+
579
+ hyper_in_list = []
580
+ for i in range(self.num_mask_tokens):
581
+ current_mlp = self.output_hypernetworks_mlps[i]
582
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
583
+ hyper_in = torch.stack(hyper_in_list, dim=2)
584
+
585
+ _, num_channels, height, width = upscaled_embedding.shape
586
+ upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
587
+ masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
588
+
589
+ # Generate mask quality predictions
590
+ iou_pred = self.iou_prediction_head(iou_token_out)
591
+
592
+ # Select the correct mask or masks for output
593
+ if multimask_output:
594
+ mask_slice = slice(1, None)
595
+ else:
596
+ mask_slice = slice(0, 1)
597
+ masks = masks[:, :, mask_slice, :, :]
598
+ iou_pred = iou_pred[:, :, mask_slice]
599
+
600
+ outputs = (masks, iou_pred)
601
+
602
+ if output_attentions:
603
+ outputs = outputs + (attentions,)
604
+ else:
605
+ outputs = outputs + (None,)
606
+
607
+ return outputs
608
+
609
+
610
+ class SamPositionalEmbedding(nn.Module):
611
+ def __init__(self, config):
612
+ super().__init__()
613
+ self.scale = config.hidden_size // 2
614
+ self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
615
+
616
+ def forward(self, input_coords, input_shape=None):
617
+ """Positionally encode points that are normalized to [0,1]."""
618
+ coordinates = input_coords.clone()
619
+
620
+ if input_shape is not None:
621
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
622
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
623
+
624
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
625
+ coordinates = 2 * coordinates - 1
626
+ coordinates = coordinates.to(self.positional_embedding.dtype)
627
+ coordinates = coordinates @ self.positional_embedding
628
+ coordinates = 2 * np.pi * coordinates
629
+ # outputs d_1 x ... x d_n x channel shape
630
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
631
+
632
+
633
+ class SamMaskEmbedding(nn.Module):
634
+ def __init__(self, config: SamPromptEncoderConfig):
635
+ super().__init__()
636
+ self.mask_input_channels = config.mask_input_channels // 4
637
+ self.activation = ACT2FN[config.hidden_act]
638
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
639
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
640
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
641
+ self.layer_norm1 = SamLayerNorm(
642
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
643
+ )
644
+ self.layer_norm2 = SamLayerNorm(
645
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
646
+ )
647
+
648
+ def forward(self, masks):
649
+ hidden_states = self.conv1(masks)
650
+ hidden_states = self.layer_norm1(hidden_states)
651
+ hidden_states = self.activation(hidden_states)
652
+
653
+ hidden_states = self.conv2(hidden_states)
654
+ hidden_states = self.layer_norm2(hidden_states)
655
+ hidden_states = self.activation(hidden_states)
656
+ dense_embeddings = self.conv3(hidden_states)
657
+ return dense_embeddings
658
+
659
+
660
+ class SamPromptEncoder(nn.Module):
661
+ def __init__(self, config: SamPromptEncoderConfig):
662
+ super().__init__()
663
+ self.shared_embedding = SamPositionalEmbedding(config.vision_config)
664
+ config = config.prompt_encoder_config
665
+ self.mask_embed = SamMaskEmbedding(config)
666
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
667
+
668
+ self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
669
+ self.input_image_size = config.image_size
670
+
671
+ self.point_embed = nn.ModuleList(
672
+ [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
673
+ )
674
+ self.hidden_size = config.hidden_size
675
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
676
+
677
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
678
+ """Embeds point prompts."""
679
+ points = points + 0.5 # Shift to center of pixel
680
+ if pad:
681
+ target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
682
+ target_labels_shape = (points.shape[0], points.shape[1], 1)
683
+ padding_point = torch.zeros(target_point_shape, device=points.device)
684
+ padding_label = -torch.ones(target_labels_shape, device=labels.device)
685
+ points = torch.cat([points, padding_point], dim=2)
686
+ labels = torch.cat([labels, padding_label], dim=2)
687
+ input_shape = (self.input_image_size, self.input_image_size)
688
+ point_embedding = self.shared_embedding(points, input_shape)
689
+
690
+ # torch.where and expanding the labels tensor is required by the ONNX export
691
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
692
+
693
+ # This is required for the ONNX export. The dtype, device need to be explicitely
694
+ # specificed as otherwise torch.onnx.export interprets as double
695
+ point_embedding = torch.where(
696
+ labels[..., None] != -10,
697
+ point_embedding,
698
+ torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
699
+ )
700
+
701
+ point_embedding = torch.where(
702
+ (labels == 0)[:, :, :, None],
703
+ point_embedding + self.point_embed[0].weight[None, None, :, :],
704
+ point_embedding,
705
+ )
706
+
707
+ point_embedding = torch.where(
708
+ (labels == 1)[:, :, :, None],
709
+ point_embedding + self.point_embed[1].weight[None, None, :, :],
710
+ point_embedding,
711
+ )
712
+
713
+ return point_embedding
714
+
715
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
716
+ """Embeds box prompts."""
717
+ boxes = boxes + 0.5 # Shift to center of pixel
718
+ batch_size, nb_boxes = boxes.shape[:2]
719
+ coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
720
+ input_shape = (self.input_image_size, self.input_image_size)
721
+ corner_embedding = self.shared_embedding(coords, input_shape)
722
+ corner_embedding[:, :, 0, :] += self.point_embed[2].weight
723
+ corner_embedding[:, :, 1, :] += self.point_embed[3].weight
724
+ return corner_embedding
725
+
726
+ def forward(
727
+ self,
728
+ input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
729
+ input_labels: Optional[torch.Tensor],
730
+ input_boxes: Optional[torch.Tensor],
731
+ input_masks: Optional[torch.Tensor],
732
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
733
+ """
734
+ Embeds different types of prompts, returning both sparse and dense embeddings.
735
+
736
+ Args:
737
+ points (`torch.Tensor`, *optional*):
738
+ point coordinates and labels to embed.
739
+ boxes (`torch.Tensor`, *optional*):
740
+ boxes to embed
741
+ masks (`torch.Tensor`, *optional*):
742
+ masks to embed
743
+ """
744
+ sparse_embeddings = None
745
+ batch_size = 1
746
+ target_device = self.shared_embedding.positional_embedding.device
747
+ if input_points is not None:
748
+ batch_size, point_batch_size = input_points.shape[:2]
749
+ if input_labels is None:
750
+ raise ValueError("If points are provided, labels must also be provided.")
751
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
752
+ sparse_embeddings = point_embeddings
753
+ if input_boxes is not None:
754
+ batch_size = input_boxes.shape[0]
755
+ box_embeddings = self._embed_boxes(input_boxes)
756
+ if sparse_embeddings is None:
757
+ sparse_embeddings = box_embeddings
758
+ else:
759
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
760
+ if input_masks is not None:
761
+ dense_embeddings = self.mask_embed(input_masks)
762
+ else:
763
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
764
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
765
+ )
766
+
767
+ if sparse_embeddings is None:
768
+ sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
769
+
770
+ return sparse_embeddings, dense_embeddings
771
+
772
+
773
+ class SamVisionAttention(nn.Module):
774
+ """Multi-head Attention block with relative position embeddings."""
775
+
776
+ def __init__(self, config, window_size):
777
+ super().__init__()
778
+ input_size = (
779
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
780
+ if window_size == 0
781
+ else (window_size, window_size)
782
+ )
783
+
784
+ self.num_attention_heads = config.num_attention_heads
785
+ head_dim = config.hidden_size // config.num_attention_heads
786
+ self.scale = head_dim**-0.5
787
+ self.dropout = config.attention_dropout
788
+
789
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
790
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
791
+
792
+ self.use_rel_pos = config.use_rel_pos
793
+ if self.use_rel_pos:
794
+ if input_size is None:
795
+ raise ValueError("Input size must be provided if using relative positional encoding.")
796
+
797
+ # initialize relative positional embeddings
798
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
799
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
800
+
801
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
802
+ """
803
+ Get relative positional embeddings according to the relative positions of
804
+ query and key sizes.
805
+
806
+ Args:
807
+ q_size (int):
808
+ size of the query.
809
+ k_size (int):
810
+ size of key k.
811
+ rel_pos (`torch.Tensor`):
812
+ relative position embeddings (L, channel).
813
+
814
+ Returns:
815
+ Extracted positional embeddings according to relative positions.
816
+ """
817
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
818
+ # Interpolate rel pos.
819
+ rel_pos_resized = F.interpolate(
820
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
821
+ size=max_rel_dist,
822
+ mode="linear",
823
+ )
824
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
825
+
826
+ # Scale the coords with short length if shapes for q and k are different.
827
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
828
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
829
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
830
+
831
+ return rel_pos_resized[relative_coords.long()]
832
+
833
+ def get_decomposed_rel_pos(
834
+ self,
835
+ query: torch.Tensor,
836
+ rel_pos_h: torch.Tensor,
837
+ rel_pos_w: torch.Tensor,
838
+ q_size: Tuple[int, int],
839
+ k_size: Tuple[int, int],
840
+ ) -> torch.Tensor:
841
+ """
842
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
843
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
844
+
845
+ Args:
846
+ query (`torch.Tensor`):
847
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
848
+ rel_pos_h (`torch.Tensor`):
849
+ relative position embeddings (Lh, channel) for height axis.
850
+ rel_pos_w (`torch.Tensor`):
851
+ relative position embeddings (Lw, channel) for width axis.
852
+ q_size (tuple):
853
+ spatial sequence size of query q with (query_height, query_width).
854
+ k_size (tuple):
855
+ spatial sequence size of key k with (key_height, key_width).
856
+
857
+ Returns:
858
+ decomposed_rel_pos (`torch.Tensor`):
859
+ decomposed relative position embeddings.
860
+ """
861
+ query_height, query_width = q_size
862
+ key_height, key_width = k_size
863
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
864
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
865
+
866
+ batch_size, _, dim = query.shape
867
+ reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
868
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
869
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
870
+
871
+ decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
872
+
873
+ return decomposed_rel_pos
874
+
875
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
876
+ batch_size, height, width, _ = hidden_states.shape
877
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
878
+ qkv = (
879
+ self.qkv(hidden_states)
880
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
881
+ .permute(2, 0, 3, 1, 4)
882
+ )
883
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
884
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
885
+
886
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
887
+
888
+ if self.use_rel_pos:
889
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
890
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
891
+ )
892
+ decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
893
+ attn_weights = attn_weights + decomposed_rel_pos
894
+
895
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
896
+
897
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
898
+
899
+ attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
900
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
901
+
902
+ attn_output = self.proj(attn_output)
903
+
904
+ if output_attentions:
905
+ outputs = (attn_output, attn_weights)
906
+ else:
907
+ outputs = (attn_output, None)
908
+
909
+ return outputs
910
+
911
+
912
+ class SamVisionSdpaAttention(SamVisionAttention):
913
+ """
914
+ Multi-head Attention block with relative position embeddings.
915
+ Using SDPA instead of the default attention.
916
+ """
917
+
918
+ def __init__(self, config, window_size):
919
+ super().__init__(config, window_size)
920
+
921
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
922
+ if output_attentions:
923
+ logger.warning_once(
924
+ "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
925
+ "`output_attentions=True`. Falling back to the manual attention implementation, but "
926
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
927
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
928
+ )
929
+ return super().forward(
930
+ hidden_states=hidden_states,
931
+ output_attentions=output_attentions,
932
+ )
933
+
934
+ batch_size, height, width, _ = hidden_states.shape
935
+ # qkv with shape (3, B, nHead, H * W, C)
936
+ qkv = (
937
+ self.qkv(hidden_states)
938
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
939
+ .permute(2, 0, 3, 1, 4)
940
+ )
941
+ # q, k, v with shape (B * nHead, H * W, C)
942
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
943
+
944
+ attn_bias = None
945
+ if self.use_rel_pos:
946
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
947
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
948
+ )
949
+ decomposed_rel_pos = decomposed_rel_pos.reshape(
950
+ batch_size, self.num_attention_heads, height * width, height * width
951
+ )
952
+ attn_bias = decomposed_rel_pos
953
+
954
+ query = query.view(batch_size, self.num_attention_heads, height * width, -1)
955
+ key = key.view(batch_size, self.num_attention_heads, height * width, -1)
956
+ value = value.view(batch_size, self.num_attention_heads, height * width, -1)
957
+
958
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
959
+
960
+ attn_output = (
961
+ attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
962
+ .permute(0, 2, 3, 1, 4)
963
+ .reshape(batch_size, height, width, -1)
964
+ )
965
+
966
+ attn_output = self.proj(attn_output)
967
+
968
+ return attn_output, None
969
+
970
+
971
+ SAM_VISION_ATTENTION_CLASSES = {
972
+ "eager": SamVisionAttention,
973
+ "sdpa": SamVisionSdpaAttention,
974
+ }
975
+
976
+
977
+ class SamVisionLayer(nn.Module):
978
+ def __init__(self, config, window_size):
979
+ super().__init__()
980
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
981
+ self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
982
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
983
+ self.mlp = SamMLPBlock(config)
984
+ self.window_size = window_size
985
+
986
+ def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
987
+ """
988
+ Args:
989
+ Partition into non-overlapping windows with padding if needed.
990
+ hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
991
+ size.
992
+
993
+ Returns:
994
+ windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
995
+ (pad_height, pad_width): padded height and width before partition
996
+ """
997
+ batch_size, height, width, channel = hidden_states.shape
998
+
999
+ pad_h = (window_size - height % window_size) % window_size
1000
+ pad_w = (window_size - width % window_size) % window_size
1001
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
1002
+ pad_height, pad_width = height + pad_h, width + pad_w
1003
+
1004
+ hidden_states = hidden_states.reshape(
1005
+ batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
1006
+ )
1007
+ windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
1008
+ return windows, (pad_height, pad_width)
1009
+
1010
+ def window_unpartition(
1011
+ self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
1012
+ ) -> torch.Tensor:
1013
+ """
1014
+ Args:
1015
+ Window unpartition into original sequences and removing padding.
1016
+ hidden_states (tensor):
1017
+ input tokens with [batch_size * num_windows, window_size, window_size, channel].
1018
+ window_size (int):
1019
+ window size.
1020
+ padding_shape (Tuple):
1021
+ padded height and width (pad_height, pad_width).
1022
+ original_shape (Tuple): original height and width (height, width) before padding.
1023
+
1024
+ Returns:
1025
+ hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
1026
+ """
1027
+ pad_height, pad_width = padding_shape
1028
+ height, width = original_shape
1029
+ batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
1030
+ hidden_states = windows.reshape(
1031
+ batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
1032
+ )
1033
+ hidden_states = (
1034
+ hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
1035
+ )
1036
+
1037
+ hidden_states = hidden_states[:, :height, :width, :].contiguous()
1038
+ return hidden_states
1039
+
1040
+ def forward(
1041
+ self,
1042
+ hidden_states: torch.Tensor,
1043
+ output_attentions: Optional[bool] = False,
1044
+ ) -> Tuple[torch.FloatTensor]:
1045
+ residual = hidden_states
1046
+
1047
+ hidden_states = self.layer_norm1(hidden_states)
1048
+ # Window partition
1049
+ if self.window_size > 0:
1050
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
1051
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
1052
+
1053
+ hidden_states, attn_weights = self.attn(
1054
+ hidden_states=hidden_states,
1055
+ output_attentions=output_attentions,
1056
+ )
1057
+ # Reverse window partition
1058
+ if self.window_size > 0:
1059
+ hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
1060
+
1061
+ hidden_states = residual + hidden_states
1062
+ layernorm_output = self.layer_norm2(hidden_states)
1063
+ hidden_states = hidden_states + self.mlp(layernorm_output)
1064
+
1065
+ outputs = (hidden_states,)
1066
+ if output_attentions:
1067
+ outputs += (attn_weights,)
1068
+
1069
+ return outputs
1070
+
1071
+
1072
+ class SamVisionNeck(nn.Module):
1073
+ def __init__(self, config: SamVisionConfig):
1074
+ super().__init__()
1075
+ self.config = config
1076
+
1077
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
1078
+ self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
1079
+ self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
1080
+ self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
1081
+
1082
+ def forward(self, hidden_states):
1083
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
1084
+ hidden_states = self.conv1(hidden_states)
1085
+ hidden_states = self.layer_norm1(hidden_states)
1086
+
1087
+ hidden_states = self.conv2(hidden_states)
1088
+ hidden_states = self.layer_norm2(hidden_states)
1089
+ return hidden_states
1090
+
1091
+
1092
+ class SamVisionEncoder(nn.Module):
1093
+ def __init__(self, config: SamVisionConfig):
1094
+ super().__init__()
1095
+ self.config = config
1096
+ self.image_size = config.image_size
1097
+
1098
+ self.patch_embed = SamPatchEmbeddings(config)
1099
+
1100
+ self.pos_embed = None
1101
+ if config.use_abs_pos:
1102
+ # Initialize absolute positional embedding with pretrain image size.
1103
+ self.pos_embed = nn.Parameter(
1104
+ torch.zeros(
1105
+ 1,
1106
+ config.image_size // config.patch_size,
1107
+ config.image_size // config.patch_size,
1108
+ config.hidden_size,
1109
+ )
1110
+ )
1111
+
1112
+ self.layers = nn.ModuleList()
1113
+ for i in range(config.num_hidden_layers):
1114
+ layer = SamVisionLayer(
1115
+ config,
1116
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
1117
+ )
1118
+ self.layers.append(layer)
1119
+
1120
+ self.neck = SamVisionNeck(config)
1121
+
1122
+ self.gradient_checkpointing = False
1123
+
1124
+ def get_input_embeddings(self):
1125
+ return self.patch_embed
1126
+
1127
+ @can_return_tuple
1128
+ def forward(
1129
+ self,
1130
+ pixel_values: Optional[torch.FloatTensor] = None,
1131
+ output_attentions: Optional[bool] = None,
1132
+ output_hidden_states: Optional[bool] = None,
1133
+ ) -> SamVisionEncoderOutput:
1134
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1135
+ output_hidden_states = (
1136
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1137
+ )
1138
+
1139
+ if pixel_values is None:
1140
+ raise ValueError("You have to specify pixel_values")
1141
+
1142
+ hidden_states = self.patch_embed(pixel_values)
1143
+ if self.pos_embed is not None:
1144
+ hidden_states = hidden_states + self.pos_embed
1145
+
1146
+ all_hidden_states = () if output_hidden_states else None
1147
+ all_self_attentions = () if output_attentions else None
1148
+
1149
+ for i, layer_module in enumerate(self.layers):
1150
+ if output_hidden_states:
1151
+ all_hidden_states = all_hidden_states + (hidden_states,)
1152
+
1153
+ if self.gradient_checkpointing and self.training:
1154
+ layer_outputs = self._gradient_checkpointing_func(
1155
+ layer_module.__call__,
1156
+ hidden_states,
1157
+ )
1158
+ else:
1159
+ layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
1160
+
1161
+ hidden_states = layer_outputs[0]
1162
+
1163
+ if output_attentions:
1164
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1165
+
1166
+ if output_hidden_states:
1167
+ all_hidden_states = all_hidden_states + (hidden_states,)
1168
+
1169
+ hidden_states = self.neck(hidden_states)
1170
+
1171
+ return SamVisionEncoderOutput(
1172
+ last_hidden_state=hidden_states,
1173
+ hidden_states=all_hidden_states,
1174
+ attentions=all_self_attentions,
1175
+ )
1176
+
1177
+
1178
+ class SamPreTrainedModel(PreTrainedModel):
1179
+ config_class = SamConfig
1180
+ base_model_prefix = "sam"
1181
+ main_input_name = "pixel_values"
1182
+ _no_split_modules = ["SamVisionAttention"]
1183
+ supports_gradient_checkpointing = True
1184
+ _supports_sdpa = True
1185
+
1186
+ def _init_weights(self, module):
1187
+ std = self.config.initializer_range
1188
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
1189
+ module.weight.data.normal_(mean=0.0, std=std)
1190
+ if module.bias is not None:
1191
+ module.bias.data.zero_()
1192
+ elif isinstance(module, nn.Embedding):
1193
+ module.weight.data.normal_(mean=0.0, std=std)
1194
+ if module.padding_idx is not None:
1195
+ module.weight.data[module.padding_idx].zero_()
1196
+ elif isinstance(module, (SamLayerNorm, nn.LayerNorm)):
1197
+ module.weight.data.fill_(1.0)
1198
+ module.bias.data.zero_()
1199
+ elif isinstance(module, SamVisionAttention):
1200
+ if module.use_rel_pos:
1201
+ module.rel_pos_h.data.zero_()
1202
+ module.rel_pos_w.data.zero_()
1203
+
1204
+
1205
+ SAM_START_DOCSTRING = r"""
1206
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1207
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1208
+ etc.)
1209
+
1210
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1211
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1212
+ and behavior.
1213
+
1214
+ Parameters:
1215
+ config ([`SamConfig`]): Model configuration class with all the parameters of the model.
1216
+ Initializing with a config file does not load the weights associated with the model, only the
1217
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1218
+ """
1219
+
1220
+
1221
+ SAM_INPUTS_DOCSTRING = r"""
1222
+ Args:
1223
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1224
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1225
+ details.
1226
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
1227
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
1228
+ better results. The points can be obtained by passing a list of list of list to the processor that will
1229
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
1230
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
1231
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
1232
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
1233
+ coordinates of the point. If a different number of points is passed either for each image, or for each
1234
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
1235
+ computation of the embedding will be skipped for these points using the labels.
1236
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
1237
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
1238
+ official implementation, there are 3 types of labels
1239
+
1240
+ - `1`: the point is a point that contains the object of interest
1241
+ - `0`: the point is a point that does not contain the object of interest
1242
+ - `-1`: the point corresponds to the background
1243
+
1244
+ We added the label:
1245
+
1246
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
1247
+
1248
+ The padding labels should be automatically done by the processor.
1249
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
1250
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
1251
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
1252
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
1253
+ size, the number of boxes per image and the coordinates of the top left and botton right point of the box.
1254
+ In the order (`x1`, `y1`, `x2`, `y2`):
1255
+
1256
+ - `x1`: the x coordinate of the top left point of the input box
1257
+ - `y1`: the y coordinate of the top left point of the input box
1258
+ - `x2`: the x coordinate of the bottom right point of the input box
1259
+ - `y2`: the y coordinate of the bottom right point of the input box
1260
+
1261
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
1262
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
1263
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
1264
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
1265
+
1266
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
1267
+ Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
1268
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
1269
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
1270
+ multimask_output (`bool`, *optional*):
1271
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
1272
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
1273
+ "best" mask, by specifying `multimask_output=False`.
1274
+ attention_similarity (`torch.FloatTensor`, *optional*):
1275
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
1276
+ model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
1277
+ target_embedding (`torch.FloatTensor`, *optional*):
1278
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
1279
+ the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
1280
+ output_attentions (`bool`, *optional*):
1281
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1282
+ tensors for more detail.
1283
+ output_hidden_states (`bool`, *optional*):
1284
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1285
+ more detail.
1286
+ return_dict (`bool`, *optional*):
1287
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1288
+ """
1289
+
1290
+
1291
+ SAM_VISION_INPUTS_DOCSTRING = r"""
1292
+ Args:
1293
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1294
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1295
+ details.
1296
+ output_attentions (`bool`, *optional*):
1297
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1298
+ tensors for more detail.
1299
+ output_hidden_states (`bool`, *optional*):
1300
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1301
+ more detail.
1302
+ return_dict (`bool`, *optional*):
1303
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1304
+ """
1305
+
1306
+
1307
+ @add_start_docstrings(
1308
+ """The vision model from Sam without any head or projection on top.""",
1309
+ SAM_START_DOCSTRING,
1310
+ )
1311
+ class SamVisionModel(SamPreTrainedModel):
1312
+ config_class = SamVisionConfig
1313
+ main_input_name = "pixel_values"
1314
+
1315
+ def __init__(self, config: SamVisionConfig):
1316
+ super().__init__(config)
1317
+ self.vision_encoder = SamVisionEncoder(config)
1318
+
1319
+ # Initialize weights and apply final processing
1320
+ self.post_init()
1321
+
1322
+ def get_input_embeddings(self) -> nn.Module:
1323
+ return self.vision_encoder.patch_embed
1324
+
1325
+ @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
1326
+ @replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig)
1327
+ def forward(
1328
+ self,
1329
+ pixel_values: Optional[torch.FloatTensor] = None,
1330
+ output_attentions: Optional[bool] = None,
1331
+ output_hidden_states: Optional[bool] = None,
1332
+ return_dict: Optional[bool] = None,
1333
+ ) -> Union[Tuple, SamVisionEncoderOutput]:
1334
+ r"""
1335
+ Returns:
1336
+
1337
+ """
1338
+ return self.vision_encoder(
1339
+ pixel_values,
1340
+ output_attentions=output_attentions,
1341
+ output_hidden_states=output_hidden_states,
1342
+ return_dict=return_dict,
1343
+ )
1344
+
1345
+
1346
+ @add_start_docstrings(
1347
+ "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
1348
+ " optional 2D location and bounding boxes.",
1349
+ SAM_START_DOCSTRING,
1350
+ )
1351
+ class SamModel(SamPreTrainedModel):
1352
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
1353
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
1354
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
1355
+
1356
+ def __init__(self, config):
1357
+ super().__init__(config)
1358
+ self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
1359
+
1360
+ self.vision_encoder = SamVisionEncoder(config.vision_config)
1361
+ self.prompt_encoder = SamPromptEncoder(config)
1362
+ self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
1363
+
1364
+ self.post_init()
1365
+
1366
+ def _tie_weights(self):
1367
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
1368
+ self.shared_image_embedding.positional_embedding.data
1369
+ )
1370
+
1371
+ def get_input_embeddings(self):
1372
+ return self.vision_encoder.get_input_embeddings()
1373
+
1374
+ def get_image_wide_positional_embeddings(self):
1375
+ size = self.config.prompt_encoder_config.image_embedding_size
1376
+ target_device = self.shared_image_embedding.positional_embedding.device
1377
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
1378
+ grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
1379
+ y_embed = grid.cumsum(dim=0) - 0.5
1380
+ x_embed = grid.cumsum(dim=1) - 0.5
1381
+ y_embed = y_embed / size
1382
+ x_embed = x_embed / size
1383
+
1384
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
1385
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
1386
+
1387
+ @torch.no_grad()
1388
+ def get_image_embeddings(
1389
+ self,
1390
+ pixel_values,
1391
+ output_attentions: Optional[bool] = None,
1392
+ output_hidden_states: Optional[bool] = None,
1393
+ ):
1394
+ r"""
1395
+ Returns the image embeddings by passing the pixel values through the vision encoder.
1396
+
1397
+ Args:
1398
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1399
+ Input pixel values
1400
+ output_attentions (`bool`, *optional*):
1401
+ Whether or not to return the attentions tensors of all attention layers.
1402
+ output_hidden_states (`bool`, *optional*):
1403
+ Whether or not to return the hidden states of all layers.
1404
+ """
1405
+ vision_output = self.vision_encoder(
1406
+ pixel_values,
1407
+ output_attentions=output_attentions,
1408
+ output_hidden_states=output_hidden_states,
1409
+ )
1410
+ image_embeddings = vision_output[0]
1411
+ return image_embeddings
1412
+
1413
+ @torch.no_grad()
1414
+ def get_prompt_embeddings(
1415
+ self,
1416
+ input_points: Optional[torch.FloatTensor] = None,
1417
+ input_labels: Optional[torch.LongTensor] = None,
1418
+ input_boxes: Optional[torch.FloatTensor] = None,
1419
+ input_masks: Optional[torch.LongTensor] = None,
1420
+ ):
1421
+ r"""
1422
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
1423
+
1424
+ Args:
1425
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
1426
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
1427
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
1428
+ point. The model will output `point_batch_size` times 3 masks in total.
1429
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
1430
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
1431
+ processor, or can be fed by the user.
1432
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
1433
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
1434
+ processor. users can also pass manually the input boxes.
1435
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
1436
+ Optional input masks for the prompt encoder.
1437
+ """
1438
+ prompt_output = self.prompt_encoder(
1439
+ input_points=input_points,
1440
+ input_labels=input_labels,
1441
+ input_boxes=input_boxes,
1442
+ input_masks=input_masks,
1443
+ )
1444
+ return prompt_output
1445
+
1446
+ @can_return_tuple
1447
+ @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
1448
+ def forward(
1449
+ self,
1450
+ pixel_values: Optional[torch.FloatTensor] = None,
1451
+ input_points: Optional[torch.FloatTensor] = None,
1452
+ input_labels: Optional[torch.LongTensor] = None,
1453
+ input_boxes: Optional[torch.FloatTensor] = None,
1454
+ input_masks: Optional[torch.LongTensor] = None,
1455
+ image_embeddings: Optional[torch.FloatTensor] = None,
1456
+ multimask_output: bool = True,
1457
+ attention_similarity: Optional[torch.FloatTensor] = None,
1458
+ target_embedding: Optional[torch.FloatTensor] = None,
1459
+ output_attentions: Optional[bool] = None,
1460
+ output_hidden_states: Optional[bool] = None,
1461
+ **kwargs,
1462
+ ) -> SamImageSegmentationOutput:
1463
+ r"""
1464
+ Example:
1465
+
1466
+ ```python
1467
+ >>> from PIL import Image
1468
+ >>> import requests
1469
+ >>> from transformers import AutoModel, AutoProcessor
1470
+
1471
+ >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
1472
+ >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
1473
+
1474
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
1475
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
1476
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
1477
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
1478
+
1479
+ >>> # Get segmentation mask
1480
+ >>> outputs = model(**inputs)
1481
+
1482
+ >>> # Postprocess masks
1483
+ >>> masks = processor.post_process_masks(
1484
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
1485
+ ... )
1486
+ ```
1487
+ """
1488
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1489
+ output_hidden_states = (
1490
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1491
+ )
1492
+
1493
+ if pixel_values is None and image_embeddings is None:
1494
+ raise ValueError("Either pixel_values or image_embeddings must be provided.")
1495
+
1496
+ if pixel_values is not None and image_embeddings is not None:
1497
+ raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
1498
+
1499
+ if input_points is not None and len(input_points.shape) != 4:
1500
+ raise ValueError(
1501
+ "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
1502
+ " got {}.".format(input_points.shape),
1503
+ )
1504
+ if input_boxes is not None and len(input_boxes.shape) != 3:
1505
+ raise ValueError(
1506
+ "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
1507
+ " got {}.".format(input_boxes.shape),
1508
+ )
1509
+ if input_points is not None and input_boxes is not None:
1510
+ point_batch_size = input_points.shape[1]
1511
+ box_batch_size = input_boxes.shape[1]
1512
+ if point_batch_size != box_batch_size:
1513
+ raise ValueError(
1514
+ "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
1515
+ point_batch_size, box_batch_size
1516
+ )
1517
+ )
1518
+
1519
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
1520
+ # repeat with batch size
1521
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
1522
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
1523
+
1524
+ vision_attentions = None
1525
+ vision_hidden_states = None
1526
+
1527
+ if pixel_values is not None:
1528
+ vision_outputs: SamVisionEncoderOutput = self.vision_encoder(
1529
+ pixel_values,
1530
+ output_attentions=output_attentions,
1531
+ output_hidden_states=output_hidden_states,
1532
+ )
1533
+ image_embeddings = vision_outputs.last_hidden_state
1534
+
1535
+ if output_hidden_states:
1536
+ vision_hidden_states = vision_outputs.hidden_states
1537
+ if output_attentions:
1538
+ vision_attentions = vision_outputs.attentions
1539
+
1540
+ if input_points is not None and input_labels is None:
1541
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
1542
+
1543
+ if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
1544
+ raise ValueError(
1545
+ "The batch size of the image embeddings and the input points must be the same. ",
1546
+ "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
1547
+ " if you want to pass multiple points for the same image, make sure that you passed ",
1548
+ " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
1549
+ " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
1550
+ )
1551
+
1552
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1553
+ input_points=input_points,
1554
+ input_labels=input_labels,
1555
+ input_boxes=input_boxes,
1556
+ input_masks=input_masks,
1557
+ )
1558
+
1559
+ low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
1560
+ image_embeddings=image_embeddings,
1561
+ image_positional_embeddings=image_positional_embeddings,
1562
+ sparse_prompt_embeddings=sparse_embeddings,
1563
+ dense_prompt_embeddings=dense_embeddings,
1564
+ multimask_output=multimask_output,
1565
+ attention_similarity=attention_similarity,
1566
+ target_embedding=target_embedding,
1567
+ output_attentions=output_attentions,
1568
+ )
1569
+
1570
+ return SamImageSegmentationOutput(
1571
+ iou_scores=iou_predictions,
1572
+ pred_masks=low_res_masks,
1573
+ vision_hidden_states=vision_hidden_states,
1574
+ vision_attentions=vision_attentions,
1575
+ mask_decoder_attentions=mask_decoder_attentions,
1576
+ )
1577
+
1578
+
1579
+ __all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]
docs/transformers/build/lib/transformers/models/sam/modeling_tf_sam.py ADDED
@@ -0,0 +1,1726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
17
+ discrepancy, the original file should be regarded as the 'reference' version.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import collections
23
+ from dataclasses import dataclass
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+ import tensorflow as tf
28
+
29
+ from ...activations_tf import ACT2FN
30
+ from ...modeling_tf_outputs import TFBaseModelOutput
31
+ from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
32
+ from ...tf_utils import flatten, functional_layernorm
33
+ from ...utils import (
34
+ ModelOutput,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CONFIG_FOR_DOC = "SamConfig"
46
+ _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
47
+
48
+
49
+ @dataclass
50
+ class TFSamVisionEncoderOutput(ModelOutput):
51
+ """
52
+ Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
53
+ layer to the pooler_output.
54
+
55
+ Args:
56
+ image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
57
+ The image embeddings obtained by applying the projection layer to the pooler_output.
58
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
59
+ Sequence of hidden-states at the output of the last layer of the model.
60
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
61
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
62
+ the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
63
+
64
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
65
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
66
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
67
+ sequence_length)`.
68
+
69
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
70
+ heads.
71
+ """
72
+
73
+ image_embeds: tf.Tensor | None = None
74
+ last_hidden_state: Optional[tf.Tensor] = None
75
+ hidden_states: Tuple[tf.Tensor, ...] | None = None
76
+ attentions: Tuple[tf.Tensor, ...] | None = None
77
+
78
+
79
+ @dataclass
80
+ class TFSamImageSegmentationOutput(ModelOutput):
81
+ """
82
+ Base class for Segment-Anything model's output
83
+
84
+ Args:
85
+ iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
86
+ The iou scores of the predicted masks.
87
+ pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
88
+ The predicted low resolutions masks. Needs to be post-processed by the processor
89
+ vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
91
+ the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
92
+
93
+ Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
94
+ vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
96
+ sequence_length)`.
97
+
98
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99
+ heads.
100
+ mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
101
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
102
+ sequence_length)`.
103
+
104
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
105
+ heads.
106
+ """
107
+
108
+ iou_scores: Optional[tf.Tensor] = None
109
+ pred_masks: Optional[tf.Tensor] = None
110
+ vision_hidden_states: Tuple[tf.Tensor, ...] | None = None
111
+ vision_attentions: Tuple[tf.Tensor, ...] | None = None
112
+ mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None
113
+
114
+
115
+ class TFSamPatchEmbeddings(keras.layers.Layer):
116
+ """
117
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
118
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
119
+ Transformer.
120
+ """
121
+
122
+ def __init__(self, config, **kwargs):
123
+ super().__init__(**kwargs)
124
+ image_size, patch_size = config.image_size, config.patch_size
125
+ num_channels, hidden_size = config.num_channels, config.hidden_size
126
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.num_patches = num_patches
133
+
134
+ self.projection = keras.layers.Conv2D(
135
+ hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
136
+ )
137
+
138
+ def call(self, pixel_values):
139
+ batch_size, num_channels, height, width = shape_list(pixel_values)
140
+ if num_channels != self.num_channels:
141
+ raise ValueError(
142
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
143
+ )
144
+ if height != self.image_size[0] or width != self.image_size[1]:
145
+ raise ValueError(
146
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
147
+ )
148
+ embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
149
+ return embeddings
150
+
151
+ def build(self, input_shape=None):
152
+ if self.built:
153
+ return
154
+ self.built = True
155
+ if getattr(self, "projection", None) is not None:
156
+ with tf.name_scope(self.projection.name):
157
+ self.projection.build([None, None, None, self.num_channels])
158
+
159
+
160
+ class TFSamMLPBlock(keras.layers.Layer):
161
+ def __init__(self, config, **kwargs):
162
+ super().__init__(**kwargs)
163
+ self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
164
+ self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
165
+ self.act = ACT2FN[config.hidden_act]
166
+ self.config = config
167
+
168
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
169
+ hidden_states = self.lin1(hidden_states)
170
+ hidden_states = self.act(hidden_states)
171
+ hidden_states = self.lin2(hidden_states)
172
+ return hidden_states
173
+
174
+ def build(self, input_shape=None):
175
+ if self.built:
176
+ return
177
+ self.built = True
178
+ if getattr(self, "lin1", None) is not None:
179
+ with tf.name_scope(self.lin1.name):
180
+ self.lin1.build([None, None, self.config.hidden_size])
181
+ if getattr(self, "lin2", None) is not None:
182
+ with tf.name_scope(self.lin2.name):
183
+ self.lin2.build([None, None, self.config.mlp_dim])
184
+
185
+
186
+ class TFSamLayerNorm(keras.layers.Layer):
187
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
188
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
189
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
190
+ """
191
+
192
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
193
+ super().__init__(**kwargs)
194
+ self.eps = eps
195
+ self.data_format = data_format
196
+ self.normalized_shape = normalized_shape
197
+ if self.data_format not in ["channels_last", "channels_first"]:
198
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
199
+
200
+ def build(self, input_shape):
201
+ self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
202
+ self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
203
+ super().build(input_shape)
204
+
205
+ def call(self, x: tf.Tensor) -> tf.Tensor:
206
+ if self.data_format == "channels_last":
207
+ x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
208
+ elif self.data_format == "channels_first":
209
+ x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
210
+ return x
211
+
212
+
213
+ class TFSamAttention(keras.layers.Layer):
214
+ """
215
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
216
+ values.
217
+ """
218
+
219
+ def __init__(self, config, downsample_rate=None, **kwargs):
220
+ super().__init__(**kwargs)
221
+ self.hidden_size = config.hidden_size
222
+
223
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
224
+
225
+ self.internal_dim = config.hidden_size // downsample_rate
226
+ self.num_attention_heads = config.num_attention_heads
227
+ if self.internal_dim % config.num_attention_heads != 0:
228
+ raise ValueError("num_attention_heads must divide hidden_size.")
229
+
230
+ self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
231
+ self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
232
+ self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
233
+ self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")
234
+
235
+ def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
236
+ batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
237
+ c_per_head = channel // num_attention_heads
238
+ hidden_states = tf.reshape(
239
+ hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
240
+ )
241
+ return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
242
+
243
+ def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
244
+ batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
245
+ hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
246
+ return tf.reshape(
247
+ hidden_states,
248
+ (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
249
+ )
250
+
251
+ def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
252
+ # Input projections
253
+ query = self.q_proj(query)
254
+ key = self.k_proj(key)
255
+ value = self.v_proj(value)
256
+
257
+ point_batch_size = shape_list(query)[1]
258
+ # Separate into heads
259
+ query = self._separate_heads(query, self.num_attention_heads)
260
+ key = self._separate_heads(key, self.num_attention_heads)
261
+ value = self._separate_heads(value, self.num_attention_heads)
262
+
263
+ # SamAttention
264
+ _, _, _, c_per_head = shape_list(query)
265
+ attn = tf.matmul(
266
+ query, tf.transpose(key, perm=[0, 1, 3, 2])
267
+ ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
268
+ attn = attn / tf.math.sqrt(float(c_per_head))
269
+ attn = tf.nn.softmax(attn, axis=-1)
270
+
271
+ # Get output
272
+ out = tf.matmul(attn, value)
273
+ out = self._recombine_heads(out, point_batch_size)
274
+ out = self.out_proj(out)
275
+
276
+ return out
277
+
278
+ def build(self, input_shape=None):
279
+ if self.built:
280
+ return
281
+ self.built = True
282
+ if getattr(self, "q_proj", None) is not None:
283
+ with tf.name_scope(self.q_proj.name):
284
+ self.q_proj.build([None, None, self.hidden_size])
285
+ if getattr(self, "k_proj", None) is not None:
286
+ with tf.name_scope(self.k_proj.name):
287
+ self.k_proj.build([None, None, self.hidden_size])
288
+ if getattr(self, "v_proj", None) is not None:
289
+ with tf.name_scope(self.v_proj.name):
290
+ self.v_proj.build([None, None, self.hidden_size])
291
+ if getattr(self, "out_proj", None) is not None:
292
+ with tf.name_scope(self.out_proj.name):
293
+ self.out_proj.build([None, None, self.internal_dim])
294
+
295
+
296
+ class TFSamTwoWayAttentionBlock(keras.layers.Layer):
297
+ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
298
+ """
299
+ A transformer block with four layers:
300
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
301
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
302
+
303
+ Arguments:
304
+ config (`SamMaskDecoderConfig`):
305
+ The configuration file used to instantiate the block
306
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
307
+ The downsample ratio of the block used to reduce the inner dim of the attention.
308
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
309
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
310
+ """
311
+ super().__init__(**kwargs)
312
+
313
+ self.hidden_size = config.hidden_size
314
+ self.layer_norm_eps = config.layer_norm_eps
315
+
316
+ self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
317
+ self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
318
+
319
+ self.cross_attn_token_to_image = TFSamAttention(
320
+ config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
321
+ )
322
+ self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
323
+
324
+ self.mlp = TFSamMLPBlock(config, name="mlp")
325
+ self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
326
+
327
+ self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
328
+ self.cross_attn_image_to_token = TFSamAttention(
329
+ config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
330
+ )
331
+
332
+ self.skip_first_layer_pe = skip_first_layer_pe
333
+
334
+ def call(
335
+ self,
336
+ queries: tf.Tensor,
337
+ keys: tf.Tensor,
338
+ query_point_embedding: tf.Tensor,
339
+ key_point_embedding: tf.Tensor,
340
+ output_attentions: bool = False,
341
+ ):
342
+ # Self attention block
343
+ if self.skip_first_layer_pe:
344
+ queries = self.self_attn(query=queries, key=queries, value=queries)
345
+ else:
346
+ query = queries + query_point_embedding
347
+ attn_out = self.self_attn(query=query, key=query, value=queries)
348
+ queries = queries + attn_out
349
+ queries = self.layer_norm1(queries)
350
+
351
+ # Cross attention block, tokens attending to image embedding
352
+ query = queries + query_point_embedding
353
+ key = keys + key_point_embedding
354
+
355
+ attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
356
+ queries = queries + attn_out
357
+
358
+ queries = self.layer_norm2(queries)
359
+
360
+ # MLP block
361
+ mlp_out = self.mlp(queries)
362
+ queries = queries + mlp_out
363
+ queries = self.layer_norm3(queries)
364
+
365
+ # Cross attention block, image embedding attending to tokens
366
+ query = queries + query_point_embedding
367
+ key = keys + key_point_embedding
368
+
369
+ attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
370
+ keys = keys + attn_out
371
+
372
+ keys = self.layer_norm4(keys)
373
+
374
+ outputs = (queries, keys)
375
+
376
+ if output_attentions:
377
+ outputs = outputs + (attn_out,)
378
+ else:
379
+ outputs = outputs + (None,)
380
+
381
+ return outputs
382
+
383
+ def build(self, input_shape=None):
384
+ if self.built:
385
+ return
386
+ self.built = True
387
+ if getattr(self, "self_attn", None) is not None:
388
+ with tf.name_scope(self.self_attn.name):
389
+ self.self_attn.build(None)
390
+ if getattr(self, "layer_norm1", None) is not None:
391
+ with tf.name_scope(self.layer_norm1.name):
392
+ self.layer_norm1.build([None, None, None, self.hidden_size])
393
+ if getattr(self, "cross_attn_token_to_image", None) is not None:
394
+ with tf.name_scope(self.cross_attn_token_to_image.name):
395
+ self.cross_attn_token_to_image.build(None)
396
+ if getattr(self, "layer_norm2", None) is not None:
397
+ with tf.name_scope(self.layer_norm2.name):
398
+ self.layer_norm2.build([None, None, None, self.hidden_size])
399
+ if getattr(self, "mlp", None) is not None:
400
+ with tf.name_scope(self.mlp.name):
401
+ self.mlp.build(None)
402
+ if getattr(self, "layer_norm3", None) is not None:
403
+ with tf.name_scope(self.layer_norm3.name):
404
+ self.layer_norm3.build([None, None, None, self.hidden_size])
405
+ if getattr(self, "layer_norm4", None) is not None:
406
+ with tf.name_scope(self.layer_norm4.name):
407
+ self.layer_norm4.build([None, None, None, self.hidden_size])
408
+ if getattr(self, "cross_attn_image_to_token", None) is not None:
409
+ with tf.name_scope(self.cross_attn_image_to_token.name):
410
+ self.cross_attn_image_to_token.build(None)
411
+
412
+
413
+ class TFSamTwoWayTransformer(keras.layers.Layer):
414
+ def __init__(self, config: SamMaskDecoderConfig, **kwargs):
415
+ super().__init__(**kwargs)
416
+ self.config = config
417
+
418
+ self.num_hidden_layers = config.num_hidden_layers
419
+ self.layers = []
420
+
421
+ for i in range(self.num_hidden_layers):
422
+ self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
423
+
424
+ self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
425
+ self.layer_norm_final_attn = keras.layers.LayerNormalization(
426
+ epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
427
+ )
428
+
429
+ def call(
430
+ self,
431
+ point_embeddings: tf.Tensor,
432
+ image_embeddings: tf.Tensor,
433
+ image_positional_embeddings: tf.Tensor,
434
+ output_attentions: Optional[bool] = None,
435
+ output_hidden_states: Optional[bool] = None,
436
+ return_dict: Optional[bool] = None,
437
+ ) -> Union[Tuple, TFBaseModelOutput]:
438
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
439
+ output_hidden_states = (
440
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
441
+ )
442
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
443
+
444
+ all_attentions = ()
445
+
446
+ if image_embeddings is None:
447
+ raise ValueError("You have to specify an image_embedding")
448
+
449
+ image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
450
+ image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
451
+
452
+ # Prepare queries
453
+ queries = point_embeddings
454
+ keys = image_embeddings
455
+
456
+ # Apply transformer blocks and final layernorm
457
+ for layer in self.layers:
458
+ queries, keys, attention_outputs = layer(
459
+ queries=queries,
460
+ keys=keys,
461
+ query_point_embedding=point_embeddings,
462
+ key_point_embedding=image_positional_embeddings,
463
+ output_attentions=output_attentions,
464
+ )
465
+
466
+ if output_attentions:
467
+ all_attentions = all_attentions + (attention_outputs,)
468
+
469
+ # Apply the final attenion layer from the points to the image
470
+ query = queries + point_embeddings
471
+ key = keys + image_positional_embeddings
472
+
473
+ attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
474
+
475
+ queries = queries + attn_out
476
+ queries = self.layer_norm_final_attn(queries)
477
+ return queries, keys, all_attentions
478
+
479
+ def build(self, input_shape=None):
480
+ if self.built:
481
+ return
482
+ self.built = True
483
+ if getattr(self, "final_attn_token_to_image", None) is not None:
484
+ with tf.name_scope(self.final_attn_token_to_image.name):
485
+ self.final_attn_token_to_image.build(None)
486
+ if getattr(self, "layer_norm_final_attn", None) is not None:
487
+ with tf.name_scope(self.layer_norm_final_attn.name):
488
+ self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
489
+ for layer in self.layers:
490
+ with tf.name_scope(layer.name):
491
+ layer.build(None)
492
+
493
+
494
+ class TFSamFeedForward(keras.layers.Layer):
495
+ def __init__(
496
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
497
+ ):
498
+ super().__init__(**kwargs)
499
+ self.num_layers = num_layers
500
+ self.activation = keras.layers.ReLU()
501
+ self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
502
+ self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
503
+ self.layers = [
504
+ keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
505
+ for i in range(num_layers - 2)
506
+ ]
507
+ self.sigmoid_output = sigmoid_output
508
+ self.hidden_dim = hidden_dim
509
+ self.input_dim = input_dim
510
+
511
+ def call(self, hidden_states):
512
+ hidden_states = self.proj_in(hidden_states)
513
+ hidden_states = self.activation(hidden_states)
514
+ for layer in self.layers:
515
+ hidden_states = self.activation(layer(hidden_states))
516
+
517
+ hidden_states = self.proj_out(hidden_states)
518
+ if self.sigmoid_output:
519
+ hidden_states = tf.sigmoid(hidden_states)
520
+ return hidden_states
521
+
522
+ def build(self, input_shape=None):
523
+ if self.built:
524
+ return
525
+ self.built = True
526
+ if getattr(self, "proj_in", None) is not None:
527
+ with tf.name_scope(self.proj_in.name):
528
+ self.proj_in.build([None, None, self.input_dim])
529
+ if getattr(self, "proj_out", None) is not None:
530
+ with tf.name_scope(self.proj_out.name):
531
+ self.proj_out.build([None, None, self.hidden_dim])
532
+ if getattr(self, "layers", None) is not None:
533
+ for layer in self.layers:
534
+ with tf.name_scope(layer.name):
535
+ layer.build([None, None, self.hidden_dim])
536
+
537
+
538
+ class TFSamMaskDecoder(keras.layers.Layer):
539
+ def __init__(self, config: SamMaskDecoderConfig, **kwargs):
540
+ super().__init__(**kwargs)
541
+
542
+ self.hidden_size = config.hidden_size
543
+
544
+ self.num_multimask_outputs = config.num_multimask_outputs
545
+ self.num_mask_tokens = config.num_multimask_outputs + 1
546
+
547
+ self.transformer = TFSamTwoWayTransformer(config, name="transformer")
548
+
549
+ self.upscale_conv1 = keras.layers.Conv2DTranspose(
550
+ self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
551
+ )
552
+ self.upscale_conv2 = keras.layers.Conv2DTranspose(
553
+ self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
554
+ )
555
+ self.upscale_layer_norm = TFSamLayerNorm(
556
+ self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
557
+ )
558
+ self.activation = tf.nn.gelu
559
+
560
+ mlps_list = []
561
+ for i in range(self.num_mask_tokens):
562
+ mlps_list += [
563
+ TFSamFeedForward(
564
+ self.hidden_size,
565
+ self.hidden_size,
566
+ self.hidden_size // 8,
567
+ 3,
568
+ name=f"output_hypernetworks_mlps_._{i}",
569
+ )
570
+ ]
571
+ self.output_hypernetworks_mlps = mlps_list
572
+
573
+ self.iou_prediction_head = TFSamFeedForward(
574
+ self.hidden_size,
575
+ config.iou_head_hidden_dim,
576
+ self.num_mask_tokens,
577
+ config.iou_head_depth,
578
+ name="iou_prediction_head",
579
+ )
580
+
581
+ def build(self, input_shape=None):
582
+ if self.built:
583
+ return
584
+ self.built = True
585
+ self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
586
+ self.mask_tokens = self.add_weight(
587
+ shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
588
+ )
589
+
590
+ if getattr(self, "transformer", None) is not None:
591
+ with tf.name_scope(self.transformer.name):
592
+ self.transformer.build(None)
593
+ if getattr(self, "upscale_conv1", None) is not None:
594
+ with tf.name_scope(self.upscale_conv1.name):
595
+ self.upscale_conv1.build([None, self.hidden_size, None, None])
596
+ if getattr(self, "upscale_conv2", None) is not None:
597
+ with tf.name_scope(self.upscale_conv2.name):
598
+ self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
599
+ if getattr(self, "upscale_layer_norm", None) is not None:
600
+ with tf.name_scope(self.upscale_layer_norm.name):
601
+ self.upscale_layer_norm.build(None)
602
+ if getattr(self, "iou_prediction_head", None) is not None:
603
+ with tf.name_scope(self.iou_prediction_head.name):
604
+ self.iou_prediction_head.build(None)
605
+ for mlp in self.output_hypernetworks_mlps:
606
+ with tf.name_scope(mlp.name):
607
+ mlp.build(None)
608
+
609
+ def call(
610
+ self,
611
+ image_embeddings: tf.Tensor,
612
+ image_positional_embeddings: tf.Tensor,
613
+ sparse_prompt_embeddings: tf.Tensor,
614
+ dense_prompt_embeddings: tf.Tensor,
615
+ multimask_output: bool,
616
+ output_attentions: Optional[bool] = None,
617
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
618
+ batch_size, num_channels, height, width = shape_list(image_embeddings)
619
+ point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
620
+
621
+ output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
622
+ output_tokens = tf.tile(
623
+ output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
624
+ ) # Should be (batch_size, point_size, 5, 32)
625
+
626
+ # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
627
+ # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
628
+ # it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
629
+ if shape_list(sparse_prompt_embeddings)[1] != 0:
630
+ tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
631
+ else:
632
+ tokens = output_tokens
633
+ point_embeddings = tf.cast(tokens, self.iou_token.dtype)
634
+
635
+ image_embeddings = image_embeddings + dense_prompt_embeddings
636
+ image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
637
+ image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
638
+
639
+ point_embedding, image_embeddings, attentions = self.transformer(
640
+ point_embeddings=point_embeddings,
641
+ image_embeddings=image_embeddings,
642
+ image_positional_embeddings=image_positional_embeddings,
643
+ output_attentions=output_attentions,
644
+ )
645
+ iou_token_out = point_embedding[:, :, 0, :]
646
+ mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
647
+
648
+ image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
649
+ image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
650
+
651
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
652
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
653
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
654
+
655
+ hyper_in_list = []
656
+ for i in range(self.num_mask_tokens):
657
+ current_mlp = self.output_hypernetworks_mlps[i]
658
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
659
+ hyper_in = tf.stack(hyper_in_list, axis=2)
660
+
661
+ _, num_channels, height, width = shape_list(upscaled_embedding)
662
+ upscaled_embedding = tf.reshape(
663
+ upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
664
+ )
665
+ masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
666
+
667
+ iou_pred = self.iou_prediction_head(iou_token_out)
668
+
669
+ if multimask_output:
670
+ mask_slice = slice(1, None)
671
+ else:
672
+ mask_slice = slice(0, 1)
673
+ masks = masks[:, :, mask_slice, :, :]
674
+ iou_pred = iou_pred[:, :, mask_slice]
675
+
676
+ outputs = (masks, iou_pred)
677
+
678
+ if output_attentions:
679
+ outputs = outputs + (attentions,)
680
+ else:
681
+ outputs = outputs + (None,)
682
+
683
+ return outputs
684
+
685
+
686
+ class TFSamPositionalEmbedding(keras.layers.Layer):
687
+ def __init__(self, config, **kwargs):
688
+ super().__init__(**kwargs)
689
+ self.scale = config.hidden_size // 2
690
+ self.config = config
691
+
692
+ def build(self, input_shape):
693
+ # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
694
+ self.positional_embedding = self.add_weight(
695
+ name="positional_embedding",
696
+ shape=(2, self.config.num_pos_feats),
697
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
698
+ trainable=False,
699
+ )
700
+ super().build(input_shape)
701
+
702
+ def call(self, input_coords, input_shape=None):
703
+ """Positionally encode points that are normalized to [0,1]."""
704
+ coordinates = tf.identity(input_coords)
705
+
706
+ if input_shape is not None:
707
+ coordinates = tf.stack(
708
+ [
709
+ tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
710
+ tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
711
+ ],
712
+ axis=-1,
713
+ )
714
+
715
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
716
+ coordinates = 2 * coordinates - 1
717
+ coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
718
+ coordinates = tf.matmul(coordinates, self.positional_embedding)
719
+ coordinates = 2 * np.pi * coordinates
720
+ # outputs d_1 x ... x d_n x channel shape
721
+ return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
722
+
723
+
724
+ class TFSamMaskEmbedding(keras.layers.Layer):
725
+ def __init__(self, config: SamPromptEncoderConfig, **kwargs):
726
+ super().__init__(**kwargs)
727
+ self.mask_input_channels = config.mask_input_channels // 4
728
+ self.activation = ACT2FN[config.hidden_act]
729
+ self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
730
+ self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
731
+ self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
732
+ self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
733
+ self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
734
+ self.config = config
735
+
736
+ def call(self, masks):
737
+ masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
738
+ hidden_states = self.conv1(masks)
739
+ hidden_states = self.layer_norm1(hidden_states)
740
+ hidden_states = self.activation(hidden_states)
741
+
742
+ hidden_states = self.conv2(hidden_states)
743
+ hidden_states = self.layer_norm2(hidden_states)
744
+ hidden_states = self.activation(hidden_states)
745
+ dense_embeddings = self.conv3(hidden_states)
746
+ dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
747
+ return dense_embeddings
748
+
749
+ def build(self, input_shape=None):
750
+ # This class needs an explicit build method because it isn't called with the standard dummy inputs
751
+ if self.built:
752
+ return
753
+ self.built = True
754
+ with tf.name_scope("conv1"):
755
+ self.conv1.build([None, None, None, 1])
756
+ with tf.name_scope("conv2"):
757
+ self.conv2.build([None, None, None, self.mask_input_channels])
758
+ with tf.name_scope("conv3"):
759
+ self.conv3.build([None, None, None, self.mask_input_channels * 4])
760
+ with tf.name_scope("layer_norm1"):
761
+ self.layer_norm1.build([None, None, None, self.mask_input_channels])
762
+ with tf.name_scope("layer_norm2"):
763
+ self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
764
+
765
+
766
+ class TFSamPromptEncoder(keras.layers.Layer):
767
+ def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
768
+ super().__init__(**kwargs)
769
+ self.shared_embedding = shared_patch_embedding
770
+ self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
771
+ self.no_mask_embed = None
772
+
773
+ self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
774
+ self.input_image_size = config.image_size
775
+
776
+ self.point_embed = []
777
+ self.hidden_size = config.hidden_size
778
+ self.not_a_point_embed = None
779
+ self.config = config
780
+
781
+ def build(self, input_shape=None):
782
+ self.no_mask_embed = self.add_weight(
783
+ name="no_mask_embed.weight",
784
+ shape=(1, self.hidden_size),
785
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
786
+ trainable=True,
787
+ )
788
+ self.point_embed = [
789
+ self.add_weight(
790
+ name=f"point_embed_._{i}.weight",
791
+ shape=(1, self.hidden_size),
792
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
793
+ trainable=True,
794
+ )
795
+ for i in range(self.config.num_point_embeddings)
796
+ ]
797
+ self.not_a_point_embed = self.add_weight(
798
+ name="not_a_point_embed.weight",
799
+ shape=(1, self.hidden_size),
800
+ initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
801
+ trainable=True,
802
+ )
803
+ with tf.name_scope("mask_embed"):
804
+ # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
805
+ self.mask_embed.build(
806
+ (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
807
+ )
808
+
809
+ if self.built:
810
+ return
811
+ self.built = True
812
+ if getattr(self, "mask_embed", None) is not None:
813
+ with tf.name_scope(self.mask_embed.name):
814
+ self.mask_embed.build(None)
815
+
816
+ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
817
+ """Embeds point prompts."""
818
+ points = points + 0.5 # Shift to center of pixel
819
+ if pad:
820
+ target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
821
+ target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
822
+ padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
823
+ padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
824
+ points = tf.concat([points, padding_point], axis=2)
825
+ labels = tf.concat([labels, padding_label], axis=2)
826
+ input_shape = (self.input_image_size, self.input_image_size)
827
+ point_embedding = self.shared_embedding(points, input_shape)
828
+
829
+ point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
830
+
831
+ point_embedding = tf.where(
832
+ labels[..., None] != -10,
833
+ point_embedding,
834
+ tf.zeros_like(point_embedding),
835
+ )
836
+ point_embedding = tf.where(
837
+ (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
838
+ )
839
+ point_embedding = tf.where(
840
+ (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
841
+ )
842
+ return point_embedding
843
+
844
+ def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
845
+ """Embeds box prompts."""
846
+ boxes = boxes + 0.5 # Shift to center of pixel
847
+ batch_size, nb_boxes = shape_list(boxes)[:2]
848
+ coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
849
+ input_shape = (self.input_image_size, self.input_image_size)
850
+ corner_embedding = self.shared_embedding(coords, input_shape)
851
+ corner_embedding += tf.where(
852
+ tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
853
+ self.point_embed[2][0],
854
+ self.point_embed[3][0],
855
+ )
856
+ return corner_embedding
857
+
858
+ def call(
859
+ self,
860
+ batch_size: Optional[int],
861
+ input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
862
+ input_labels: tf.Tensor | None,
863
+ input_boxes: tf.Tensor | None,
864
+ input_masks: tf.Tensor | None,
865
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
866
+ """
867
+ Embeds different types of prompts, returning both sparse and dense embeddings.
868
+
869
+ Args:
870
+ points (`tf.Tensor`, *optional*):
871
+ point coordinates and labels to embed.
872
+ boxes (`tf.Tensor`, *optional*):
873
+ boxes to embed
874
+ masks (`tf.Tensor`, *optional*):
875
+ masks to embed
876
+ """
877
+ sparse_embeddings = None
878
+ if input_points is not None:
879
+ batch_size, point_batch_size = shape_list(input_points)[:2]
880
+ if input_labels is None:
881
+ raise ValueError("If points are provided, labels must also be provided.")
882
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
883
+ sparse_embeddings = tf.zeros(
884
+ (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
885
+ )
886
+ sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
887
+ if input_boxes is not None:
888
+ batch_size = shape_list(input_boxes)[0]
889
+ box_embeddings = self._embed_boxes(input_boxes)
890
+ if sparse_embeddings is None:
891
+ sparse_embeddings = box_embeddings
892
+ else:
893
+ sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
894
+ if input_masks is not None:
895
+ dense_embeddings = self.mask_embed(input_masks)
896
+ else:
897
+ dense_embeddings = self.no_mask_embed[0]
898
+ dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
899
+ dense_embeddings = tf.tile(
900
+ dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
901
+ )
902
+ if sparse_embeddings is None:
903
+ sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
904
+
905
+ return sparse_embeddings, dense_embeddings
906
+
907
+
908
+ class TFSamVisionAttention(keras.layers.Layer):
909
+ """Multi-head Attention block with relative position embeddings."""
910
+
911
+ def __init__(self, config, window_size, **kwargs):
912
+ super().__init__(**kwargs)
913
+ input_size = (
914
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
915
+ if window_size == 0
916
+ else (window_size, window_size)
917
+ )
918
+ self.input_size = input_size
919
+
920
+ self.num_attention_heads = config.num_attention_heads
921
+ head_dim = config.hidden_size // config.num_attention_heads
922
+ self.head_dim = head_dim
923
+ self.scale = head_dim**-0.5
924
+ self.dropout = config.attention_dropout
925
+
926
+ self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
927
+ self.proj = keras.layers.Dense(config.hidden_size, name="proj")
928
+
929
+ self.use_rel_pos = config.use_rel_pos
930
+ if self.use_rel_pos:
931
+ if input_size is None:
932
+ raise ValueError("Input size must be provided if using relative positional encoding.")
933
+ self.config = config
934
+
935
+ def build(self, input_shape=None):
936
+ if self.input_size is not None:
937
+ # initialize relative positional embeddings
938
+ self.rel_pos_h = self.add_weight(
939
+ shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
940
+ )
941
+ self.rel_pos_w = self.add_weight(
942
+ shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
943
+ )
944
+
945
+ if self.built:
946
+ return
947
+ self.built = True
948
+ if getattr(self, "qkv", None) is not None:
949
+ with tf.name_scope(self.qkv.name):
950
+ self.qkv.build([None, None, self.config.hidden_size])
951
+ if getattr(self, "proj", None) is not None:
952
+ with tf.name_scope(self.proj.name):
953
+ self.proj.build([None, None, self.config.hidden_size])
954
+
955
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
956
+ """
957
+ Get relative positional embeddings according to the relative positions of
958
+ query and key sizes.
959
+
960
+ Args:
961
+ q_size (int):
962
+ size of the query.
963
+ k_size (int):
964
+ size of key k.
965
+ rel_pos (`tf.Tensor`):
966
+ relative position embeddings (L, channel).
967
+
968
+ Returns:
969
+ Extracted positional embeddings according to relative positions.
970
+ """
971
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
972
+ # Interpolate rel pos if needed.
973
+ if rel_pos.shape[0] != max_rel_dist:
974
+ # Interpolate rel pos.
975
+ rel_pos_resized = tf.image.resize(
976
+ tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
977
+ size=(max_rel_dist, rel_pos.shape[1]),
978
+ method="bilinear",
979
+ )
980
+ rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
981
+ else:
982
+ rel_pos_resized = rel_pos
983
+
984
+ # Scale the coords with short length if shapes for q and k are different.
985
+ q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
986
+ k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
987
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
988
+
989
+ return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
990
+
991
+ def get_decomposed_rel_pos(
992
+ self,
993
+ query: tf.Tensor,
994
+ rel_pos_h: tf.Tensor,
995
+ rel_pos_w: tf.Tensor,
996
+ q_size: Tuple[int, int],
997
+ k_size: Tuple[int, int],
998
+ ) -> tf.Tensor:
999
+ """
1000
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
1001
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
1002
+
1003
+ Args:
1004
+ query (`tf.Tensor`):
1005
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
1006
+ rel_pos_h (`tf.Tensor`):
1007
+ relative position embeddings (Lh, channel) for height axis.
1008
+ rel_pos_w (`tf.Tensor`):
1009
+ relative position embeddings (Lw, channel) for width axis.
1010
+ q_size (tuple):
1011
+ spatial sequence size of query q with (query_height, query_width).
1012
+ k_size (tuple):
1013
+ spatial sequence size of key k with (key_height, key_width).
1014
+
1015
+ Returns:
1016
+ decomposed_rel_pos (`torch.Tensor`):
1017
+ decomposed relative position embeddings.
1018
+ """
1019
+ query_height, query_width = q_size
1020
+ key_height, key_width = k_size
1021
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
1022
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
1023
+
1024
+ batch_size, _, dim = shape_list(query)
1025
+ reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
1026
+ rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
1027
+ rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
1028
+
1029
+ rel_h = tf.expand_dims(rel_h, axis=-1)
1030
+ rel_w = tf.expand_dims(rel_w, axis=-2)
1031
+ decomposed_rel_pos = rel_h + rel_w
1032
+
1033
+ return decomposed_rel_pos
1034
+
1035
+ def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
1036
+ batch_size, height, width, _ = shape_list(hidden_states)
1037
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
1038
+ qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
1039
+ qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
1040
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
1041
+ query, key, value = tf.unstack(
1042
+ tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
1043
+ )
1044
+ attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
1045
+
1046
+ if self.use_rel_pos:
1047
+ decomposed_rel_pos = self.get_decomposed_rel_pos(
1048
+ query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
1049
+ )
1050
+ decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
1051
+ attn_weights = attn_weights + decomposed_rel_pos
1052
+
1053
+ attn_weights = tf.nn.softmax(attn_weights, axis=-1)
1054
+
1055
+ if training:
1056
+ attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
1057
+ else:
1058
+ attn_probs = attn_weights
1059
+
1060
+ attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
1061
+ attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
1062
+ attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
1063
+
1064
+ attn_output = self.proj(attn_output)
1065
+
1066
+ if output_attentions:
1067
+ outputs = (attn_output, attn_weights)
1068
+ else:
1069
+ outputs = (attn_output, None)
1070
+
1071
+ return outputs
1072
+
1073
+
1074
+ class TFSamVisionLayer(keras.layers.Layer):
1075
+ def __init__(self, config, window_size, **kwargs):
1076
+ super().__init__(**kwargs)
1077
+ self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
1078
+ self.attn = TFSamVisionAttention(config, window_size, name="attn")
1079
+ self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
1080
+ self.mlp = TFSamMLPBlock(config, name="mlp")
1081
+ self.window_size = window_size
1082
+ self.config = config
1083
+
1084
+ def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
1085
+ batch_size, height, width, channel = shape_list(hidden_states)
1086
+
1087
+ pad_h = (window_size - height % window_size) % window_size
1088
+ pad_w = (window_size - width % window_size) % window_size
1089
+ if pad_h > 0 or pad_w > 0:
1090
+ hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
1091
+ pad_height, pad_width = height + pad_h, width + pad_w
1092
+
1093
+ hidden_states = tf.reshape(
1094
+ hidden_states,
1095
+ [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
1096
+ )
1097
+ windows = tf.reshape(
1098
+ tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
1099
+ )
1100
+ return windows, (pad_height, pad_width)
1101
+
1102
+ def window_unpartition(
1103
+ self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
1104
+ ) -> tf.Tensor:
1105
+ pad_height, pad_width = padding_shape
1106
+ height, width = original_shape
1107
+ batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
1108
+ hidden_states = tf.reshape(
1109
+ windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
1110
+ )
1111
+ hidden_states = tf.reshape(
1112
+ tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
1113
+ )
1114
+
1115
+ if pad_height > height or pad_width > width:
1116
+ hidden_states = hidden_states[:, :height, :width, :]
1117
+ return hidden_states
1118
+
1119
+ def call(
1120
+ self,
1121
+ hidden_states: tf.Tensor,
1122
+ output_attentions: Optional[bool] = False,
1123
+ training: Optional[bool] = False,
1124
+ ) -> Tuple[tf.Tensor]:
1125
+ residual = hidden_states
1126
+
1127
+ hidden_states = self.layer_norm1(hidden_states)
1128
+ if self.window_size > 0:
1129
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
1130
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
1131
+
1132
+ hidden_states, attn_weights = self.attn(
1133
+ hidden_states=hidden_states,
1134
+ output_attentions=output_attentions,
1135
+ training=training,
1136
+ )
1137
+ if self.window_size > 0:
1138
+ hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
1139
+
1140
+ hidden_states = residual + hidden_states
1141
+ layernorm_output = self.layer_norm2(hidden_states)
1142
+ hidden_states = hidden_states + self.mlp(layernorm_output)
1143
+
1144
+ outputs = (hidden_states,)
1145
+ if output_attentions:
1146
+ outputs += (attn_weights,)
1147
+
1148
+ return outputs
1149
+
1150
+ def build(self, input_shape=None):
1151
+ if self.built:
1152
+ return
1153
+ self.built = True
1154
+ if getattr(self, "layer_norm1", None) is not None:
1155
+ with tf.name_scope(self.layer_norm1.name):
1156
+ self.layer_norm1.build([None, None, None, self.config.hidden_size])
1157
+ if getattr(self, "attn", None) is not None:
1158
+ with tf.name_scope(self.attn.name):
1159
+ self.attn.build(None)
1160
+ if getattr(self, "layer_norm2", None) is not None:
1161
+ with tf.name_scope(self.layer_norm2.name):
1162
+ self.layer_norm2.build([None, None, None, self.config.hidden_size])
1163
+ if getattr(self, "mlp", None) is not None:
1164
+ with tf.name_scope(self.mlp.name):
1165
+ self.mlp.build(None)
1166
+
1167
+
1168
+ class TFSamVisionNeck(keras.layers.Layer):
1169
+ def __init__(self, config: SamVisionConfig, **kwargs):
1170
+ super().__init__(**kwargs)
1171
+ self.config = config
1172
+
1173
+ self.conv1 = keras.layers.Conv2D(
1174
+ config.output_channels,
1175
+ kernel_size=1,
1176
+ use_bias=False,
1177
+ name="conv1",
1178
+ )
1179
+ self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
1180
+ self.conv2 = keras.layers.Conv2D(
1181
+ config.output_channels,
1182
+ kernel_size=3,
1183
+ padding="same",
1184
+ use_bias=False,
1185
+ name="conv2",
1186
+ )
1187
+ self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")
1188
+
1189
+ def call(self, hidden_states):
1190
+ hidden_states = self.conv1(hidden_states)
1191
+ hidden_states = self.layer_norm1(hidden_states)
1192
+
1193
+ hidden_states = self.conv2(hidden_states)
1194
+ hidden_states = self.layer_norm2(hidden_states)
1195
+ hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
1196
+ return hidden_states
1197
+
1198
+ def build(self, input_shape=None):
1199
+ if self.built:
1200
+ return
1201
+ self.built = True
1202
+ if getattr(self, "conv1", None) is not None:
1203
+ with tf.name_scope(self.conv1.name):
1204
+ self.conv1.build([None, None, None, self.config.hidden_size])
1205
+ if getattr(self, "layer_norm1", None) is not None:
1206
+ with tf.name_scope(self.layer_norm1.name):
1207
+ self.layer_norm1.build(None)
1208
+ if getattr(self, "conv2", None) is not None:
1209
+ with tf.name_scope(self.conv2.name):
1210
+ self.conv2.build([None, None, None, self.config.output_channels])
1211
+ if getattr(self, "layer_norm2", None) is not None:
1212
+ with tf.name_scope(self.layer_norm2.name):
1213
+ self.layer_norm2.build(None)
1214
+
1215
+
1216
+ class TFSamVisionEncoder(keras.layers.Layer):
1217
+ def __init__(self, config: SamVisionConfig, **kwargs):
1218
+ super().__init__(**kwargs)
1219
+ self.config = config
1220
+ self.image_size = config.image_size
1221
+
1222
+ self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
1223
+
1224
+ self.pos_embed = None
1225
+
1226
+ self.layers = []
1227
+ for i in range(config.num_hidden_layers):
1228
+ layer = TFSamVisionLayer(
1229
+ config,
1230
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
1231
+ name=f"layers_._{i}",
1232
+ )
1233
+ self.layers.append(layer)
1234
+
1235
+ self.neck = TFSamVisionNeck(config, name="neck")
1236
+
1237
+ def build(self, input_shape=None):
1238
+ if self.built:
1239
+ return
1240
+ self.built = True
1241
+ if self.config.use_abs_pos:
1242
+ # Initialize absolute positional embedding with pretrain image size.
1243
+ self.pos_embed = self.add_weight(
1244
+ shape=[
1245
+ 1,
1246
+ self.config.image_size // self.config.patch_size,
1247
+ self.config.image_size // self.config.patch_size,
1248
+ self.config.hidden_size,
1249
+ ],
1250
+ initializer="zeros",
1251
+ trainable=True,
1252
+ name="pos_embed",
1253
+ )
1254
+
1255
+ if getattr(self, "patch_embed", None) is not None:
1256
+ with tf.name_scope(self.patch_embed.name):
1257
+ self.patch_embed.build(None)
1258
+ if getattr(self, "neck", None) is not None:
1259
+ with tf.name_scope(self.neck.name):
1260
+ self.neck.build(None)
1261
+ for layer in self.layers:
1262
+ with tf.name_scope(layer.name):
1263
+ layer.build(None)
1264
+
1265
+ def get_input_embeddings(self):
1266
+ return self.patch_embed
1267
+
1268
+ def call(
1269
+ self,
1270
+ pixel_values: tf.Tensor | None = None,
1271
+ output_attentions: Optional[bool] = None,
1272
+ output_hidden_states: Optional[bool] = None,
1273
+ return_dict: Optional[bool] = None,
1274
+ training: Optional[bool] = False,
1275
+ ) -> Union[Tuple, TFSamVisionEncoderOutput]:
1276
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1277
+ output_hidden_states = (
1278
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1279
+ )
1280
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1281
+
1282
+ if pixel_values is None:
1283
+ raise ValueError("You have to specify pixel_values")
1284
+
1285
+ hidden_states = self.patch_embed(pixel_values)
1286
+ if self.pos_embed is not None:
1287
+ hidden_states = hidden_states + self.pos_embed
1288
+
1289
+ all_hidden_states = () if output_hidden_states else None
1290
+ all_self_attentions = () if output_attentions else None
1291
+
1292
+ for i, layer_module in enumerate(self.layers):
1293
+ if output_hidden_states:
1294
+ all_hidden_states = all_hidden_states + (hidden_states,)
1295
+
1296
+ layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
1297
+
1298
+ hidden_states = layer_outputs[0]
1299
+
1300
+ if output_attentions:
1301
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1302
+
1303
+ if output_hidden_states:
1304
+ all_hidden_states = all_hidden_states + (hidden_states,)
1305
+
1306
+ hidden_states = self.neck(hidden_states)
1307
+
1308
+ if not return_dict:
1309
+ outputs = (hidden_states,)
1310
+ if output_hidden_states:
1311
+ outputs = outputs + (all_hidden_states,)
1312
+ if output_attentions:
1313
+ outputs = outputs + (all_self_attentions,)
1314
+ return outputs
1315
+
1316
+ return TFSamVisionEncoderOutput(
1317
+ last_hidden_state=hidden_states,
1318
+ hidden_states=all_hidden_states,
1319
+ attentions=all_self_attentions,
1320
+ )
1321
+
1322
+
1323
+ class TFSamPreTrainedModel(TFPreTrainedModel):
1324
+ config_class = SamConfig
1325
+ base_model_prefix = "sam"
1326
+ main_input_name = "pixel_values"
1327
+
1328
+
1329
+ SAM_START_DOCSTRING = r"""
1330
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
1331
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1332
+ etc.)
1333
+
1334
+ This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
1335
+ subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
1336
+ general usage and behavior.
1337
+
1338
+ Parameters:
1339
+ config ([`SamConfig`]): Model configuration class with all the parameters of the model.
1340
+ Initializing with a config file does not load the weights associated with the model, only the
1341
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
1342
+ """
1343
+
1344
+
1345
+ SAM_INPUTS_DOCSTRING = r"""
1346
+ Args:
1347
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
1348
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1349
+ details.
1350
+ input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
1351
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
1352
+ better results. The points can be obtained by passing a list of list of list to the processor that will
1353
+ create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
1354
+ dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
1355
+ input point), the third dimension is the number of points per segmentation mask (it is possible to pass
1356
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
1357
+ coordinates of the point. If a different number of points is passed either for each image, or for each
1358
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
1359
+ computation of the embedding will be skipped for these points using the labels.
1360
+ input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
1361
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
1362
+ official implementation, there are 3 types of labels
1363
+
1364
+ - `1`: the point is a point that contains the object of interest
1365
+ - `0`: the point is a point that does not contain the object of interest
1366
+ - `-1`: the point corresponds to the background
1367
+
1368
+ We added the label:
1369
+
1370
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
1371
+
1372
+ The padding labels should be automatically done by the processor.
1373
+ input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
1374
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
1375
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
1376
+ that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
1377
+ the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
1378
+ order (`x1`, `y1`, `x2`, `y2`):
1379
+
1380
+ - `x1`: the x coordinate of the top left point of the input box
1381
+ - `y1`: the y coordinate of the top left point of the input box
1382
+ - `x2`: the x coordinate of the bottom right point of the input box
1383
+ - `y2`: the y coordinate of the bottom right point of the input box
1384
+
1385
+ input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
1386
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
1387
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
1388
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
1389
+
1390
+ image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
1391
+ Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
1392
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
1393
+ method, and then feed them to the `call` method instead of feeding the `pixel_values`.
1394
+ multimask_output (`bool`, *optional*):
1395
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
1396
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
1397
+ "best" mask, by specifying `multimask_output=False`.
1398
+ output_attentions (`bool`, *optional*):
1399
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1400
+ tensors for more detail.
1401
+ output_hidden_states (`bool`, *optional*):
1402
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1403
+ more detail.
1404
+ return_dict (`bool`, *optional*):
1405
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1406
+ """
1407
+
1408
+
1409
+ SAM_VISION_INPUTS_DOCSTRING = r"""
1410
+ Args:
1411
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
1412
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1413
+ details.
1414
+ output_attentions (`bool`, *optional*):
1415
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1416
+ tensors for more detail.
1417
+ output_hidden_states (`bool`, *optional*):
1418
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1419
+ more detail.
1420
+ return_dict (`bool`, *optional*):
1421
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1422
+ """
1423
+
1424
+
1425
+ @add_start_docstrings(
1426
+ """The vision model from Sam without any head or projection on top.""",
1427
+ SAM_START_DOCSTRING,
1428
+ )
1429
+ class TFSamVisionModel(TFSamPreTrainedModel):
1430
+ config_class = SamVisionConfig
1431
+ main_input_name = "pixel_values"
1432
+
1433
+ def __init__(self, config: SamVisionConfig, **kwargs):
1434
+ super().__init__(config, **kwargs)
1435
+ self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")
1436
+
1437
+ def build(self, input_shape=None):
1438
+ if self.built:
1439
+ return
1440
+ self.built = True
1441
+ if getattr(self, "vision_encoder", None) is not None:
1442
+ with tf.name_scope(self.vision_encoder.name):
1443
+ self.vision_encoder.build(None)
1444
+
1445
+ def get_input_embeddings(self):
1446
+ return self.vision_encoder.patch_embed
1447
+
1448
+ @unpack_inputs
1449
+ @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
1450
+ @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
1451
+ def call(
1452
+ self,
1453
+ pixel_values: TFModelInputType | None = None,
1454
+ output_attentions: bool | None = None,
1455
+ output_hidden_states: bool | None = None,
1456
+ return_dict: bool | None = None,
1457
+ training: bool = False,
1458
+ **kwargs,
1459
+ ) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
1460
+ r"""
1461
+ Returns:
1462
+
1463
+ """
1464
+ return self.vision_encoder(
1465
+ pixel_values,
1466
+ output_attentions=output_attentions,
1467
+ output_hidden_states=output_hidden_states,
1468
+ return_dict=return_dict,
1469
+ training=training,
1470
+ )
1471
+
1472
+
1473
+ @add_start_docstrings(
1474
+ "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
1475
+ " optional 2D location and bounding boxes.",
1476
+ SAM_START_DOCSTRING,
1477
+ )
1478
+ class TFSamModel(TFSamPreTrainedModel):
1479
+ _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
1480
+
1481
+ def __init__(self, config, **kwargs):
1482
+ super().__init__(config, **kwargs)
1483
+ self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
1484
+
1485
+ self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
1486
+ self.prompt_encoder = TFSamPromptEncoder(
1487
+ config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
1488
+ )
1489
+ self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
1490
+ self.config = config
1491
+
1492
+ def get_input_embeddings(self):
1493
+ return self.vision_encoder.get_input_embeddings()
1494
+
1495
+ def get_image_wide_positional_embeddings(self):
1496
+ size = self.config.prompt_encoder_config.image_embedding_size
1497
+ grid = tf.ones((size, size))
1498
+ y_embed = tf.math.cumsum(grid, axis=0) - 0.5
1499
+ x_embed = tf.math.cumsum(grid, axis=1) - 0.5
1500
+ y_embed = y_embed / size
1501
+ x_embed = x_embed / size
1502
+
1503
+ positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
1504
+ return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width
1505
+
1506
+ def get_image_embeddings(
1507
+ self,
1508
+ pixel_values,
1509
+ output_attentions: Optional[bool] = None,
1510
+ output_hidden_states: Optional[bool] = None,
1511
+ return_dict: Optional[bool] = None,
1512
+ ):
1513
+ r"""
1514
+ Returns the image embeddings by passing the pixel values through the vision encoder.
1515
+
1516
+ Args:
1517
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
1518
+ Input pixel values
1519
+ output_attentions (`bool`, *optional*):
1520
+ Whether or not to return the attentions tensors of all attention layers.
1521
+ output_hidden_states (`bool`, *optional*):
1522
+ Whether or not to return the hidden states of all layers.
1523
+ return_dict (`bool`, *optional*):
1524
+ Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
1525
+
1526
+ """
1527
+ vision_output = self.vision_encoder(
1528
+ pixel_values,
1529
+ output_attentions=output_attentions,
1530
+ output_hidden_states=output_hidden_states,
1531
+ return_dict=return_dict,
1532
+ )
1533
+ image_embeddings = vision_output[0]
1534
+ return image_embeddings
1535
+
1536
+ def get_prompt_embeddings(
1537
+ self,
1538
+ input_points: tf.Tensor | None = None,
1539
+ input_labels: tf.Tensor | None = None,
1540
+ input_boxes: tf.Tensor | None = None,
1541
+ input_masks: tf.Tensor | None = None,
1542
+ ):
1543
+ r"""
1544
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
1545
+
1546
+ Args:
1547
+ input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
1548
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
1549
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
1550
+ point. The model will output `point_batch_size` times 3 masks in total.
1551
+ input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
1552
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
1553
+ processor, or can be fed by the user.
1554
+ input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
1555
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
1556
+ processor. users can also pass manually the input boxes.
1557
+ input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
1558
+ Optional input masks for the prompt encoder.
1559
+ """
1560
+ prompt_output = self.prompt_encoder(
1561
+ input_points=input_points,
1562
+ input_labels=input_labels,
1563
+ input_boxes=input_boxes,
1564
+ input_masks=input_masks,
1565
+ )
1566
+ return prompt_output
1567
+
1568
+ @unpack_inputs
1569
+ @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
1570
+ def call(
1571
+ self,
1572
+ pixel_values: TFModelInputType | None = None,
1573
+ input_points: tf.Tensor | None = None,
1574
+ input_labels: tf.Tensor | None = None,
1575
+ input_boxes: tf.Tensor | None = None,
1576
+ input_masks: tf.Tensor | None = None,
1577
+ image_embeddings: tf.Tensor | None = None,
1578
+ multimask_output: bool = True,
1579
+ output_attentions: bool | None = None,
1580
+ output_hidden_states: bool | None = None,
1581
+ return_dict: bool | None = None,
1582
+ training: bool = False,
1583
+ **kwargs,
1584
+ ) -> TFSamImageSegmentationOutput | Tuple[tf.Tensor]:
1585
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1586
+ output_hidden_states = (
1587
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1588
+ )
1589
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1590
+
1591
+ if pixel_values is None and image_embeddings is None:
1592
+ raise ValueError("Either pixel_values or image_embeddings must be provided.")
1593
+
1594
+ if pixel_values is not None and image_embeddings is not None:
1595
+ raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
1596
+
1597
+ if input_points is not None and len(input_points.shape) != 4:
1598
+ raise ValueError(
1599
+ "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
1600
+ " got {}.".format(input_points.shape),
1601
+ )
1602
+ if input_boxes is not None and len(input_boxes.shape) != 3:
1603
+ raise ValueError(
1604
+ "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
1605
+ " got {}.".format(input_boxes.shape),
1606
+ )
1607
+ if input_points is not None and input_boxes is not None:
1608
+ point_batch_size = shape_list(input_points)[1]
1609
+ box_batch_size = shape_list(input_boxes)[1]
1610
+ if point_batch_size != box_batch_size:
1611
+ raise ValueError(
1612
+ "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
1613
+ point_batch_size, box_batch_size
1614
+ )
1615
+ )
1616
+ if pixel_values is not None:
1617
+ # Ensures that later checks pass even with an all-None shape from the serving signature
1618
+ pixel_values = tf.ensure_shape(
1619
+ pixel_values,
1620
+ [
1621
+ None,
1622
+ self.config.vision_config.num_channels,
1623
+ self.config.vision_config.image_size,
1624
+ self.config.vision_config.image_size,
1625
+ ],
1626
+ )
1627
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
1628
+ # repeat with batch size
1629
+ batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
1630
+ image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)
1631
+
1632
+ vision_attentions = None
1633
+ vision_hidden_states = None
1634
+
1635
+ if pixel_values is not None:
1636
+ vision_outputs = self.vision_encoder(
1637
+ pixel_values,
1638
+ output_attentions=output_attentions,
1639
+ output_hidden_states=output_hidden_states,
1640
+ return_dict=True,
1641
+ training=training,
1642
+ )
1643
+ image_embeddings = vision_outputs["last_hidden_state"]
1644
+
1645
+ if output_hidden_states:
1646
+ vision_hidden_states = vision_outputs["hidden_states"]
1647
+ if output_attentions:
1648
+ vision_attentions = vision_outputs["attentions"]
1649
+
1650
+ if input_points is not None and input_labels is None:
1651
+ input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)
1652
+
1653
+ if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
1654
+ raise ValueError(
1655
+ "The batch size of the image embeddings and the input points must be the same. ",
1656
+ "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
1657
+ " if you want to pass multiple points for the same image, make sure that you passed ",
1658
+ " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
1659
+ " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
1660
+ )
1661
+
1662
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1663
+ batch_size=shape_list(image_embeddings)[0],
1664
+ input_points=input_points,
1665
+ input_labels=input_labels,
1666
+ input_boxes=input_boxes,
1667
+ input_masks=input_masks,
1668
+ )
1669
+
1670
+ low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
1671
+ image_embeddings=image_embeddings,
1672
+ image_positional_embeddings=image_positional_embeddings,
1673
+ sparse_prompt_embeddings=sparse_embeddings,
1674
+ dense_prompt_embeddings=dense_embeddings,
1675
+ multimask_output=multimask_output,
1676
+ output_attentions=output_attentions,
1677
+ )
1678
+
1679
+ if not return_dict:
1680
+ output = (iou_predictions, low_res_masks)
1681
+ if output_hidden_states:
1682
+ output = output + (vision_hidden_states,)
1683
+
1684
+ if output_attentions:
1685
+ output = output + (vision_attentions, mask_decoder_attentions)
1686
+ return output
1687
+
1688
+ return TFSamImageSegmentationOutput(
1689
+ iou_scores=iou_predictions,
1690
+ pred_masks=low_res_masks,
1691
+ vision_hidden_states=vision_hidden_states,
1692
+ vision_attentions=vision_attentions,
1693
+ mask_decoder_attentions=mask_decoder_attentions,
1694
+ )
1695
+
1696
+ def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
1697
+ hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
1698
+ attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
1699
+
1700
+ return TFSamImageSegmentationOutput(
1701
+ iou_scores=output.iou_scores,
1702
+ pred_masks=output.pred_masks,
1703
+ vision_hidden_states=hs if self.config.output_hidden_states else None,
1704
+ vision_attentions=attns if self.config.output_attentions else None,
1705
+ mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
1706
+ )
1707
+
1708
+ def build(self, input_shape=None):
1709
+ if self.built:
1710
+ return
1711
+ self.built = True
1712
+ if getattr(self, "shared_image_embedding", None) is not None:
1713
+ with tf.name_scope(self.shared_image_embedding.name):
1714
+ self.shared_image_embedding.build(None)
1715
+ if getattr(self, "vision_encoder", None) is not None:
1716
+ with tf.name_scope(self.vision_encoder.name):
1717
+ self.vision_encoder.build(None)
1718
+ if getattr(self, "prompt_encoder", None) is not None:
1719
+ with tf.name_scope(self.prompt_encoder.name):
1720
+ self.prompt_encoder.build(None)
1721
+ if getattr(self, "mask_decoder", None) is not None:
1722
+ with tf.name_scope(self.mask_decoder.name):
1723
+ self.mask_decoder.build(None)
1724
+
1725
+
1726
+ __all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]
docs/transformers/build/lib/transformers/models/seamless_m4t/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_seamless_m4t import *
22
+ from .feature_extraction_seamless_m4t import *
23
+ from .modeling_seamless_m4t import *
24
+ from .processing_seamless_m4t import *
25
+ from .tokenization_seamless_m4t import *
26
+ from .tokenization_seamless_m4t_fast import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/seamless_m4t/configuration_seamless_m4t.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SeamlessM4T model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class SeamlessM4TConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`~SeamlessM4TModel`]. It is used to instantiate an
27
+ SeamlessM4T model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the SeamlessM4T
29
+ ["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 256102):
37
+ Vocabulary size of the SeamlessM4T model. Defines the number of different tokens that can be represented by
38
+ the `inputs_ids` passed when calling [`~SeamlessM4TModel`], [`~SeamlessM4TForTextToSpeech`] or
39
+ [`~SeamlessM4TForTextToText`].
40
+ t2u_vocab_size (`int`, *optional*, defaults to 10082):
41
+ Unit vocabulary size of the SeamlessM4T model. Defines the number of different unit tokens that can be
42
+ represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4TModel`],
43
+ [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`].
44
+
45
+ > Parameters shared across sub-models
46
+
47
+ hidden_size (`int`, *optional*, defaults to 1024):
48
+ Dimensionality of the "intermediate" layers in the architecture.
49
+ initializer_range (`float`, *optional*, defaults to 0.02):
50
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
52
+ The epsilon used by the layer normalization layers.
53
+ use_cache (`bool`, *optional*, defaults to `True`):
54
+ Whether or not the model should return the last key/values attentions (not used by all models).
55
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
56
+ The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set
57
+ this to something large just in case (e.g., 512 or 1024 or 2048).
58
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
59
+ Whether the model is used as an encoder/decoder or not.
60
+ encoder_layerdrop (`float`, *optional*, defaults to 0.05):
61
+ The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
62
+ for more details.
63
+ decoder_layerdrop (`float`, *optional*, defaults to 0.05):
64
+ The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
65
+ for more details.
66
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
67
+ The non-linear activation function (function or string) in the decoder and feed-forward layers. If string,
68
+ `"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
69
+ dropout (`float`, *optional*, defaults to 0.1):
70
+ The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler.
71
+ attention_dropout (`float`, *optional*, defaults to 0.1):
72
+ The dropout probability for all attention layers.
73
+ activation_dropout (`float`, *optional*, defaults to 0.0):
74
+ The dropout probability for all activation layers in the model.
75
+ scale_embedding (`bool`, *optional*, defaults to `True`):
76
+ Scale embeddings by diving by sqrt(d_model).
77
+
78
+ > Text encoder and text decoder specific parameters
79
+
80
+ encoder_layers (`int`, *optional*, defaults to 24):
81
+ Number of hidden layers in the Transformer text encoder.
82
+ encoder_ffn_dim (`int`, *optional*, defaults to 8192):
83
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder.
84
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
85
+ Number of attention heads for each attention layer in the Transformer text encoder.
86
+ decoder_layers (`int`, *optional*, defaults to 24):
87
+ Number of hidden layers in the Transformer text decoder.
88
+ decoder_ffn_dim (`int`, *optional*, defaults to 8192):
89
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder.
90
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
91
+ Number of attention heads for each attention layer in the Transformer text decoder.
92
+ decoder_start_token_id (`int`, *optional*, defaults to 3):
93
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only
94
+ applied in the text decoder.
95
+ max_new_tokens (`int`, *optional*, defaults to 256):
96
+ The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt.
97
+ pad_token_id (`int`, *optional*, defaults to 0):
98
+ The id of the _padding_ text token. Only applied to the text-decoder model.
99
+ bos_token_id (`int`, *optional*, defaults to 2):
100
+ The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model.
101
+ eos_token_id (`int`, *optional*, defaults to 3):
102
+ The id of the _end-of-stream_ text token. Only applied to the text-decoder model.
103
+
104
+ > Speech encoder specific parameters
105
+
106
+ speech_encoder_layers (`int`, *optional*, defaults to 24):
107
+ Number of hidden layers in the Transformer speech encoder.
108
+ speech_encoder_attention_heads (`int`, *optional*, defaults to 16):
109
+ Number of attention heads for each attention layer in the Transformer speech encoder.
110
+ speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096):
111
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder.
112
+ speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`):
113
+ The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`,
114
+ `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
115
+ speech_encoder_dropout (`float`, *optional*, defaults to 0.0):
116
+ The dropout probability for all layers in the speech encoder.
117
+ add_adapter (`bool`, *optional*, defaults to `True`):
118
+ Add an adapter layer on top of the speech encoder.
119
+ speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1):
120
+ The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see
121
+ https://arxiv.org/abs/1909.11556) for more details.
122
+ feature_projection_input_dim (`int`, *optional*, defaults to 160):
123
+ Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing
124
+ input audios with [`SeamlessM4TFeatureExtractor`].
125
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
126
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
127
+ embeddings layer of the speech encoder.
128
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
129
+ Number of groups of 1D convolutional positional embeddings layer of the speech encoder.
130
+ adaptor_kernel_size (`int`, *optional*, defaults to 8):
131
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
132
+ adaptor_stride (`int`, *optional*, defaults to 8):
133
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
134
+ adaptor_dropout (`float`, *optional*, defaults to 0.1):
135
+ The dropout probability for all layers in the speech adapter.
136
+ num_adapter_layers (`int`, *optional*, defaults to 1):
137
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
138
+ True`.
139
+ position_embeddings_type (`str`, *optional*, defaults to `"relative"`):
140
+ Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
141
+ `None` no relative position embedding is applied. Only applied to the speech encoder.
142
+ rotary_embedding_base (`int`, *optional*, defaults to 10000):
143
+ If `"rotary"` position embeddings are used, defines the size of the embedding base. Only applied to the
144
+ speech encoder.
145
+ max_source_positions (`int`, *optional*, defaults to 4096):
146
+ if `"relative"` position embeddings are used, defines the maximum source input positions. Only applied to
147
+ the speech encoder.
148
+ conv_depthwise_kernel_size (`int`, *optional*, defaults to 31):
149
+ Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder.
150
+
151
+ > Text-To-Unit (t2u) model specific parameters
152
+
153
+ t2u_bos_token_id (`int`, *optional*, defaults to 0):
154
+ The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
155
+ t2u_pad_token_id (`int`, *optional*, defaults to 1):
156
+ The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model.
157
+ t2u_eos_token_id (`int`, *optional*, defaults to 2):
158
+ The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
159
+ t2u_decoder_start_token_id (`int`, *optional*, defaults to 2):
160
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only
161
+ applied to the text-to-unit seq2seq model.
162
+ t2u_max_new_tokens (`int`, *optional*, defaults to 1024):
163
+ The maximum numbers of unit tokens to generate, ignoring the number of tokens in the prompt. Only applied
164
+ to the text-to-unit seq2seq model.
165
+ t2u_encoder_layers (`int`, *optional*, defaults to 6):
166
+ Number of hidden layers in the Transformer text-to-unit encoder.
167
+ t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192):
168
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder.
169
+ t2u_encoder_attention_heads (`int`, *optional*, defaults to 16):
170
+ Number of attention heads for each attention layer in the Transformer text-to-unit encoder.
171
+ t2u_decoder_layers (`int`, *optional*, defaults to 6):
172
+ Number of hidden layers in the Transformer text-to-unit decoder.
173
+ t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192):
174
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder.
175
+ t2u_decoder_attention_heads (`int`, *optional*, defaults to 16):
176
+ Number of attention heads for each attention layer in the Transformer text-to-unit decoder.
177
+ t2u_max_position_embeddings (`int`, *optional*, defaults to 2048):
178
+ The maximum sequence length that this model text-to-unit component might ever be used with. Typically set
179
+ this to something large just in case (e.g., 512 or 1024 or 2048).
180
+
181
+ > Hifi-Gan Vocoder specific parameters
182
+
183
+ sampling_rate (`int`, *optional*, defaults to 16000):
184
+ The sampling rate at which the output audio will be generated, expressed in hertz (Hz).
185
+ upsample_initial_channel (`int`, *optional*, defaults to 512):
186
+ The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only.
187
+ upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`):
188
+ A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network.
189
+ The length of *upsample_rates* defines the number of convolutional layers and has to match the length of
190
+ *upsample_kernel_sizes*. Applies to the vocoder only.
191
+ upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`):
192
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling
193
+ network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match
194
+ the length of *upsample_rates*. Applies to the vocoder only.
195
+ resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
196
+ A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive
197
+ field fusion (MRF) module. Applies to the vocoder only.
198
+ resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
199
+ A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in
200
+ the multi-receptive field fusion (MRF) module. Applies to the vocoder only.
201
+ leaky_relu_slope (`float`, *optional*, defaults to 0.1):
202
+ The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder
203
+ only.
204
+ unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000):
205
+ Vocabulary size of the SeamlessM4T vocoder. Defines the number of different unit tokens that can be
206
+ represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4TModel`],
207
+ [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`].
208
+ unit_embed_dim (`int`, *optional*, defaults to 1280):
209
+ The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only.
210
+ lang_embed_dim (`int`, *optional*, defaults to 256):
211
+ The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only.
212
+ spkr_embed_dim (`int`, *optional*, defaults to 256):
213
+ The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only.
214
+ vocoder_num_langs (`int`, *optional*, defaults to 36):
215
+ Number of langs supported by the vocoder. Might be different from `t2u_num_langs`.
216
+ vocoder_num_spkrs (`int`, *optional*, defaults to 200):
217
+ Number of speakers supported by the vocoder.
218
+ variance_predictor_kernel_size (`int`, *optional*, defaults to 3):
219
+ Kernel size of the duration predictor. Applies to the vocoder only.
220
+ var_pred_dropout (`float`, *optional*, defaults to 0.5):
221
+ The dropout probability of the duration predictor. Applies to the vocoder only.
222
+ vocoder_offset (`int`, *optional*, defaults to 4):
223
+ Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only.
224
+
225
+ ```python
226
+ >>> from transformers import SeamlessM4TModel, SeamlessM4TConfig
227
+
228
+ >>> # Initializing a SeamlessM4T "facebook/hf-seamless-m4t-medium" style configuration
229
+ >>> configuration = SeamlessM4TConfig()
230
+
231
+ >>> # Initializing a model from the "facebook/hf-seamless-m4t-medium" style configuration
232
+ >>> model = SeamlessM4TModel(configuration)
233
+
234
+ >>> # Accessing the model configuration
235
+ >>> configuration = model.config
236
+ ```"""
237
+
238
+ model_type = "seamless_m4t"
239
+
240
+ def __init__(
241
+ self,
242
+ vocab_size=256102,
243
+ t2u_vocab_size=10082,
244
+ # shared config
245
+ hidden_size=1024,
246
+ initializer_range=0.02,
247
+ layer_norm_eps=1e-5,
248
+ use_cache=True,
249
+ max_position_embeddings=1024,
250
+ is_encoder_decoder=True,
251
+ encoder_layerdrop=0.05,
252
+ decoder_layerdrop=0.05,
253
+ activation_function="relu",
254
+ dropout=0.1,
255
+ attention_dropout=0.1,
256
+ activation_dropout=0.0,
257
+ scale_embedding=True,
258
+ # text encoder|decoder
259
+ encoder_layers=24,
260
+ encoder_ffn_dim=8192,
261
+ encoder_attention_heads=16,
262
+ decoder_layers=24,
263
+ decoder_ffn_dim=8192,
264
+ decoder_attention_heads=16,
265
+ decoder_start_token_id=3,
266
+ max_new_tokens=256,
267
+ pad_token_id=0,
268
+ bos_token_id=2,
269
+ eos_token_id=3,
270
+ # speech_encoder
271
+ speech_encoder_layers=24,
272
+ speech_encoder_attention_heads=16,
273
+ speech_encoder_intermediate_size=4096,
274
+ speech_encoder_hidden_act="swish",
275
+ speech_encoder_dropout=0.0,
276
+ add_adapter=True,
277
+ speech_encoder_layerdrop=0.1,
278
+ feature_projection_input_dim=160,
279
+ num_conv_pos_embeddings=128,
280
+ num_conv_pos_embedding_groups=16,
281
+ adaptor_kernel_size=8,
282
+ adaptor_stride=8,
283
+ adaptor_dropout=0.1,
284
+ num_adapter_layers=1,
285
+ position_embeddings_type="relative",
286
+ rotary_embedding_base=10000,
287
+ max_source_positions=4096,
288
+ conv_depthwise_kernel_size=31,
289
+ # t2u config
290
+ t2u_bos_token_id=0,
291
+ t2u_pad_token_id=1,
292
+ t2u_eos_token_id=2,
293
+ t2u_decoder_start_token_id=2,
294
+ t2u_max_new_tokens=1024,
295
+ t2u_encoder_layers=6,
296
+ t2u_encoder_ffn_dim=8192,
297
+ t2u_encoder_attention_heads=16,
298
+ t2u_decoder_layers=6,
299
+ t2u_decoder_ffn_dim=8192,
300
+ t2u_decoder_attention_heads=16,
301
+ t2u_max_position_embeddings=2048,
302
+ # hifi-gan vocoder config
303
+ sampling_rate=16000,
304
+ upsample_initial_channel=512,
305
+ upsample_rates=[5, 4, 4, 2, 2],
306
+ upsample_kernel_sizes=[11, 8, 8, 4, 4],
307
+ resblock_kernel_sizes=[3, 7, 11],
308
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
309
+ leaky_relu_slope=0.1,
310
+ # specific to Code Hifi-Gan
311
+ unit_hifi_gan_vocab_size=10000,
312
+ unit_embed_dim=1280,
313
+ lang_embed_dim=256,
314
+ spkr_embed_dim=256,
315
+ vocoder_num_langs=36,
316
+ vocoder_num_spkrs=200,
317
+ variance_predictor_kernel_size=3,
318
+ var_pred_dropout=0.5,
319
+ vocoder_offset=4,
320
+ **kwargs,
321
+ ):
322
+ # overall_config
323
+ self.vocab_size = vocab_size
324
+ self.t2u_vocab_size = t2u_vocab_size
325
+ self.hidden_size = hidden_size
326
+ self.initializer_range = initializer_range
327
+ self.layer_norm_eps = layer_norm_eps
328
+ self.max_position_embeddings = max_position_embeddings
329
+ self.use_cache = use_cache
330
+ self.max_new_tokens = max_new_tokens
331
+ self.encoder_layerdrop = encoder_layerdrop
332
+ self.decoder_layerdrop = decoder_layerdrop
333
+ self.activation_function = activation_function
334
+ self.dropout = dropout
335
+ self.attention_dropout = attention_dropout
336
+ self.activation_dropout = activation_dropout
337
+ self.scale_embedding = scale_embedding
338
+ # for proper config init
339
+ self.num_attention_heads = decoder_attention_heads
340
+ self.num_hidden_layers = decoder_layers
341
+
342
+ # text|unit encoder|decoder
343
+ self.encoder_layers = encoder_layers
344
+ self.encoder_ffn_dim = encoder_ffn_dim
345
+ self.encoder_attention_heads = encoder_attention_heads
346
+ self.decoder_layers = decoder_layers
347
+ self.decoder_ffn_dim = decoder_ffn_dim
348
+ self.decoder_attention_heads = decoder_attention_heads
349
+
350
+ # speech_encoder
351
+ self.speech_encoder_layers = speech_encoder_layers
352
+ self.speech_encoder_hidden_act = speech_encoder_hidden_act
353
+ self.speech_encoder_dropout = speech_encoder_dropout
354
+ self.speech_encoder_attention_heads = speech_encoder_attention_heads
355
+ self.speech_encoder_layerdrop = speech_encoder_layerdrop
356
+ self.speech_encoder_intermediate_size = speech_encoder_intermediate_size
357
+ self.feature_projection_input_dim = feature_projection_input_dim
358
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
359
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
360
+ self.adaptor_kernel_size = adaptor_kernel_size
361
+ self.adaptor_stride = adaptor_stride
362
+ self.adaptor_dropout = adaptor_dropout
363
+ self.num_adapter_layers = num_adapter_layers
364
+ self.position_embeddings_type = position_embeddings_type
365
+ self.rotary_embedding_base = rotary_embedding_base
366
+ self.max_source_positions = max_source_positions
367
+ self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
368
+ self.add_adapter = add_adapter
369
+
370
+ # t2u config
371
+ self.t2u_bos_token_id = t2u_bos_token_id
372
+ self.t2u_pad_token_id = t2u_pad_token_id
373
+ self.t2u_eos_token_id = t2u_eos_token_id
374
+ self.t2u_decoder_start_token_id = t2u_decoder_start_token_id
375
+ self.t2u_max_new_tokens = t2u_max_new_tokens
376
+ self.t2u_encoder_layers = t2u_encoder_layers
377
+ self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim
378
+ self.t2u_encoder_attention_heads = t2u_encoder_attention_heads
379
+ self.t2u_decoder_layers = t2u_decoder_layers
380
+ self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim
381
+ self.t2u_decoder_attention_heads = t2u_decoder_attention_heads
382
+ self.t2u_max_position_embeddings = t2u_max_position_embeddings
383
+
384
+ # hifi-gan vocoder config
385
+ # original parameters specific to Hifi-Gan
386
+ self.sampling_rate = sampling_rate
387
+ self.upsample_initial_channel = upsample_initial_channel
388
+ self.upsample_rates = upsample_rates
389
+ self.upsample_kernel_sizes = upsample_kernel_sizes
390
+ self.resblock_kernel_sizes = resblock_kernel_sizes
391
+ self.resblock_dilation_sizes = resblock_dilation_sizes
392
+ self.leaky_relu_slope = leaky_relu_slope
393
+
394
+ # specific to Code Hifi-Gan
395
+ self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size
396
+ self.unit_embed_dim = unit_embed_dim
397
+ self.lang_embed_dim = lang_embed_dim
398
+ self.spkr_embed_dim = spkr_embed_dim
399
+ self.vocoder_num_langs = vocoder_num_langs
400
+ self.vocoder_num_spkrs = vocoder_num_spkrs
401
+ self.variance_predictor_kernel_size = variance_predictor_kernel_size
402
+ self.var_pred_dropout = var_pred_dropout
403
+ self.vocoder_offset = vocoder_offset
404
+
405
+ super().__init__(
406
+ pad_token_id=pad_token_id,
407
+ bos_token_id=bos_token_id,
408
+ eos_token_id=eos_token_id,
409
+ decoder_start_token_id=decoder_start_token_id,
410
+ is_encoder_decoder=is_encoder_decoder,
411
+ max_position_embeddings=max_position_embeddings,
412
+ **kwargs,
413
+ )
414
+
415
+
416
+ __all__ = ["SeamlessM4TConfig"]
docs/transformers/build/lib/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for SeamlessM4T
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from ...utils import is_torch_available
24
+
25
+
26
+ if is_torch_available():
27
+ import torch
28
+
29
+ from ...audio_utils import mel_filter_bank, spectrogram, window_function
30
+ from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
31
+ from ...feature_extraction_utils import BatchFeature
32
+ from ...utils import PaddingStrategy, TensorType, logging
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
39
+ r"""
40
+ Constructs a SeamlessM4T feature extractor.
41
+
42
+ This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users
43
+ should refer to this superclass for more information regarding those methods.
44
+
45
+ This class extracts mel-filter bank features from raw speech.
46
+
47
+ Args:
48
+ feature_size (`int`, *optional*, defaults to 80):
49
+ The feature dimension of the extracted features.
50
+ sampling_rate (`int`, *optional*, defaults to 16000):
51
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
52
+ num_mel_bins (`int`, *optional*, defaults to 80):
53
+ Number of Mel-frequency bins.
54
+ padding_value (`float`, *optional*, defaults to 0.0):
55
+ The value that is used to fill the padding vectors.
56
+ stride (`int`, *optional*, defaults to 2):
57
+ Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to
58
+ (batch_size,num_frames//stride,num_mel_bins*stride).
59
+ """
60
+
61
+ model_input_names = ["input_features", "attention_mask"]
62
+
63
+ def __init__(
64
+ self,
65
+ feature_size=80,
66
+ sampling_rate=16000,
67
+ num_mel_bins=80,
68
+ padding_value=0.0,
69
+ stride=2,
70
+ **kwargs,
71
+ ):
72
+ self.num_mel_bins = num_mel_bins
73
+ self.return_attention_mask = True
74
+ self.stride = stride
75
+
76
+ mel_filters = mel_filter_bank(
77
+ num_frequency_bins=257,
78
+ num_mel_filters=self.num_mel_bins,
79
+ min_frequency=20,
80
+ max_frequency=sampling_rate // 2,
81
+ sampling_rate=sampling_rate,
82
+ norm=None,
83
+ mel_scale="kaldi",
84
+ triangularize_in_mel_space=True,
85
+ )
86
+
87
+ self.mel_filters = mel_filters
88
+ self.window = window_function(400, "povey", periodic=False)
89
+
90
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
91
+
92
+ @staticmethod
93
+ # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
94
+ def zero_mean_unit_var_norm(
95
+ input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
96
+ ) -> List[np.ndarray]:
97
+ """
98
+ Every array in the list is normalized to have zero mean and unit variance
99
+ """
100
+ if attention_mask is not None:
101
+ attention_mask = np.array(attention_mask, np.int32)
102
+ normed_input_values = []
103
+
104
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
105
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
106
+ if length < normed_slice.shape[0]:
107
+ normed_slice[length:] = padding_value
108
+
109
+ normed_input_values.append(normed_slice)
110
+ else:
111
+ normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
112
+
113
+ return normed_input_values
114
+
115
+ def _extract_fbank_features(
116
+ self,
117
+ waveform: np.ndarray,
118
+ ) -> np.ndarray:
119
+ """
120
+ Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
121
+ and hence the waveform should not be normalized before feature extraction.
122
+ """
123
+ # by default, it extracts the left channel if stereo
124
+ if len(waveform.shape) == 2:
125
+ waveform = waveform[0]
126
+
127
+ waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers
128
+ features = spectrogram(
129
+ waveform,
130
+ self.window,
131
+ frame_length=400,
132
+ hop_length=160,
133
+ fft_length=512,
134
+ power=2.0,
135
+ center=False,
136
+ preemphasis=0.97,
137
+ mel_filters=self.mel_filters,
138
+ log_mel="log",
139
+ mel_floor=1.192092955078125e-07,
140
+ remove_dc_offset=True,
141
+ ).T
142
+ return features
143
+
144
+ def __call__(
145
+ self,
146
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
147
+ padding: Union[bool, str, PaddingStrategy] = True,
148
+ pad_to_multiple_of: Optional[int] = 2,
149
+ max_length: Optional[int] = None,
150
+ truncation: bool = False,
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ sampling_rate: Optional[int] = None,
153
+ return_attention_mask: Optional[bool] = None,
154
+ do_normalize_per_mel_bins: Optional[bool] = True,
155
+ **kwargs,
156
+ ) -> BatchFeature:
157
+ """
158
+ Main method to featurize and prepare for the model one or several sequence(s).
159
+
160
+ Args:
161
+ raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`,
162
+ `List[List[float]]`, `List[List[List[float]]]`):
163
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array,
164
+ a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors,
165
+ a list of list of float values or a list of a list of list of float values.
166
+ If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is
167
+ considered a single-channel, single-sample sound. In all other cases, the first dimension of
168
+ `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`,
169
+ corresponds to the number of samples in the batch, and the number of channels
170
+ (i.e. mono or stereo character) is derived from the other dimensions
171
+ (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
172
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
173
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
174
+ index) among:
175
+
176
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
177
+ sequence if provided).
178
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
179
+ acceptable input length for the model if that argument is not provided.
180
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
181
+ lengths).
182
+ pad_to_multiple_of (`int`, *optional*, defaults to 2):
183
+ If set will pad the sequence to a multiple of the provided value.
184
+
185
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
186
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
187
+ max_length (`int`, *optional*):
188
+ Maximum length of the returned list and optionally padding length (see above).
189
+ truncation (`bool`):
190
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
191
+ return_attention_mask (`bool`, *optional*):
192
+ Whether to return the attention mask. If left to the default, will return the attention mask according
193
+ to the specific feature_extractor's default.
194
+
195
+ [What are attention masks?](../glossary#attention-mask)
196
+
197
+ <Tip>
198
+
199
+ For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle
200
+ bugs.
201
+
202
+ </Tip>
203
+
204
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
205
+ If set, will return tensors instead of list of python integers. Acceptable values are:
206
+
207
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
208
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
209
+ - `'np'`: Return Numpy `np.ndarray` objects.
210
+ sampling_rate (`int`, *optional*):
211
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
212
+ `sampling_rate` at the forward call to prevent silent errors.
213
+ do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`):
214
+ Whether or not to zero-mean unit-variance normalize the input per mel-channel.
215
+ kwargs (*optional*):
216
+ Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature
217
+ extractor.
218
+ """
219
+ if sampling_rate is not None:
220
+ if sampling_rate != self.sampling_rate:
221
+ raise ValueError(
222
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
223
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
224
+ f" {self.sampling_rate} and not {sampling_rate}."
225
+ )
226
+ else:
227
+ logger.warning(
228
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
229
+ "Failing to do so can result in silent errors that might be hard to debug."
230
+ )
231
+
232
+ return_attention_mask = (
233
+ return_attention_mask if return_attention_mask is not None else self.return_attention_mask
234
+ )
235
+
236
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
237
+ if is_batched_numpy and len(raw_speech.shape) > 3:
238
+ raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
239
+
240
+ acceptable_types = (
241
+ (torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list)
242
+ )
243
+ is_batched = is_batched_numpy or (
244
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types))
245
+ )
246
+
247
+ if is_batched:
248
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
249
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
250
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
251
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
252
+ raw_speech = raw_speech.astype(np.float32)
253
+
254
+ # always return batch
255
+ if not is_batched:
256
+ raw_speech = [raw_speech]
257
+
258
+ # extract fbank features
259
+ features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
260
+
261
+ if do_normalize_per_mel_bins:
262
+ # torch defaults to ddof=1, and numpy defaults to ddof=0
263
+ features = [
264
+ (x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7)
265
+ for x in features
266
+ ]
267
+
268
+ # convert into correct format for padding
269
+ encoded_inputs = BatchFeature({"input_features": features})
270
+
271
+ padded_inputs = self.pad(
272
+ encoded_inputs,
273
+ padding=padding,
274
+ max_length=max_length,
275
+ truncation=truncation,
276
+ pad_to_multiple_of=pad_to_multiple_of,
277
+ return_attention_mask=True,
278
+ return_tensors="np",
279
+ )
280
+
281
+ # SeamlessM4T needs to process extracted features
282
+ input_features = padded_inputs.get("input_features")
283
+ attention_mask = padded_inputs.pop("attention_mask")
284
+
285
+ batch_size, num_frames, num_channels = input_features.shape
286
+
287
+ remainder = num_frames % self.stride
288
+ if remainder != 0:
289
+ input_features = input_features[:, : num_frames - remainder, :]
290
+ attention_mask = attention_mask[:, : num_frames - remainder]
291
+
292
+ input_features = np.reshape(
293
+ input_features, (batch_size, num_frames // self.stride, num_channels * self.stride)
294
+ )
295
+
296
+ indices = np.arange(0, num_frames - remainder)
297
+ attention_mask = attention_mask[:, indices % self.stride == 1]
298
+
299
+ padded_inputs["input_features"] = input_features
300
+ if return_attention_mask:
301
+ padded_inputs["attention_mask"] = attention_mask
302
+
303
+ if return_tensors is not None:
304
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
305
+
306
+ return padded_inputs
307
+
308
+
309
+ __all__ = ["SeamlessM4TFeatureExtractor"]
docs/transformers/build/lib/transformers/models/seamless_m4t/tokenization_seamless_m4t.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for SeamlessM4T."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import sentencepiece as spm
22
+
23
+ from ...convert_slow_tokenizer import import_protobuf
24
+ from ...tokenization_utils import (
25
+ BatchEncoding,
26
+ PreTokenizedInput,
27
+ PreTrainedTokenizer,
28
+ TextInput,
29
+ )
30
+ from ...tokenization_utils_base import AddedToken
31
+ from ...utils import PaddingStrategy, logging
32
+ from ...utils.import_utils import requires
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ SPIECE_UNDERLINE = "▁"
39
+
40
+
41
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
42
+
43
+
44
+ @requires(backends=("sentencepiece",))
45
+ class SeamlessM4TTokenizer(PreTrainedTokenizer):
46
+ """
47
+ Construct a SeamlessM4T tokenizer.
48
+
49
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
50
+ [SentencePiece](https://github.com/google/sentencepiece).
51
+
52
+ The tokenization method is `<language code> <tokens> <eos>` for source language documents, and `<eos> <language
53
+ code> <tokens> <eos>` for target language documents.
54
+
55
+ Examples:
56
+
57
+ ```python
58
+ >>> from transformers import SeamlessM4TTokenizer
59
+
60
+ >>> tokenizer = SeamlessM4TTokenizer.from_pretrained(
61
+ ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra"
62
+ ... )
63
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
64
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
65
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
66
+ ```
67
+
68
+ Args:
69
+ vocab_file (`str`):
70
+ Path to the vocabulary file.
71
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
72
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
73
+
74
+ <Tip>
75
+
76
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
77
+ sequence. The token used is the `cls_token`.
78
+
79
+ </Tip>
80
+
81
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
82
+ The end of sequence token.
83
+
84
+ <Tip>
85
+
86
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
87
+ The token used is the `sep_token`.
88
+
89
+ </Tip>
90
+
91
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
92
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
93
+ sequence classification or for a text and a question for question answering. It is also used as the last
94
+ token of a sequence built with special tokens.
95
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
96
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
97
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
98
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
99
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
100
+ token instead.
101
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
102
+ The token used for padding, for example when batching sequences of different lengths.
103
+ tokenizer_file (`str`, *optional*):
104
+ The path to a tokenizer file to use instead of the vocab file.
105
+ src_lang (`str`, *optional*, defaults to `"eng"`):
106
+ The language to use as source language for translation.
107
+ tgt_lang (`str`, *optional*, defaults to `"fra"`):
108
+ The language to use as target language for translation.
109
+ sp_model_kwargs (`Dict[str, Any]`, *optional*):
110
+ Additional keyword arguments to pass to the model initialization.
111
+ additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
112
+ A tuple or a list of additional special tokens. Can be used to specify the list of languages that will be
113
+ supported by the tokenizer.
114
+ add_prefix_space (`bool`, *optional*, defaults to `True`):
115
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
116
+ other word.
117
+ """
118
+
119
+ vocab_files_names = VOCAB_FILES_NAMES
120
+ model_input_names = ["input_ids", "attention_mask"]
121
+
122
+ prefix_tokens: List[int] = []
123
+ suffix_tokens: List[int] = []
124
+
125
+ def __init__(
126
+ self,
127
+ vocab_file,
128
+ bos_token="<s>",
129
+ eos_token="</s>",
130
+ sep_token="</s>",
131
+ cls_token="<s>",
132
+ unk_token="<unk>",
133
+ pad_token="<pad>",
134
+ tokenizer_file=None,
135
+ src_lang="eng",
136
+ tgt_lang="fra",
137
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
138
+ additional_special_tokens=None,
139
+ add_prefix_space=True,
140
+ **kwargs,
141
+ ):
142
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
143
+ # Add this unused argument to keep some important Copied from statements
144
+ self.legacy = False
145
+ self.vocab_file = vocab_file
146
+
147
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
148
+
149
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
150
+ # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
151
+ # spm | '<unk>' | '<s>' | '</s>' | 'an' | 'en' | '_d' | 'er' | 'in' | '_s' | '_a'
152
+ # fairseq | '<pad>' | '<unk>' | '<s>' | '</s>' | 'an' | 'en' | '▁d' | 'er' | 'in' | '▁s'
153
+
154
+ # Mimic fairseq token-to-id alignment for the first 4 token
155
+ self._added_tokens_decoder = {
156
+ 0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token,
157
+ 1: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token,
158
+ 2: AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token,
159
+ 3: AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token,
160
+ }
161
+
162
+ # The first "real" token "an" has position 4 in the original fairseq vocab and position 3 in the spm vocab
163
+ self.fairseq_offset = 1
164
+
165
+ self.sp_model_size = len(self.sp_model)
166
+
167
+ self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang
168
+ self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang
169
+ self.add_prefix_space = add_prefix_space
170
+
171
+ super().__init__(
172
+ bos_token=bos_token,
173
+ eos_token=eos_token,
174
+ unk_token=unk_token,
175
+ sep_token=sep_token,
176
+ cls_token=cls_token,
177
+ pad_token=pad_token,
178
+ tokenizer_file=tokenizer_file,
179
+ src_lang=src_lang,
180
+ tgt_lang=tgt_lang,
181
+ additional_special_tokens=additional_special_tokens,
182
+ sp_model_kwargs=self.sp_model_kwargs,
183
+ add_prefix_space=add_prefix_space,
184
+ **kwargs,
185
+ )
186
+
187
+ self.set_src_lang_special_tokens(self._src_lang)
188
+ self.set_tgt_lang_special_tokens(self._tgt_lang)
189
+
190
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__getstate__
191
+ def __getstate__(self):
192
+ state = self.__dict__.copy()
193
+ state["sp_model"] = None
194
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
195
+ return state
196
+
197
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__setstate__
198
+ def __setstate__(self, d):
199
+ self.__dict__ = d
200
+
201
+ # for backward compatibility
202
+ if not hasattr(self, "sp_model_kwargs"):
203
+ self.sp_model_kwargs = {}
204
+
205
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
206
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
207
+
208
+ @property
209
+ def vocab_size(self):
210
+ return len(self.sp_model)
211
+
212
+ def __call__(
213
+ self,
214
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
215
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
216
+ text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
217
+ text_pair_target: Optional[
218
+ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
219
+ ] = None,
220
+ padding: Union[bool, str, PaddingStrategy] = True,
221
+ pad_to_multiple_of: Optional[int] = 2,
222
+ src_lang: Optional[str] = None,
223
+ tgt_lang: Optional[str] = None,
224
+ **kwargs,
225
+ ):
226
+ """
227
+ Args:
228
+ text (`str`, `List[str]`, `List[List[str]]`, *optional*):
229
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
230
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
231
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
232
+ text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
233
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
234
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
235
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
236
+ text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
237
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
238
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
239
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
240
+ text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
241
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
242
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
243
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
244
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
245
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
246
+ index) among:
247
+
248
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
249
+ sequence if provided).
250
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
251
+ acceptable input length for the model if that argument is not provided.
252
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
253
+ lengths).
254
+ pad_to_multiple_of (`int`, *optional*):
255
+ If set will pad the sequence to a multiple of the provided value.
256
+
257
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
258
+ `>= 7.5` (Volta).
259
+ src_lang (`str`, *optional*):
260
+ A string representing the source language. If not specified, the last `src_lang` specified (either
261
+ during initialization or when calling this tokenizer) will be used.
262
+ tgt_lang (`str`, *optional*):
263
+ A string representing the target language. If not specified, the last `tgt_lang` specified (either
264
+ during initialization or when calling this tokenizer) will be used.
265
+ kwargs (*optional*):
266
+ Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizer.__call__`].
267
+ """
268
+ if src_lang is not None:
269
+ self.src_lang = src_lang
270
+ if tgt_lang is not None:
271
+ self.tgt_lang = tgt_lang
272
+
273
+ output = super().__call__(
274
+ text=text,
275
+ text_pair=text_pair,
276
+ text_target=text_target,
277
+ text_pair_target=text_pair_target,
278
+ padding=padding,
279
+ pad_to_multiple_of=pad_to_multiple_of,
280
+ **kwargs,
281
+ )
282
+
283
+ return BatchEncoding(output, tensor_type=kwargs.get("return_tensors"))
284
+
285
+ @property
286
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang
287
+ def src_lang(self) -> str:
288
+ return self._src_lang
289
+
290
+ @src_lang.setter
291
+ def src_lang(self, new_src_lang: str) -> None:
292
+ if "__" not in new_src_lang:
293
+ self._src_lang = f"__{new_src_lang}__"
294
+ else:
295
+ self._src_lang = new_src_lang
296
+ self.set_src_lang_special_tokens(self._src_lang)
297
+
298
+ @property
299
+ def tgt_lang(self) -> str:
300
+ return self._tgt_lang
301
+
302
+ @tgt_lang.setter
303
+ def tgt_lang(self, new_tgt_lang: str) -> None:
304
+ if "__" not in new_tgt_lang:
305
+ self._tgt_lang = f"__{new_tgt_lang}__"
306
+ else:
307
+ self._tgt_lang = new_tgt_lang
308
+ self.set_tgt_lang_special_tokens(self._tgt_lang)
309
+
310
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.get_special_tokens_mask
311
+ def get_special_tokens_mask(
312
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
313
+ ) -> List[int]:
314
+ """
315
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
316
+ special tokens using the tokenizer `prepare_for_model` method.
317
+
318
+ Args:
319
+ token_ids_0 (`List[int]`):
320
+ List of IDs.
321
+ token_ids_1 (`List[int]`, *optional*):
322
+ Optional second list of IDs for sequence pairs.
323
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
324
+ Whether or not the token list is already formatted with special tokens for the model.
325
+
326
+ Returns:
327
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
328
+ """
329
+
330
+ if already_has_special_tokens:
331
+ return super().get_special_tokens_mask(
332
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
333
+ )
334
+
335
+ prefix_ones = [1] * len(self.prefix_tokens)
336
+ suffix_ones = [1] * len(self.suffix_tokens)
337
+ if token_ids_1 is None:
338
+ return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
339
+ return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
340
+
341
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.build_inputs_with_special_tokens
342
+ def build_inputs_with_special_tokens(
343
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
344
+ ) -> List[int]:
345
+ """
346
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
347
+ adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence:
348
+
349
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
350
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
351
+
352
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
353
+ separator.
354
+
355
+ Args:
356
+ token_ids_0 (`List[int]`):
357
+ List of IDs to which the special tokens will be added.
358
+ token_ids_1 (`List[int]`, *optional*):
359
+ Optional second list of IDs for sequence pairs.
360
+
361
+ Returns:
362
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
363
+ """
364
+ if token_ids_1 is None:
365
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
366
+ # We don't expect to process pairs, but leave the pair logic for API consistency
367
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
368
+
369
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.create_token_type_ids_from_sequences
370
+ def create_token_type_ids_from_sequences(
371
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
372
+ ) -> List[int]:
373
+ """
374
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
375
+ make use of token type ids, therefore a list of zeros is returned.
376
+
377
+ Args:
378
+ token_ids_0 (`List[int]`):
379
+ List of IDs.
380
+ token_ids_1 (`List[int]`, *optional*):
381
+ Optional second list of IDs for sequence pairs.
382
+
383
+ Returns:
384
+ `List[int]`: List of zeros.
385
+
386
+ """
387
+
388
+ sep = [self.sep_token_id]
389
+ cls = [self.cls_token_id]
390
+
391
+ if token_ids_1 is None:
392
+ return len(cls + token_ids_0 + sep) * [0]
393
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
394
+
395
+ def _build_translation_inputs(
396
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
397
+ ):
398
+ """Used by translation pipeline, to prepare inputs for the generate function"""
399
+ if src_lang is None or tgt_lang is None:
400
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model.")
401
+ self.src_lang = src_lang
402
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
403
+ if "__" not in tgt_lang:
404
+ tgt_lang = f"__{tgt_lang}__"
405
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
406
+ inputs["forced_bos_token_id"] = tgt_lang_id
407
+ return inputs
408
+
409
+ def get_vocab(self):
410
+ vocab = {
411
+ self.convert_ids_to_tokens(i): i for i in range(self.fairseq_offset, self.vocab_size + self.fairseq_offset)
412
+ }
413
+ vocab.update(self.added_tokens_encoder)
414
+ return vocab
415
+
416
+ @property
417
+ def unk_token_length(self):
418
+ return len(self.sp_model.encode(str(self.unk_token)))
419
+
420
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
421
+ def get_spm_processor(self, from_slow=False):
422
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
423
+ if self.legacy or from_slow: # no dependency on protobuf
424
+ tokenizer.Load(self.vocab_file)
425
+ return tokenizer
426
+
427
+ with open(self.vocab_file, "rb") as f:
428
+ sp_model = f.read()
429
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
430
+ model = model_pb2.ModelProto.FromString(sp_model)
431
+ normalizer_spec = model_pb2.NormalizerSpec()
432
+ normalizer_spec.add_dummy_prefix = False
433
+ model.normalizer_spec.MergeFrom(normalizer_spec)
434
+ sp_model = model.SerializeToString()
435
+ tokenizer.LoadFromSerializedProto(sp_model)
436
+ return tokenizer
437
+
438
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
439
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
440
+ """
441
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
442
+ first token is special.
443
+ """
444
+ if self.legacy or len(text) == 0:
445
+ return super().tokenize(text, **kwargs)
446
+
447
+ text = text.replace(SPIECE_UNDERLINE, " ")
448
+ if self.add_prefix_space:
449
+ text = SPIECE_UNDERLINE + text
450
+
451
+ tokens = super().tokenize(text, **kwargs)
452
+
453
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
454
+ tokens = tokens[1:]
455
+ return tokens
456
+
457
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
458
+ def _tokenize(self, text, **kwargs):
459
+ """
460
+ Returns a tokenized string.
461
+
462
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
463
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
464
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
465
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
466
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
467
+ """
468
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
469
+ return self.sp_model.encode(text, out_type=str)
470
+
471
+ # 1. Encode string + prefix ex: "<unk> Hey"
472
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
473
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
474
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
475
+
476
+ def _convert_token_to_id(self, token):
477
+ """Converts a token (str) in an id using the vocab."""
478
+ spm_id = self.sp_model.PieceToId(token)
479
+
480
+ # Need to return unknown token if the SP model returned 0
481
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
482
+
483
+ def _convert_id_to_token(self, index):
484
+ """Converts an index (integer) in a token (str) using the vocab."""
485
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
486
+
487
+ def convert_tokens_to_string(self, tokens):
488
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
489
+ # since we manually add the prefix space, we have to remove it when decoding
490
+ if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
491
+ tokens[0] = tokens[0][1:]
492
+
493
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
494
+ return out_string
495
+
496
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.save_vocabulary
497
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
498
+ if not os.path.isdir(save_directory):
499
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
500
+ return
501
+ out_vocab_file = os.path.join(
502
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
503
+ )
504
+
505
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
506
+ copyfile(self.vocab_file, out_vocab_file)
507
+ elif not os.path.isfile(self.vocab_file):
508
+ with open(out_vocab_file, "wb") as fi:
509
+ content_spiece_model = self.sp_model.serialized_model_proto()
510
+ fi.write(content_spiece_model)
511
+
512
+ return (out_vocab_file,)
513
+
514
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.prepare_seq2seq_batch with eng_Latn->eng, fra_Latn->fra
515
+ def prepare_seq2seq_batch(
516
+ self,
517
+ src_texts: List[str],
518
+ src_lang: str = "eng",
519
+ tgt_texts: Optional[List[str]] = None,
520
+ tgt_lang: str = "fra",
521
+ **kwargs,
522
+ ) -> BatchEncoding:
523
+ self.src_lang = src_lang
524
+ self.tgt_lang = tgt_lang
525
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
526
+
527
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_input_mode
528
+ def _switch_to_input_mode(self):
529
+ return self.set_src_lang_special_tokens(self.src_lang)
530
+
531
+ # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_target_mode
532
+ def _switch_to_target_mode(self):
533
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
534
+
535
+ def set_src_lang_special_tokens(self, src_lang) -> None:
536
+ """Reset the special tokens to the source lang setting.
537
+ Prefix=[src_lang_code], suffix = [eos]
538
+ """
539
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
540
+ self.init_kwargs["src_lang"] = src_lang
541
+
542
+ if self.cur_lang_code == self.unk_token_id:
543
+ logger.warning_once(
544
+ f"`src_lang={src_lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
545
+ )
546
+
547
+ self.prefix_tokens = [self.cur_lang_code]
548
+ self.suffix_tokens = [self.eos_token_id]
549
+
550
+ # https://github.com/facebookresearch/fairseq2/blob/c53f18e6be6b8b46b722f2249b8397b7eccd7ad3/src/fairseq2/models/nllb/tokenizer.py#L112-L116
551
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
552
+ """Reset the special tokens to the target lang setting.
553
+ Prefix=[eos, tgt_lang_code] and suffix=[eos].
554
+ """
555
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
556
+ self.init_kwargs["tgt_lang"] = lang
557
+
558
+ if self.cur_lang_code == self.unk_token_id:
559
+ logger.warning_once(
560
+ f"`tgt_lang={lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id."
561
+ )
562
+
563
+ self.prefix_tokens = [self.eos_token_id, self.cur_lang_code]
564
+ self.suffix_tokens = [self.eos_token_id]
565
+
566
+
567
+ __all__ = ["SeamlessM4TTokenizer"]
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_seamless_m4t_v2 import *
22
+ from .modeling_seamless_m4t_v2 import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """SeamlessM4Tv2 model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class SeamlessM4Tv2Config(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`~SeamlessM4Tv2Model`]. It is used to instantiate
27
+ an SeamlessM4Tv2 model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the SeamlessM4Tv2
29
+ [""](https://huggingface.co/"") architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 256102):
37
+ Vocabulary size of the text modality of the SeamlessM4Tv2 model. Defines the number of different tokens
38
+ that can be represented by the `inputs_ids` passed when calling [`~SeamlessM4Tv2Model`],
39
+ [`~SeamlessM4Tv2ForTextToSpeech`] or [`~SeamlessM4Tv2ForTextToText`].
40
+ t2u_vocab_size (`int`, *optional*, defaults to 10082):
41
+ Unit vocabulary size of the SeamlessM4Tv2 model. Defines the number of different "unit tokens" that can be
42
+ represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4Tv2Model`],
43
+ [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`].
44
+ char_vocab_size (`int`, *optional*, defaults to 10943):
45
+ Character vocabulary size of the SeamlessM4Tv2 model. Defines the number of different character tokens that
46
+ can be represented by the `char_inputs_ids` passed when calling the Text-To-Units sub-model of
47
+ [`~SeamlessM4Tv2Model`], [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`].
48
+
49
+ > Parameters shared across sub-models
50
+
51
+ hidden_size (`int`, *optional*, defaults to 1024):
52
+ Dimensionality of the "intermediate" layers in the architecture.
53
+ initializer_range (`float`, *optional*, defaults to 0.02):
54
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
55
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
56
+ The epsilon used by the layer normalization layers.
57
+ use_cache (`bool`, *optional*, defaults to `True`):
58
+ Whether or not the model should return the last key/values attentions (not used by all models).
59
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
60
+ The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set
61
+ this to something large just in case (e.g., 512 or 1024 or 2048).
62
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
63
+ Whether the model is used as an encoder/decoder or not.
64
+ encoder_layerdrop (`float`, *optional*, defaults to 0.05):
65
+ The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
66
+ for more details.
67
+ decoder_layerdrop (`float`, *optional*, defaults to 0.05):
68
+ The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
69
+ for more details.
70
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
71
+ The non-linear activation function (function or string) in the decoder and feed-forward layers. If string,
72
+ `"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
73
+ dropout (`float`, *optional*, defaults to 0.1):
74
+ The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler.
75
+ attention_dropout (`float`, *optional*, defaults to 0.1):
76
+ The dropout probability for all attention layers.
77
+ activation_dropout (`float`, *optional*, defaults to 0.0):
78
+ The dropout probability for all activation layers in the model.
79
+ scale_embedding (`bool`, *optional*, defaults to `True`):
80
+ Scale embeddings by diving by sqrt(d_model).
81
+
82
+ > Text encoder and text decoder specific parameters
83
+
84
+ encoder_layers (`int`, *optional*, defaults to 24):
85
+ Number of hidden layers in the Transformer text encoder.
86
+ encoder_ffn_dim (`int`, *optional*, defaults to 8192):
87
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder.
88
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
89
+ Number of attention heads for each attention layer in the Transformer text encoder.
90
+ decoder_layers (`int`, *optional*, defaults to 24):
91
+ Number of hidden layers in the Transformer text decoder.
92
+ decoder_ffn_dim (`int`, *optional*, defaults to 8192):
93
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder.
94
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
95
+ Number of attention heads for each attention layer in the Transformer text decoder.
96
+ decoder_start_token_id (`int`, *optional*, defaults to 3):
97
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only
98
+ applied in the text decoder.
99
+ max_new_tokens (`int`, *optional*, defaults to 256):
100
+ The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt.
101
+ pad_token_id (`int`, *optional*, defaults to 0):
102
+ The id of the _padding_ text token. Only applied to the text-decoder model.
103
+ bos_token_id (`int`, *optional*, defaults to 2):
104
+ The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model.
105
+ eos_token_id (`int`, *optional*, defaults to 3):
106
+ The id of the _end-of-stream_ text token. Only applied to the text-decoder model.
107
+
108
+ > Speech encoder specific parameters
109
+
110
+ speech_encoder_layers (`int`, *optional*, defaults to 24):
111
+ Number of hidden layers in the Transformer speech encoder.
112
+ speech_encoder_attention_heads (`int`, *optional*, defaults to 16):
113
+ Number of attention heads for each attention layer in the Transformer speech encoder.
114
+ speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096):
115
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder.
116
+ speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`):
117
+ The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`,
118
+ `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported.
119
+ speech_encoder_dropout (`float`, *optional*, defaults to 0.0):
120
+ The dropout probability for all layers in the speech encoder.
121
+ add_adapter (`bool`, *optional*, defaults to `True`):
122
+ Add an adapter layer on top of the speech encoder.
123
+ speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1):
124
+ The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see
125
+ https://arxiv.org/abs/1909.11556) for more details.
126
+ feature_projection_input_dim (`int`, *optional*, defaults to 160):
127
+ Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing
128
+ input audios with [`SeamlessM4TFeatureExtractor`].
129
+ adaptor_kernel_size (`int`, *optional*, defaults to 8):
130
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
131
+ adaptor_stride (`int`, *optional*, defaults to 8):
132
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
133
+ adaptor_dropout (`float`, *optional*, defaults to 0.1):
134
+ The dropout probability for all layers in the speech adapter.
135
+ num_adapter_layers (`int`, *optional*, defaults to 1):
136
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
137
+ True`.
138
+ position_embeddings_type (`str`, *optional*, defaults to `"relative_key"`):
139
+ Can be specified to `relative_key`. If left to `None`, no relative position embedding is applied. Only
140
+ applied to the speech encoder. For more information on `"relative_key"`, please refer to [Self-Attention
141
+ with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
142
+ conv_depthwise_kernel_size (`int`, *optional*, defaults to 31):
143
+ Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder.
144
+ left_max_position_embeddings (`int`, *optional*, defaults to 64):
145
+ The left clipping value for relative positions.
146
+ right_max_position_embeddings (`int`, *optional*, defaults to 8):
147
+ The right clipping value for relative positions.
148
+ speech_encoder_chunk_size (`int`, *optional*, defaults to 20000): The size of each attention chunk.
149
+ speech_encoder_left_chunk_num (`int`, *optional*, defaults to 128):
150
+ Number of chunks on the left up to which lookahead is allowed.
151
+
152
+ > Text-To-Unit (t2u) model specific parameters
153
+
154
+ t2u_bos_token_id (`int`, *optional*, defaults to 0):
155
+ The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
156
+ t2u_pad_token_id (`int`, *optional*, defaults to 1):
157
+ The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model.
158
+ t2u_eos_token_id (`int`, *optional*, defaults to 2):
159
+ The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model.
160
+ t2u_encoder_layers (`int`, *optional*, defaults to 6):
161
+ Number of hidden layers in the Transformer text-to-unit encoder.
162
+ t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192):
163
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder.
164
+ t2u_encoder_attention_heads (`int`, *optional*, defaults to 16):
165
+ Number of attention heads for each attention layer in the Transformer text-to-unit encoder.
166
+ t2u_decoder_layers (`int`, *optional*, defaults to 6):
167
+ Number of hidden layers in the Transformer text-to-unit decoder.
168
+ t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192):
169
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder.
170
+ t2u_decoder_attention_heads (`int`, *optional*, defaults to 16):
171
+ Number of attention heads for each attention layer in the Transformer text-to-unit decoder.
172
+ t2u_max_position_embeddings (`int`, *optional*, defaults to 4096):
173
+ The maximum sequence length that this model text-to-unit component might ever be used with. Typically set
174
+ this to something large just in case (e.g., 512 or 1024 or 2048).
175
+ t2u_variance_predictor_embed_dim (`int`, *optional*, defaults to 1024):
176
+ The projection dimension of the text-to-unit's duration predictor.
177
+ t2u_variance_predictor_hidden_dim (`int`, *optional*, defaults to 256):
178
+ Internal dimension of the text-to-unit's duration predictor.
179
+ t2u_variance_predictor_kernel_size (`int`, *optional*, defaults to 3):
180
+ Kernel size of the convolutional layers of the text-to-unit's duration predictor.
181
+ t2u_variance_pred_dropout (`float`, *optional*, defaults to 0.5):
182
+ The dropout probability of the text-to-unit's duration predictor.
183
+
184
+ > Hifi-Gan Vocoder specific parameters
185
+
186
+ sampling_rate (`int`, *optional*, defaults to 16000):
187
+ The sampling rate at which the output audio will be generated, expressed in hertz (Hz).
188
+ upsample_initial_channel (`int`, *optional*, defaults to 512):
189
+ The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only.
190
+ upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`):
191
+ A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network.
192
+ The length of *upsample_rates* defines the number of convolutional layers and has to match the length of
193
+ *upsample_kernel_sizes*. Applies to the vocoder only.
194
+ upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`):
195
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling
196
+ network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match
197
+ the length of *upsample_rates*. Applies to the vocoder only.
198
+ resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
199
+ A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive
200
+ field fusion (MRF) module. Applies to the vocoder only.
201
+ resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
202
+ A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in
203
+ the multi-receptive field fusion (MRF) module. Applies to the vocoder only.
204
+ leaky_relu_slope (`float`, *optional*, defaults to 0.1):
205
+ The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder
206
+ only.
207
+ unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000):
208
+ Vocabulary size of the SeamlessM4Tv2 vocoder. Defines the number of different unit tokens that can be
209
+ represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4Tv2Model`],
210
+ [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`].
211
+ unit_embed_dim (`int`, *optional*, defaults to 1280):
212
+ The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only.
213
+ lang_embed_dim (`int`, *optional*, defaults to 256):
214
+ The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only.
215
+ spkr_embed_dim (`int`, *optional*, defaults to 256):
216
+ The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only.
217
+ vocoder_num_langs (`int`, *optional*, defaults to 36):
218
+ Number of langs supported by the vocoder. Might be different from `t2u_num_langs`.
219
+ vocoder_num_spkrs (`int`, *optional*, defaults to 200):
220
+ Number of speakers supported by the vocoder.
221
+ variance_predictor_kernel_size (`int`, *optional*, defaults to 3):
222
+ Kernel size of the duration predictor. Applies to the vocoder only.
223
+ var_pred_dropout (`float`, *optional*, defaults to 0.5):
224
+ The dropout probability of the duration predictor. Applies to the vocoder only.
225
+ vocoder_offset (`int`, *optional*, defaults to 4):
226
+ Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only.
227
+
228
+ ```python
229
+ >>> from transformers import SeamlessM4Tv2Model, SeamlessM4Tv2Config
230
+
231
+ >>> # Initializing a SeamlessM4Tv2 "" style configuration
232
+ >>> configuration = SeamlessM4Tv2Config()
233
+
234
+ >>> # Initializing a model from the "" style configuration
235
+ >>> model = SeamlessM4Tv2Model(configuration)
236
+
237
+ >>> # Accessing the model configuration
238
+ >>> configuration = model.config
239
+ ```"""
240
+
241
+ model_type = "seamless_m4t_v2"
242
+
243
+ def __init__(
244
+ self,
245
+ vocab_size=256102,
246
+ t2u_vocab_size=10082,
247
+ char_vocab_size=10943,
248
+ # shared config
249
+ hidden_size=1024,
250
+ initializer_range=0.02,
251
+ layer_norm_eps=1e-5,
252
+ use_cache=True,
253
+ max_position_embeddings=4096,
254
+ is_encoder_decoder=True,
255
+ encoder_layerdrop=0.05,
256
+ decoder_layerdrop=0.05,
257
+ activation_function="relu",
258
+ dropout=0.1,
259
+ attention_dropout=0.1,
260
+ activation_dropout=0.0,
261
+ scale_embedding=True,
262
+ # text encoder|decoder
263
+ encoder_layers=24,
264
+ encoder_ffn_dim=8192,
265
+ encoder_attention_heads=16,
266
+ decoder_layers=24,
267
+ decoder_ffn_dim=8192,
268
+ decoder_attention_heads=16,
269
+ decoder_start_token_id=3,
270
+ max_new_tokens=256,
271
+ pad_token_id=0,
272
+ bos_token_id=2,
273
+ eos_token_id=3,
274
+ # speech_encoder
275
+ speech_encoder_layers=24,
276
+ speech_encoder_attention_heads=16,
277
+ speech_encoder_intermediate_size=4096,
278
+ speech_encoder_hidden_act="swish",
279
+ speech_encoder_dropout=0.0,
280
+ add_adapter=True,
281
+ speech_encoder_layerdrop=0.1,
282
+ feature_projection_input_dim=160,
283
+ adaptor_kernel_size=8,
284
+ adaptor_stride=8,
285
+ adaptor_dropout=0.1,
286
+ num_adapter_layers=1,
287
+ position_embeddings_type="relative_key",
288
+ conv_depthwise_kernel_size=31,
289
+ left_max_position_embeddings=64,
290
+ right_max_position_embeddings=8,
291
+ speech_encoder_chunk_size=20000,
292
+ speech_encoder_left_chunk_num=128,
293
+ # t2u config
294
+ t2u_bos_token_id=0,
295
+ t2u_pad_token_id=1,
296
+ t2u_eos_token_id=2,
297
+ t2u_encoder_layers=6,
298
+ t2u_encoder_ffn_dim=8192,
299
+ t2u_encoder_attention_heads=16,
300
+ t2u_decoder_layers=6,
301
+ t2u_decoder_ffn_dim=8192,
302
+ t2u_decoder_attention_heads=16,
303
+ t2u_max_position_embeddings=4096,
304
+ t2u_variance_predictor_embed_dim=1024,
305
+ t2u_variance_predictor_hidden_dim=256,
306
+ t2u_variance_predictor_kernel_size=3,
307
+ t2u_variance_pred_dropout=0.5,
308
+ # hifi-gan vocoder config
309
+ sampling_rate=16000,
310
+ upsample_initial_channel=512,
311
+ upsample_rates=[5, 4, 4, 2, 2],
312
+ upsample_kernel_sizes=[11, 8, 8, 4, 4],
313
+ resblock_kernel_sizes=[3, 7, 11],
314
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
315
+ leaky_relu_slope=0.1,
316
+ # specific to Code Hifi-Gan
317
+ unit_hifi_gan_vocab_size=10000,
318
+ unit_embed_dim=1280,
319
+ lang_embed_dim=256,
320
+ spkr_embed_dim=256,
321
+ vocoder_num_langs=36,
322
+ vocoder_num_spkrs=200,
323
+ variance_predictor_kernel_size=3,
324
+ var_pred_dropout=0.5,
325
+ vocoder_offset=4,
326
+ **kwargs,
327
+ ):
328
+ # overall_config
329
+ self.vocab_size = vocab_size
330
+ self.t2u_vocab_size = t2u_vocab_size
331
+ self.char_vocab_size = char_vocab_size
332
+ self.hidden_size = hidden_size
333
+ self.initializer_range = initializer_range
334
+ self.layer_norm_eps = layer_norm_eps
335
+ self.max_position_embeddings = max_position_embeddings
336
+ self.use_cache = use_cache
337
+ self.max_new_tokens = max_new_tokens
338
+ self.encoder_layerdrop = encoder_layerdrop
339
+ self.decoder_layerdrop = decoder_layerdrop
340
+ self.activation_function = activation_function
341
+ self.dropout = dropout
342
+ self.attention_dropout = attention_dropout
343
+ self.activation_dropout = activation_dropout
344
+ self.scale_embedding = scale_embedding
345
+ # for proper config init
346
+ self.num_attention_heads = decoder_attention_heads
347
+ self.num_hidden_layers = decoder_layers
348
+
349
+ # text|unit encoder|decoder
350
+ self.encoder_layers = encoder_layers
351
+ self.encoder_ffn_dim = encoder_ffn_dim
352
+ self.encoder_attention_heads = encoder_attention_heads
353
+ self.decoder_layers = decoder_layers
354
+ self.decoder_ffn_dim = decoder_ffn_dim
355
+ self.decoder_attention_heads = decoder_attention_heads
356
+
357
+ # speech_encoder
358
+ self.speech_encoder_layers = speech_encoder_layers
359
+ self.speech_encoder_hidden_act = speech_encoder_hidden_act
360
+ self.speech_encoder_dropout = speech_encoder_dropout
361
+ self.speech_encoder_attention_heads = speech_encoder_attention_heads
362
+ self.speech_encoder_layerdrop = speech_encoder_layerdrop
363
+ self.speech_encoder_intermediate_size = speech_encoder_intermediate_size
364
+ self.feature_projection_input_dim = feature_projection_input_dim
365
+ self.adaptor_kernel_size = adaptor_kernel_size
366
+ self.adaptor_stride = adaptor_stride
367
+ self.adaptor_dropout = adaptor_dropout
368
+ self.num_adapter_layers = num_adapter_layers
369
+ self.position_embeddings_type = position_embeddings_type
370
+ self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
371
+ self.add_adapter = add_adapter
372
+ self.left_max_position_embeddings = left_max_position_embeddings
373
+ self.right_max_position_embeddings = right_max_position_embeddings
374
+ self.speech_encoder_chunk_size = speech_encoder_chunk_size
375
+ self.speech_encoder_left_chunk_num = speech_encoder_left_chunk_num
376
+
377
+ # t2u config
378
+ self.t2u_bos_token_id = t2u_bos_token_id
379
+ self.t2u_pad_token_id = t2u_pad_token_id
380
+ self.t2u_eos_token_id = t2u_eos_token_id
381
+ self.t2u_encoder_layers = t2u_encoder_layers
382
+ self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim
383
+ self.t2u_encoder_attention_heads = t2u_encoder_attention_heads
384
+ self.t2u_decoder_layers = t2u_decoder_layers
385
+ self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim
386
+ self.t2u_decoder_attention_heads = t2u_decoder_attention_heads
387
+ self.t2u_max_position_embeddings = t2u_max_position_embeddings
388
+ self.t2u_variance_predictor_embed_dim = t2u_variance_predictor_embed_dim # TODO: add to docstrings
389
+ self.t2u_variance_predictor_hidden_dim = t2u_variance_predictor_hidden_dim # TODO: add to docstrings
390
+ self.t2u_variance_predictor_kernel_size = t2u_variance_predictor_kernel_size # TODO: add to docstrings
391
+ self.t2u_variance_pred_dropout = t2u_variance_pred_dropout # TODO: add to docstrings
392
+
393
+ # hifi-gan vocoder config
394
+ # original parameters specific to Hifi-Gan
395
+ self.sampling_rate = sampling_rate
396
+ self.upsample_initial_channel = upsample_initial_channel
397
+ self.upsample_rates = upsample_rates
398
+ self.upsample_kernel_sizes = upsample_kernel_sizes
399
+ self.resblock_kernel_sizes = resblock_kernel_sizes
400
+ self.resblock_dilation_sizes = resblock_dilation_sizes
401
+ self.leaky_relu_slope = leaky_relu_slope
402
+
403
+ # specific to Code Hifi-Gan
404
+ self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size
405
+ self.unit_embed_dim = unit_embed_dim
406
+ self.lang_embed_dim = lang_embed_dim
407
+ self.spkr_embed_dim = spkr_embed_dim
408
+ self.vocoder_num_langs = vocoder_num_langs
409
+ self.vocoder_num_spkrs = vocoder_num_spkrs
410
+ self.variance_predictor_kernel_size = variance_predictor_kernel_size
411
+ self.var_pred_dropout = var_pred_dropout
412
+ self.vocoder_offset = vocoder_offset
413
+
414
+ super().__init__(
415
+ pad_token_id=pad_token_id,
416
+ bos_token_id=bos_token_id,
417
+ eos_token_id=eos_token_id,
418
+ decoder_start_token_id=decoder_start_token_id,
419
+ is_encoder_decoder=is_encoder_decoder,
420
+ max_position_embeddings=max_position_embeddings,
421
+ **kwargs,
422
+ )
423
+
424
+
425
+ __all__ = ["SeamlessM4Tv2Config"]
test.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 \
2
+ swift infer \
3
+ --adapters /root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324\
4
+ --stream true \
5
+ --temperature 0 \
6
+ --max_new_tokens 2048
test_qwenOmni.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template
10
+ from transformers import HfArgumentParser
11
+ from transformers import Qwen2_5OmniProcessor
12
+ from dataset.dataset2 import AudioDataset
13
+
14
+ @dataclass
15
+ class TestArguments:
16
+ """
17
+ Arguments pertaining to what data we are going to input our model for training and eval.
18
+ """
19
+ MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" # 基础模型路径
20
+ LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" # LoRA 模型路径
21
+ DATA_FILE = "/root/ms-swift/silence_overlaps/test" # 测试数据文件
22
+ OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" # 推理结果输出目录
23
+
24
+ model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"})
25
+ lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"})
26
+ out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"})
27
+ data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"})
28
+ DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"})
29
+ force: Optional[bool] = field(default=False, metadata={"help": "force test"})
30
+ batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"})
31
+
32
+ def __post_init__(self):
33
+ if self.model_path is None:
34
+ raise ValueError("config path should not none")
35
+ if self.data_dir is None:
36
+ raise ValueError("data directory should not be none")
37
+
38
+ def get_prompt_templates():
39
+ prompt_template = (
40
+ "You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n"
41
+ "Please summarize if any overlaps exceed the 3-second threshold."
42
+ )
43
+ return prompt_template
44
+
45
+ def extract_overall_score(output_str):
46
+ """从输出中提取<overall score>X</overall score>"""
47
+ score_pattern = r"<overall score>(\d+)</overall score>"
48
+ match = re.search(score_pattern, output_str)
49
+ if match:
50
+ try:
51
+ return int(match.group(1))
52
+ except ValueError:
53
+ pass
54
+ return None
55
+
56
+ def main():
57
+ parser = HfArgumentParser(TestArguments)
58
+ data_args = parser.parse_args_into_dataclasses()[0]
59
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
60
+ logging.info("Starting inference with arguments: %s", data_args)
61
+
62
+ if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0:
63
+ logging.info(f"The {data_args.out_file} exists. Do not regenerate it.")
64
+ return
65
+
66
+ # 设置GPU设备
67
+ device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu")
68
+ logging.info(f"Using device: {device}")
69
+
70
+ # 初始化音频处理器
71
+ logging.info("Loading processor...")
72
+ processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path)
73
+
74
+ # 初始化推理引擎
75
+ logging.info("Initializing inference engine...")
76
+ engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path])
77
+ engine.processor = processor
78
+ template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.")
79
+ engine.default_template = template
80
+ template.processor = processor
81
+ # 初始化数据集
82
+ logging.info("Initializing dataset from %s", data_args.data_dir)
83
+ dataset = AudioDataset(data_args.data_dir)
84
+ logging.info(f"Dataset loaded successfully with {len(dataset)} samples")
85
+
86
+ # 获取提示模板
87
+ prompt_template = get_prompt_templates()
88
+
89
+ all_outputs = []
90
+ batch_size = data_args.batch_size
91
+ total_batches = (len(dataset) + batch_size - 1) // batch_size
92
+ logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}")
93
+
94
+ for i in range(0, len(dataset), batch_size):
95
+ current_batch = i // batch_size + 1
96
+ logging.info(f"Processing batch {current_batch}/{total_batches}")
97
+
98
+ batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
99
+
100
+ # Process each sample
101
+ batch_outputs = []
102
+ for bd in batch_data:
103
+ # 构建推理请求
104
+ infer_request = InferRequest(
105
+ messages=bd["prompt"],
106
+ audios=[bd["audio"]]
107
+ )
108
+
109
+ # 设置推理配置
110
+ request_config = RequestConfig(
111
+ max_tokens=512,
112
+ temperature=0,
113
+ do_sample=False,
114
+ num_beams=1
115
+ )
116
+
117
+ # 执行推理
118
+ resp_list = engine.infer([infer_request], request_config)
119
+ response = resp_list[0].choices[0].message.content
120
+ batch_outputs.append(response)
121
+
122
+ all_outputs.extend(batch_outputs)
123
+ logging.info(f"Completed batch {current_batch}/{total_batches}")
124
+
125
+ final_output = []
126
+ correct_count = 0
127
+ total_count = 0
128
+ true_positive = 0
129
+ false_positive = 0
130
+ false_negative = 0
131
+
132
+ for input_example, model_output in zip(dataset, all_outputs):
133
+ pred_score = extract_overall_score(model_output)
134
+ gt_score = input_example.get("solution", None)
135
+
136
+ result = {
137
+ "id": input_example.get("id", None),
138
+ "gt_score": gt_score,
139
+ "model_output": model_output,
140
+ "predicted_score": pred_score
141
+ }
142
+ final_output.append(result)
143
+
144
+ if pred_score is not None and gt_score is not None:
145
+ total_count += 1
146
+ if pred_score == gt_score:
147
+ correct_count += 1
148
+ true_positive += 1
149
+ else:
150
+ false_positive += 1
151
+ false_negative += 1
152
+
153
+ accuracy = correct_count / total_count if total_count > 0 else 0
154
+ precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
155
+ recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
156
+
157
+ # 添加准确率指标到最终输出
158
+ metrics = {
159
+ "accuracy": accuracy,
160
+ "precision": precision,
161
+ "recall": recall,
162
+ "correct_count": correct_count,
163
+ "total_count": total_count
164
+ }
165
+ final_output.append({"metrics": metrics})
166
+
167
+ logging.info("Saving results to %s", data_args.out_file)
168
+ with open(data_args.out_file, "w") as f:
169
+ json.dump(final_output, f, indent=2)
170
+
171
+ logging.info(f"Results saved successfully.")
172
+ logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})")
173
+ logging.info(f"召回率: {recall:.4f}")
174
+ logging.info(f"精确率: {precision:.4f}")
175
+
176
+ if __name__ == "__main__":
177
+ main()
train.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ CUDA_VISIBLE_DEVICES=0 swift sft \
3
+ --model /root/autodl-tmp/output_7B_FULL_4JOB/v2-20250621-170947/checkpoint-608 \
4
+ --dataset ./dataset_newCOTSFT1_filtered_90.0s_resampled_16000.jsonl \
5
+ --model_type qwen2_5_omni\
6
+ --train_type full \
7
+ --output_dir /root/autodl-tmp/output_7B_FULL_cotSFT \
8
+ --torch_dtype bfloat16 \
9
+ --learning_rate 1e-4 \
10
+ --num_train_epochs 2 \
11
+ --freeze_vit false \
12
+ --freeze_aligner false \
13
+ --per_device_train_batch_size 1 \
14
+ --per_device_eval_batch_size 1 \
15
+ # ...
16
+
17
+
18
+ #CUDA_VISIBLE_DEVICES=0 swift sft \
19
+ # --model /root/autodl-tmp/Qwen2.5-Omni-7B \
20
+ # --dataset /root/ms-swift/dataset_cotSFT.json \
21
+ # --model_type qwen2_5_omni\
22
+ # --train_type lora \
23
+ # --output_dir /root/autodl-tmp/output_7B_Lora_cotSFT \
24
+ # --torch_dtype bfloat16 \
25
+ # --learning_rate 1e-4 \
26
+ # --lora_rank 8 \
27
+ # --lora_alpha 32 \
28
+ # --target_modules all-linear \
29
+ # --num_train_epochs 3 \
30
+ # --freeze_vit false \
31
+ # --freeze_aligner false \
32
+ # --per_device_train_batch_size 3 \
33
+ # --per_device_eval_batch_size 1 \
34
+ # ...
35
+ # # 8*A100
36
+ # NPROC_PER_NODE=8 \
37
+ # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
38
+ # swift pt \
39
+ # --model Qwen/Qwen2.5-7B \
40
+ # --dataset swift/chinese-c4 \
41
+ # --streaming true \
42
+ # --train_type full \
43
+ # --deepspeed zero2 \
44
+ # --output_dir output \
45
+ # --max_steps 10000 \
46
+ # ...
47
+
48
+
49
+
50
+ # --lora_rank 8 \
51
+ # --lora_alpha 32 \
52
+ # --target_modules all-linear \
53
+ # --gradient_accumulation_steps 16 \
54
+ # --eval_steps 50 \
55
+ # --save_steps 50 \
56
+ # --save_total_limit 2 \
57
+ # --logging_steps 5 \
58
+ # --max_length 2048 \
59
+ # --output_dir output \
60
+ # --system 'You are a helpful assistant.' \
61
+ # --warmup_ratio 0.05 \
62
+ # --dataloader_num_workers 4 \
63
+ # --model_author swift \
64
+ # --model_name swift-robot